From e23c0539d42465c49d535c805ac508334effd7d2 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 11 Jul 2024 04:19:44 +0200
Subject: [PATCH] Minor test code improvements.

---
 tests/conditions.py | 85 ++++++++++++++++++++++---------------------
 tests/utils.py      | 88 +++++++++++++++++++++++++--------------------
 2 files changed, 94 insertions(+), 79 deletions(-)

diff --git a/tests/conditions.py b/tests/conditions.py
index 9b3a403..f84533e 100644
--- a/tests/conditions.py
+++ b/tests/conditions.py
@@ -24,17 +24,18 @@ class TestsWithDB(TestCaseWithDB):
         proc = Process(None)
         proc.save(self.db_conn)
         todo = Todo(None, proc, False, '2024-01-01')
+        todo.save(self.db_conn)
+        # check condition can only be deleted if not depended upon
         for depender in (proc, todo):
             assert hasattr(depender, 'save')
             assert hasattr(depender, 'set_conditions')
             c = Condition(None)
             c.save(self.db_conn)
-            depender.save(self.db_conn)
-            depender.set_conditions(self.db_conn, [c.id_], 'conditions')
+            depender.set_conditions(self.db_conn, [c.id_])
             depender.save(self.db_conn)
             with self.assertRaises(HandledException):
                 c.remove(self.db_conn)
-            depender.set_conditions(self.db_conn, [], 'conditions')
+            depender.set_conditions(self.db_conn, [])
             depender.save(self.db_conn)
             c.remove(self.db_conn)
 
@@ -66,7 +67,7 @@ class TestsWithServer(TestCaseWithServer):
 
     def test_fail_POST_condition(self) -> None:
         """Test malformed/illegal POST /condition requests."""
-        # check invalid POST payloads
+        # check incomplete POST payloads
         url = '/condition'
         self.check_post({}, url, 400)
         self.check_post({'title': ''}, url, 400)
@@ -81,31 +82,32 @@ class TestsWithServer(TestCaseWithServer):
         """Test (valid) POST /condition and its effect on GET /condition[s]."""
         # test valid POST's effect on …
         post = {'title': 'foo', 'description': 'oof', 'is_active': False}
-        self.check_post(post, '/condition', 302, '/condition?id=1')
+        self.check_post(post, '/condition', redir='/condition?id=1')
         # … single /condition
-        cond = self.cond_as_dict(titles=['foo'], descriptions=['oof'])
-        assert isinstance(cond['_versioned'], dict)
-        expected_single = self.GET_condition_dict(cond)
+        expected_cond = self.cond_as_dict(titles=['foo'], descriptions=['oof'])
+        assert isinstance(expected_cond['_versioned'], dict)
+        expected_single = self.GET_condition_dict(expected_cond)
         self.check_json_get('/condition?id=1', expected_single)
         # … full /conditions
-        expected_all = self.GET_conditions_dict([cond])
+        expected_all = self.GET_conditions_dict([expected_cond])
         self.check_json_get('/conditions', expected_all)
         # test (no) effect of invalid POST to existing Condition on /condition
         self.check_post({}, '/condition?id=1', 400)
         self.check_json_get('/condition?id=1', expected_single)
         # test effect of POST changing title and activeness
         post = {'title': 'bar', 'description': 'oof', 'is_active': True}
-        self.check_post(post, '/condition?id=1', 302)
-        cond['_versioned']['title'][1] = 'bar'
-        cond['is_active'] = True
+        self.check_post(post, '/condition?id=1')
+        expected_cond['_versioned']['title'][1] = 'bar'
+        expected_cond['is_active'] = True
         self.check_json_get('/condition?id=1', expected_single)
-        # test deletion POST's effect on …
-        self.check_post({'delete': ''}, '/condition?id=1', 302, '/conditions')
-        cond = self.cond_as_dict()
+        # test deletion POST's effect, both to return id=1 into empty single, …
+        self.check_post({'delete': ''}, '/condition?id=1', redir='/conditions')
+        expected_cond = self.cond_as_dict()
         assert isinstance(expected_single['_library'], dict)
-        expected_single['_library']['Condition'] = self.as_refs([cond])
+        expected_single['_library']['Condition'] = self.as_refs(
+                [expected_cond])
         self.check_json_get('/condition?id=1', expected_single)
-        # … full /conditions
+        # … and full /conditions into empty list
         expected_all['conditions'] = []
         expected_all['_library'] = {}
         self.check_json_get('/conditions', expected_all)
@@ -117,7 +119,7 @@ class TestsWithServer(TestCaseWithServer):
         # make Condition and two Processes that among them establish all
         # possible ConditionsRelations to it, …
         cond_post = {'title': 'foo', 'description': 'oof', 'is_active': False}
