home · contact · privacy
Slightly improve and re-organize Condition tests.
[plomtask] / tests / utils.py
index d41e7b39ad5fe07abdc701be674d9f81678cd30f..665436873c27af704a13827715d3c795e04e1fe1 100644 (file)
 """Shared test utilities."""
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from threading import Thread
 from http.client import HTTPConnection
+from json import loads as json_loads
 from urllib.parse import urlencode
 from urllib.parse import urlencode
-from datetime import datetime
+from uuid import uuid4
 from os import remove as remove_file
 from os import remove as remove_file
-from typing import Mapping
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
+from plomtask.processes import Process, ProcessStep
+from plomtask.conditions import Condition
+from plomtask.days import Day
+from plomtask.todos import Todo
+from plomtask.exceptions import NotFoundException, HandledException
+
+
+def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+    def wrapper(self: TestCase) -> None:
+        if hasattr(self, 'checked_class'):
+            f(self)
+    return wrapper
+
+
+class TestCaseSansDB(TestCase):
+    """Tests requiring no DB setup."""
+    checked_class: Any
+    default_init_args: list[Any] = []
+    versioned_defaults_to_test: dict[str, str | float] = {}
+    legal_ids = [1, 5]
+    illegal_ids = [0]
+
+    @_within_checked_class
+    def test_id_validation(self) -> None:
+        """Test .id_ validation/setting."""
+        for id_ in self.illegal_ids:
+            with self.assertRaises(HandledException):
+                self.checked_class(id_, *self.default_init_args)
+        for id_ in self.legal_ids:
+            obj = self.checked_class(id_, *self.default_init_args)
+            self.assertEqual(obj.id_, id_)
+
+    @_within_checked_class
+    def test_versioned_defaults(self) -> None:
+        """Test defaults of VersionedAttributes."""
+        id_ = self.legal_ids[0]
+        obj = self.checked_class(id_, *self.default_init_args)
+        for k, v in self.versioned_defaults_to_test.items():
+            self.assertEqual(getattr(obj, k).newest, v)
 
 
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
 
 
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
+    checked_class: Any
+    default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
+    default_init_kwargs: dict[str, Any] = {}
+    test_versioneds: dict[str, type] = {}
 
     def setUp(self) -> None:
 
     def setUp(self) -> None:
-        timestamp = datetime.now().timestamp()
-        self.db_file = DatabaseFile(f'test_db:{timestamp}')
-        self.db_file.remake()
+        Condition.empty_cache()
+        Day.empty_cache()
+        Process.empty_cache()
+        ProcessStep.empty_cache()
+        Todo.empty_cache()
+        self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
         self.db_conn = DatabaseConnection(self.db_file)
 
     def tearDown(self) -> None:
         self.db_conn.close()
         remove_file(self.db_file.path)
 
         self.db_conn = DatabaseConnection(self.db_file)
 
     def tearDown(self) -> None:
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    def _load_from_db(self, id_: int | str) -> list[object]:
+        db_found: list[object] = []
+        for row in self.db_conn.row_where(self.checked_class.table_name,
+                                          'id', id_):
+            db_found += [self.checked_class.from_table_row(self.db_conn,
+                                                           row)]
+        return db_found
+
+    def _change_obj(self, obj: object) -> str:
+        attr_name: str = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
+        return attr_name
+
+    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+        """Test both cache and DB equal content."""
+        expected_cache = {}
+        for item in content:
+            expected_cache[item.id_] = item
+        self.assertEqual(self.checked_class.get_cache(), expected_cache)
+        hashes_content = [hash(x) for x in content]
+        db_found: list[Any] = []
+        for item in content:
+            assert isinstance(item.id_, type(self.default_ids[0]))
+            db_found += self._load_from_db(item.id_)
+        hashes_db_found = [hash(x) for x in db_found]
+        self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
+
+    @_within_checked_class
+    def test_saving_versioned(self) -> None:
+        """Test storage and initialization of versioned attributes."""
+        def retrieve_attr_vals() -> list[object]:
+            attr_vals_saved: list[object] = []
+            assert hasattr(retrieved, 'id_')
+            for row in self.db_conn.row_where(attr.table_name, 'parent',
+                                              retrieved.id_):
+                attr_vals_saved += [row[2]]
+            return attr_vals_saved
+        for attr_name, type_ in self.test_versioneds.items():
+            # fail saving attributes on non-saved owner
+            owner = self.checked_class(None, **self.default_init_kwargs)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            with self.assertRaises(NotFoundException):
+                attr.save(self.db_conn)
+            owner.save(self.db_conn)
+            # check stored attribute is as expected
+            retrieved = self._load_from_db(owner.id_)[0]
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
+            # check owner.save() created entries in attr table
+            attr_vals_saved = retrieve_attr_vals()
+            self.assertEqual(vals, attr_vals_saved)
+            # check setting new val to attr inconsequential to DB without save
+            attr.set(vals[0])
+            attr_vals_saved = retrieve_attr_vals()
+            self.assertEqual(vals, attr_vals_saved)
+            # check save finally adds new val
+            attr.save(self.db_conn)
+            attr_vals_saved = retrieve_attr_vals()
+            self.assertEqual(vals + [vals[0]], attr_vals_saved)
+
+    @_within_checked_class
+    def test_saving_and_caching(self) -> None:
+        """Test effects of .cache() and .save()."""
+        id1 = self.default_ids[0]
+        # check failure to cache without ID (if None-ID input possible)
+        if isinstance(id1, int):
+            obj0 = self.checked_class(None, **self.default_init_kwargs)
+            with self.assertRaises(HandledException):
+                obj0.cache()
+        # check mere object init itself doesn't even store in cache
+        obj1 = self.checked_class(id1, **self.default_init_kwargs)
+        self.assertEqual(self.checked_class.get_cache(), {})
+        # check .cache() fills cache, but not DB
+        obj1.cache()
+        self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
+        db_found = self._load_from_db(id1)
+        self.assertEqual(db_found, [])
+        # check .save() sets ID (for int IDs), updates cache, and fills DB
+        # (expect ID to be set to id1, despite obj1 already having that as ID:
+        # it's generated by cursor.lastrowid on the DB table, and with obj1
+        # not written there, obj2 should get it first!)
+        id_input = None if isinstance(id1, int) else id1
+        obj2 = self.checked_class(id_input, **self.default_init_kwargs)
+        obj2.save(self.db_conn)
+        obj2_hash = hash(obj2)
+        self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
+        db_found += self._load_from_db(id1)
+        self.assertEqual([hash(o) for o in db_found], [obj2_hash])
+        # check we cannot overwrite obj2 with obj1 despite its same ID,
+        # since it has disappeared now
+        with self.assertRaises(HandledException):
+            obj1.save(self.db_conn)
+
+    @_within_checked_class
+    def test_by_id(self) -> None:
+        """Test .by_id()."""
+        id1, id2, _ = self.default_ids
+        # check failure if not yet saved
+        obj1 = self.checked_class(id1, **self.default_init_kwargs)
+        with self.assertRaises(NotFoundException):
+            self.checked_class.by_id(self.db_conn, id1)
+        # check identity of cached and retrieved
+        obj1.cache()
+        self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
+        # check identity of saved and retrieved
+        obj2 = self.checked_class(id2, **self.default_init_kwargs)
+        obj2.save(self.db_conn)
+        self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
+
+    @_within_checked_class
+    def test_by_id_or_create(self) -> None:
+        """Test .by_id_or_create."""
+        # check .by_id_or_create fails if wrong class
+        if not self.checked_class.can_create_by_id:
+            with self.assertRaises(HandledException):
+                self.checked_class.by_id_or_create(self.db_conn, None)
+            return
+        # check ID input of None creates, on saving, ID=1,2,… for int IDs
+        if isinstance(self.default_ids[0], int):
+            for n in range(2):
+                item = self.checked_class.by_id_or_create(self.db_conn, None)
+                self.assertEqual(item.id_, None)
+                item.save(self.db_conn)
+                self.assertEqual(item.id_, n+1)
+        # check .by_id_or_create acts like normal instantiation (sans saving)
+        id_ = self.default_ids[2]
+        item = self.checked_class.by_id_or_create(self.db_conn, id_)
+        self.assertEqual(item.id_, id_)
+        with self.assertRaises(NotFoundException):
+            self.checked_class.by_id(self.db_conn, item.id_)
+        self.assertEqual(self.checked_class(item.id_), item)
+
+    @_within_checked_class
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class directly from DB."""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_, **self.default_init_kwargs)
+        obj.save(self.db_conn)
+        assert isinstance(obj.id_, type(id_))
+        for row in self.db_conn.row_where(self.checked_class.table_name,
+                                          'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
+            hash_original = hash(obj)
+            attr_name = self._change_obj(obj)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
+            retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
+            self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
+            self.assertEqual({retrieved.id_: retrieved},
+                             self.checked_class.get_cache())
+        # check .from_table_row also reads versioned attributes from DB
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            for row in self.db_conn.row_where(owner.table_name, 'id',
+                                              owner.id_):
+                retrieved = owner.__class__.from_table_row(self.db_conn, row)
+                attr = getattr(retrieved, attr_name)
+                self.assertEqual(sorted(attr.history.values()), vals)
+
+    @_within_checked_class
+    def test_all(self) -> None:
+        """Test .all() and its relation to cache and savings."""
+        id_1, id_2, id_3 = self.default_ids
+        item1 = self.checked_class(id_1, **self.default_init_kwargs)
+        item2 = self.checked_class(id_2, **self.default_init_kwargs)
+        item3 = self.checked_class(id_3, **self.default_init_kwargs)
+        # check .all() returns empty list on un-cached items
+        self.assertEqual(self.checked_class.all(self.db_conn), [])
+        # check that all() shows only cached/saved items
+        item1.cache()
+        item3.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item3]))
+        item2.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item2, item3]))
+
+    @_within_checked_class
+    def test_singularity(self) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        id1 = self.default_ids[0]
+        obj = self.checked_class(id1, **self.default_init_kwargs)
+        obj.save(self.db_conn)
+        # change object, expect retrieved through .by_id to carry change
+        attr_name = self._change_obj(obj)
+        new_attr = getattr(obj, attr_name)
+        retrieved = self.checked_class.by_id(self.db_conn, id1)
+        self.assertEqual(new_attr, getattr(retrieved, attr_name))
+
+    @_within_checked_class
+    def test_versioned_singularity_title(self) -> None:
+        """Test singularity of VersionedAttributes on saving (with .title)."""
+        if 'title' in self.test_versioneds:
+            obj = self.checked_class(None)
+            obj.save(self.db_conn)
+            assert isinstance(obj.id_, int)
+            # change obj, expect retrieved through .by_id to carry change
+            obj.title.set('named')
+            retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
+            self.assertEqual(obj.title.history, retrieved.title.history)
+
+    @_within_checked_class
+    def test_remove(self) -> None:
+        """Test .remove() effects on DB and cache."""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_, **self.default_init_kwargs)
+        # check removal only works after saving
+        with self.assertRaises(HandledException):
+            obj.remove(self.db_conn)
+        obj.save(self.db_conn)
+        obj.remove(self.db_conn)
+        # check access to obj fails after removal
+        with self.assertRaises(HandledException):
+            print(obj.id_)
+        # check DB and cache now empty
+        self.check_identity_with_cache_and_db([])
+
 
 class TestCaseWithServer(TestCaseWithDB):
     """Module tests against our HTTP server/handler (and database)."""
 
 class TestCaseWithServer(TestCaseWithDB):
     """Module tests against our HTTP server/handler (and database)."""
@@ -35,6 +319,7 @@ class TestCaseWithServer(TestCaseWithDB):
         self.server_thread.start()
         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
                                    self.httpd.server_address[1])
         self.server_thread.start()
         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
                                    self.httpd.server_address[1])
+        self.httpd.set_json_mode()
 
     def tearDown(self) -> None:
         self.httpd.shutdown()
 
     def tearDown(self) -> None:
         self.httpd.shutdown()
@@ -42,6 +327,71 @@ class TestCaseWithServer(TestCaseWithDB):
         self.server_thread.join()
         super().tearDown()
 
         self.server_thread.join()
         super().tearDown()
 
+    @staticmethod
+    def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
+        """Return list of only 'id' fields of items."""
+        id_list = []
+        for item in items:
+            assert isinstance(item['id'], (int, str))
+            id_list += [item['id']]
+        return id_list
+
+    @staticmethod
+    def as_refs(items: list[dict[str, object]]
+                ) -> dict[str, dict[str, object]]:
+        """Return dictionary of items by their 'id' fields."""
+        refs = {}
+        for item in items:
+            refs[str(item['id'])] = item
+        return refs
+
+    @staticmethod
+    def cond_as_dict(id_: int = 1,
+                     is_active: bool = False,
+                     titles: None | list[str] = None,
+                     descriptions: None | list[str] = None
+                     ) -> dict[str, object]:
+        """Return JSON of Condition to expect."""
+        d = {'id': id_,
+             'is_active': is_active,
+             '_versioned': {
+                 'title': {},
+                 'description': {}}}
+        titles = titles if titles else []
+        descriptions = descriptions if descriptions else []
+        assert isinstance(d['_versioned'], dict)
+        for i, title in enumerate(titles):
+            d['_versioned']['title'][i] = title
+        for i, description in enumerate(descriptions):
+            d['_versioned']['description'][i] = description
+        return d
+
+    @staticmethod
+    def proc_as_dict(id_: int = 1,
+                     title: str = 'A',
+                     description: str = '',
+                     effort: float = 1.0,
+                     conditions: None | list[int] = None,
+                     disables: None | list[int] = None,
+                     blockers: None | list[int] = None,
+                     enables: None | list[int] = None
+                     ) -> dict[str, object]:
+        """Return JSON of Process to expect."""
+        # pylint: disable=too-many-arguments
+        d = {'id': id_,
+             'calendarize': False,
+             'suppressed_steps': [],
+             'explicit_steps': [],
+             '_versioned': {
+                 'title': {0: title},
+                 'description': {0: description},
+                 'effort': {0: effort}},
+             'conditions': conditions if conditions else [],
+             'disables': disables if disables else [],
+             'enables': enables if enables else [],
+             'blockers': blockers if blockers else []}
+        return d
+
     def check_redirect(self, target: str) -> None:
         """Check that self.conn answers with a 302 redirect to target."""
         response = self.conn.getresponse()
     def check_redirect(self, target: str) -> None:
         """Check that self.conn answers with a 302 redirect to target."""
         response = self.conn.getresponse()
@@ -56,12 +406,60 @@ class TestCaseWithServer(TestCaseWithDB):
     def check_post(self, data: Mapping[str, object], target: str,
                    expected_code: int, redirect_location: str = '') -> None:
         """Check that POST of data to target yields expected_code."""
     def check_post(self, data: Mapping[str, object], target: str,
                    expected_code: int, redirect_location: str = '') -> None:
         """Check that POST of data to target yields expected_code."""
-        encoded_form_data = urlencode(data).encode('utf-8')
+        encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
         headers = {'Content-Type': 'application/x-www-form-urlencoded',
                    'Content-Length': str(len(encoded_form_data))}
         self.conn.request('POST', target,
                           body=encoded_form_data, headers=headers)
         if 302 == expected_code:
         headers = {'Content-Type': 'application/x-www-form-urlencoded',
                    'Content-Length': str(len(encoded_form_data))}
         self.conn.request('POST', target,
                           body=encoded_form_data, headers=headers)
         if 302 == expected_code:
+            if redirect_location == '':
+                redirect_location = target
             self.check_redirect(redirect_location)
         else:
             self.assertEqual(self.conn.getresponse().status, expected_code)
             self.check_redirect(redirect_location)
         else:
             self.assertEqual(self.conn.getresponse().status, expected_code)
+
+    def check_get_defaults(self, path: str) -> None:
+        """Some standard model paths to test."""
+        self.check_get(path, 200)
+        self.check_get(f'{path}?id=', 200)
+        self.check_get(f'{path}?id=foo', 400)
+        self.check_get(f'/{path}?id=0', 500)
+        self.check_get(f'{path}?id=1', 200)
+
+    def post_process(self, id_: int = 1,
+                     form_data: dict[str, Any] | None = None
+                     ) -> dict[str, Any]:
+        """POST basic Process."""
+        if not form_data:
+            form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
+        self.check_post(form_data, f'/process?id={id_}', 302,
+                        f'/process?id={id_}')
+        return form_data
+
+    def check_json_get(self, path: str, expected: dict[str, object]) -> None:
+        """Compare JSON on GET path with expected.
+
+        To simplify comparison of VersionedAttribute histories, transforms
+        timestamp keys of VersionedAttribute history keys into integers
+        counting chronologically forward from 0.
+        """
+        def rewrite_history_keys_in(item: Any) -> Any:
+            if isinstance(item, dict):
+                if '_versioned' in item.keys():
+                    for k in item['_versioned']:
+                        vals = item['_versioned'][k].values()
+                        history = {}
+                        for i, val in enumerate(vals):
+                            history[i] = val
+                        item['_versioned'][k] = history
+                for k in list(item.keys()):
+                    rewrite_history_keys_in(item[k])
+            elif isinstance(item, list):
+                item[:] = [rewrite_history_keys_in(i) for i in item]
+            return item
+        self.conn.request('GET', path)
+        response = self.conn.getresponse()
+        self.assertEqual(response.status, 200)
+        retrieved = json_loads(response.read().decode())
+        rewrite_history_keys_in(retrieved)
+        self.assertEqual(expected, retrieved)