-        self.check_post(cond_post, '/condition', 302, '/condition?id=1')
+        self.check_post(cond_post, '/condition', redir='/condition?id=1')
         proc1_post = {'title': 'A', 'description': '', 'effort': 1.0,
                       'conditions': [1], 'disables': [1]}
         proc2_post = {'title': 'B', 'description': '', 'effort': 1.0,
@@ -125,40 +127,41 @@ class TestsWithServer(TestCaseWithServer):
         self.post_process(1, proc1_post)
         self.post_process(2, proc2_post)
         # … then check /condition displays all these properly.
-        cond = self.cond_as_dict(titles=['foo'], descriptions=['oof'])
-        assert isinstance(cond['id'], int)
-        proc1 = self.proc_as_dict(conditions=[cond['id']],
-                                  disables=[cond['id']])
+        cond_expected = self.cond_as_dict(titles=['foo'], descriptions=['oof'])
+        assert isinstance(cond_expected['id'], int)
+        proc1 = self.proc_as_dict(conditions=[cond_expected['id']],
+                                  disables=[cond_expected['id']])
         proc2 = self.proc_as_dict(2, 'B',
-                                  blockers=[cond['id']],
-                                  enables=[cond['id']])
-        expected = self.GET_condition_dict(cond)
-        assert isinstance(expected['_library'], dict)
-        expected['enabled_processes'] = self.as_id_list([proc1])
-        expected['disabled_processes'] = self.as_id_list([proc2])
-        expected['enabling_processes'] = self.as_id_list([proc2])
-        expected['disabling_processes'] = self.as_id_list([proc1])
-        expected['_library']['Process'] = self.as_refs([proc1, proc2])
-        self.check_json_get('/condition?id=1', expected)
+                                  blockers=[cond_expected['id']],
+                                  enables=[cond_expected['id']])
+        display_expected = self.GET_condition_dict(cond_expected)
+        assert isinstance(display_expected['_library'], dict)
+        display_expected['enabled_processes'] = self.as_id_list([proc1])
+        display_expected['disabled_processes'] = self.as_id_list([proc2])
+        display_expected['enabling_processes'] = self.as_id_list([proc2])
+        display_expected['disabling_processes'] = self.as_id_list([proc1])
+        display_expected['_library']['Process'] = self.as_refs([proc1, proc2])
+        self.check_json_get('/condition?id=1', display_expected)
 
     def test_GET_conditions(self) -> None:
         """Test GET /conditions."""
         # test empty result on empty DB, default-settings on empty params
         expected = self.GET_conditions_dict([])
         self.check_json_get('/conditions', expected)
-        # test on meaningless non-empty params (incl. entirely un-used key),
+        # test ignorance of meaningless non-empty params (incl. unknown key),
         # that 'sort_by' default to 'title' (even if set to something else, as
         # long as without handler) and 'pattern' get preserved
         expected['pattern'] = 'bar'  # preserved despite zero effect!
+        expected['sort_by'] = 'title'  # for clarity (actually already set)
         url = '/conditions?sort_by=foo&pattern=bar&foo=x'
         self.check_json_get(url, expected)
         # test non-empty result, automatic (positive) sorting by title
-        post1 = {'is_active': False, 'title': 'foo', 'description': 'oof'}
-        post2 = {'is_active': False, 'title': 'bar', 'description': 'rab'}
-        post3 = {'is_active': True, 'title': 'baz', 'description': 'zab'}
-        self.check_post(post1, '/condition', 302, '/condition?id=1')
-        self.check_post(post2, '/condition', 302, '/condition?id=2')
-        self.check_post(post3, '/condition', 302, '/condition?id=3')
+        post_cond1 = {'is_active': False, 'title': 'foo', 'description': 'oof'}
+        post_cond2 = {'is_active': False, 'title': 'bar', 'description': 'rab'}
+        post_cond3 = {'is_active': True, 'title': 'baz', 'description': 'zab'}
+        self.check_post(post_cond1, '/condition', redir='/condition?id=1')
+        self.check_post(post_cond2, '/condition', redir='/condition?id=2')
+        self.check_post(post_cond3, '/condition', redir='/condition?id=3')
         cond1 = self.cond_as_dict(1, False, ['foo'], ['oof'])
         cond2 = self.cond_as_dict(2, False, ['bar'], ['rab'])
         cond3 = self.cond_as_dict(3, True, ['baz'], ['zab'])
@@ -166,7 +169,7 @@ class TestsWithServer(TestCaseWithServer):
         self.check_json_get('/conditions', expected)
         # test other sortings
         # (NB: by .is_active has two items of =False, their order currently
-        # is not explicitly made predictable, so mail fail until we do)
+        # is not explicitly made predictable, so _may_ fail until we do)
         expected['sort_by'] = '-title'
         expected['conditions'] = self.as_id_list([cond1, cond3, cond2])
         self.check_json_get('/conditions?sort_by=-title', expected)
@@ -182,7 +185,7 @@ class TestsWithServer(TestCaseWithServer):
         self.check_json_get('/conditions?pattern=ba', expected)
         # test pattern matching on description
         assert isinstance(expected['_library'], dict)
+        expected['pattern'] = 'of'
         expected['conditions'] = self.as_id_list([cond1])
         expected['_library']['Condition'] = self.as_refs([cond1])
-        expected['pattern'] = 'of'
         self.check_json_get('/conditions?pattern=of', expected)
diff --git a/tests/utils.py b/tests/utils.py
index 6654368..4d81c91 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -14,6 +14,7 @@ from plomtask.processes import Process, ProcessStep
 from plomtask.conditions import Condition
 from plomtask.days import Day
 from plomtask.todos import Todo
+from plomtask.versioned_attributes import VersionedAttribute
 from plomtask.exceptions import NotFoundException, HandledException
 
 
@@ -92,6 +93,16 @@ class TestCaseWithDB(TestCase):
         setattr(obj, attr_name, new_attr)
         return attr_name
 
+    def _versioned_attrs_and_owner(self, attr_name: str, type_: type
+                                   ) -> tuple[Any, list[str | float],
+                                              VersionedAttribute]:
+        owner = self.checked_class(None, **self.default_init_kwargs)
+        vals: list[str | float] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+        attr = getattr(owner, attr_name)
+        attr.set(vals[0])
+        attr.set(vals[1])
+        return owner, vals, attr
+
     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
         """Test both cache and DB equal content."""
         expected_cache = {}
@@ -107,40 +118,39 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
 
     @_within_checked_class
-    def test_saving_versioned(self) -> None:
+    def test_saving_versioned_attributes(self) -> None:
         """Test storage and initialization of versioned attributes."""
-        def retrieve_attr_vals() -> list[object]:
+
+        def retrieve_attr_vals(attr: VersionedAttribute, owner_id: int
+                               ) -> list[object]:
             attr_vals_saved: list[object] = []
-            assert hasattr(retrieved, 'id_')
             for row in self.db_conn.row_where(attr.table_name, 'parent',
-                                              retrieved.id_):
+                                              owner_id):
                 attr_vals_saved += [row[2]]
             return attr_vals_saved
-        for attr_name, type_ in self.test_versioneds.items():
+
+        owner_id = 1
+        for name, type_ in self.test_versioneds.items():
             # fail saving attributes on non-saved owner
-            owner = self.checked_class(None, **self.default_init_kwargs)
-            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-            attr = getattr(owner, attr_name)
-            attr.set(vals[0])
-            attr.set(vals[1])
+            owner, vals, attr = self._versioned_attrs_and_owner(name, type_)
             with self.assertRaises(NotFoundException):
                 attr.save(self.db_conn)
+            # check owner.save() created entries as expected in attr table
             owner.save(self.db_conn)
-            # check stored attribute is as expected
-            retrieved = self._load_from_db(owner.id_)[0]
-            attr = getattr(retrieved, attr_name)
-            self.assertEqual(sorted(attr.history.values()), vals)
-            # check owner.save() created entries in attr table
-            attr_vals_saved = retrieve_attr_vals()
+            attr_vals_saved = retrieve_attr_vals(attr, owner_id)
             self.assertEqual(vals, attr_vals_saved)
-            # check setting new val to attr inconsequential to DB without save
+            # check changing attr val without save affects owner in memory …
             attr.set(vals[0])
-            attr_vals_saved = retrieve_attr_vals()
+            cmp_attr = getattr(owner, name)
+            self.assertEqual(cmp_attr.history, attr.history)
+            # … but does not yet affect DB
+            attr_vals_saved = retrieve_attr_vals(attr, owner_id)
             self.assertEqual(vals, attr_vals_saved)
-            # check save finally adds new val
+            # check individual attr.save also stores new val to DB
             attr.save(self.db_conn)
-            attr_vals_saved = retrieve_attr_vals()
+            attr_vals_saved = retrieve_attr_vals(attr, owner_id)
             self.assertEqual(vals + [vals[0]], attr_vals_saved)
+            owner_id += 1
 
     @_within_checked_class
     def test_saving_and_caching(self) -> None:
@@ -157,8 +167,8 @@ class TestCaseWithDB(TestCase):
         # check .cache() fills cache, but not DB
         obj1.cache()
         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
-        db_found = self._load_from_db(id1)
-        self.assertEqual(db_found, [])
+        found_in_db = self._load_from_db(id1)
+        self.assertEqual(found_in_db, [])
         # check .save() sets ID (for int IDs), updates cache, and fills DB
         # (expect ID to be set to id1, despite obj1 already having that as ID:
         # it's generated by cursor.lastrowid on the DB table, and with obj1
@@ -166,10 +176,12 @@ class TestCaseWithDB(TestCase):
         id_input = None if isinstance(id1, int) else id1
         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
         obj2.save(self.db_conn)
-        obj2_hash = hash(obj2)
         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
-        db_found += self._load_from_db(id1)
-        self.assertEqual([hash(o) for o in db_found], [obj2_hash])
+        # NB: we'll only compare hashes because obj2 itself disappears on
+        # .from_table_row-trioggered database reload
+        obj2_hash = hash(obj2)
+        found_in_db += self._load_from_db(id1)
+        self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
         # check we cannot overwrite obj2 with obj1 despite its same ID,
         # since it has disappeared now
         with self.assertRaises(HandledException):
@@ -225,6 +237,8 @@ class TestCaseWithDB(TestCase):
                                           'id', obj.id_):
             # check .from_table_row reproduces state saved, no matter if obj
             # later changed (with caching even)
+            # NB: we'll only compare hashes because obj itself disappears on
+            # .from_table_row-triggered database reload
             hash_original = hash(obj)
             attr_name = self._change_obj(obj)
             obj.cache()
@@ -236,17 +250,14 @@ class TestCaseWithDB(TestCase):
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
         # check .from_table_row also reads versioned attributes from DB
-        for attr_name, type_ in self.test_versioneds.items():
-            owner = self.checked_class(None)
-            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-            attr = getattr(owner, attr_name)
-            attr.set(vals[0])
-            attr.set(vals[1])
+        for name, type_ in self.test_versioneds.items():
+            owner, vals, attr = self._versioned_attrs_and_owner(name, type_)
             owner.save(self.db_conn)
+            attr.set(vals[0])
             for row in self.db_conn.row_where(owner.table_name, 'id',
                                               owner.id_):
                 retrieved = owner.__class__.from_table_row(self.db_conn, row)
-                attr = getattr(retrieved, attr_name)
+                attr = getattr(retrieved, name)
                 self.assertEqual(sorted(attr.history.values()), vals)
 
     @_within_checked_class
@@ -404,7 +415,7 @@ class TestCaseWithServer(TestCaseWithDB):
         self.assertEqual(self.conn.getresponse().status, expected_code)
 
     def check_post(self, data: Mapping[str, object], target: str,
-                   expected_code: int, redirect_location: str = '') -> None:
+                   expected_code: int = 302, redir: str = '') -> None:
         """Check that POST of data to target yields expected_code."""
         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
         headers = {'Content-Type': 'application/x-www-form-urlencoded',
@@ -412,9 +423,8 @@ class TestCaseWithServer(TestCaseWithDB):
         self.conn.request('POST', target,
                           body=encoded_form_data, headers=headers)
         if 302 == expected_code:
-            if redirect_location == '':
-                redirect_location = target
-            self.check_redirect(redirect_location)
+            redir = target if redir == '' else redir
+            self.check_redirect(redir)
         else:
             self.assertEqual(self.conn.getresponse().status, expected_code)
 
@@ -432,8 +442,8 @@ class TestCaseWithServer(TestCaseWithDB):
         """POST basic Process."""
         if not form_data:
             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
-        self.check_post(form_data, f'/process?id={id_}', 302,
-                        f'/process?id={id_}')
+        self.check_post(form_data, f'/process?id={id_}',
+                        redir=f'/process?id={id_}')
         return form_data
 
     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
@@ -443,6 +453,7 @@ class TestCaseWithServer(TestCaseWithDB):
         timestamp keys of VersionedAttribute history keys into integers
         counting chronologically forward from 0.
         """
+
         def rewrite_history_keys_in(item: Any) -> Any:
             if isinstance(item, dict):
                 if '_versioned' in item.keys():
@@ -457,6 +468,7 @@ class TestCaseWithServer(TestCaseWithDB):
             elif isinstance(item, list):
                 item[:] = [rewrite_history_keys_in(i) for i in item]
             return item
+
         self.conn.request('GET', path)
         response = self.conn.getresponse()
         self.assertEqual(response.status, 200)
-- 
2.30.2