From: Christian Heller <c.heller@plomlompom.de>
Date: Tue, 18 Jun 2024 08:05:18 +0000 (+0200)
Subject: Refactor remaining test.utils helpers into actual tests.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/do_todos?a=commitdiff_plain;h=c021152e6566c8374170de916c69d6b5c816cd54;p=plomtask

Refactor remaining test.utils helpers into actual tests.
---

diff --git a/plomtask/db.py b/plomtask/db.py
index 797b08e..cce2630 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -337,9 +337,8 @@ class BaseModel(Generic[BaseModelId]):
     def _get_cached(cls: type[BaseModelInstance],
                     id_: BaseModelId) -> BaseModelInstance | None:
         """Get object of id_ from class's cache, or None if not found."""
-        # pylint: disable=consider-iterating-dictionary
         cache = cls.get_cache()
-        if id_ in cache.keys():
+        if id_ in cache:
             obj = cache[id_]
             assert isinstance(obj, cls)
             return obj
diff --git a/tests/conditions.py b/tests/conditions.py
index 969942b..4ac69a8 100644
--- a/tests/conditions.py
+++ b/tests/conditions.py
@@ -19,19 +19,9 @@ class TestsWithDB(TestCaseWithDB):
     default_init_kwargs = {'is_active': False}
     test_versioneds = {'title': str, 'description': str}
 
-    def test_from_table_row(self) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
-        super().test_from_table_row()
-        self.check_versioned_from_table_row('title', str)
-        self.check_versioned_from_table_row('description', str)
-
-    def test_Condition_versioned_attributes_singularity(self) -> None:
-        """Test behavior of VersionedAttributes on saving (with .title)."""
-        self.check_versioned_singularity()
-
-    def test_Condition_remove(self) -> None:
+    def test_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
-        self.check_remove()
+        super().test_remove()
         proc = Process(None)
         proc.save(self.db_conn)
         todo = Todo(None, proc, False, '2024-01-01')
diff --git a/tests/days.py b/tests/days.py
index 9fb12ad..36d0285 100644
--- a/tests/days.py
+++ b/tests/days.py
@@ -82,10 +82,6 @@ class TestsWithDB(TestCaseWithDB):
                                                   'today', 'today'),
                          [today])
 
-    def test_Day_remove(self) -> None:
-        """Test .remove() effects on DB and cache."""
-        self.check_remove()
-
 
 class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
diff --git a/tests/processes.py b/tests/processes.py
index d33aa8f..481d2d4 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -65,11 +65,8 @@ class TestsWithDB(TestCaseWithDB):
         self.assertEqual(sorted(r.disables), sorted(set3))
 
     def test_from_table_row(self) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+        """Test .from_table_row() properly reads in class from DB."""
         super().test_from_table_row()
-        self.check_versioned_from_table_row('title', str)
-        self.check_versioned_from_table_row('description', str)
-        self.check_versioned_from_table_row('effort', float)
         p, set1, set2, set3 = self.p_of_conditions()
         p.save(self.db_conn)
         assert isinstance(p.id_, int)
@@ -183,13 +180,9 @@ class TestsWithDB(TestCaseWithDB):
             method(self.db_conn, [c1.id_, c2.id_])
             self.assertEqual(getattr(p, target), [c1, c2])
 
-    def test_Process_versioned_attributes_singularity(self) -> None:
-        """Test behavior of VersionedAttributes on saving (with .title)."""
-        self.check_versioned_singularity()
-
-    def test_Process_removal(self) -> None:
+    def test_remove(self) -> None:
         """Test removal of Processes and ProcessSteps."""
-        self.check_remove()
+        super().test_remove()
         p1, p2, p3 = self.three_processes()
         assert isinstance(p1.id_, int)
         assert isinstance(p2.id_, int)
@@ -222,28 +215,24 @@ class TestsWithDB(TestCaseWithDB):
 class TestsWithDBForProcessStep(TestCaseWithDB):
     """Module tests requiring DB setup."""
     checked_class = ProcessStep
-    default_init_kwargs = {'owner_id': 2, 'step_process_id': 3,
-                           'parent_step_id': 4}
+    default_init_kwargs = {'owner_id': 1, 'step_process_id': 2,
+                           'parent_step_id': 3}
 
     def setUp(self) -> None:
         super().setUp()
-        p = Process(1)
-        p.save(self.db_conn)
-        p = Process(2)
-        p.save(self.db_conn)
+        self.p1 = Process(1)
+        self.p1.save(self.db_conn)
 
-    def test_ProcessStep_remove(self) -> None:
+    def test_remove(self) -> None:
         """Test .remove and unsetting of owner's .explicit_steps entry."""
-        p1 = Process(None)
-        p2 = Process(None)
-        p1.save(self.db_conn)
+        p2 = Process(2)
         p2.save(self.db_conn)
-        assert isinstance(p1.id_, int)
+        assert isinstance(self.p1.id_, int)
         assert isinstance(p2.id_, int)
-        step = ProcessStep(None, p1.id_, p2.id_, None)
-        p1.set_steps(self.db_conn, [step])
+        step = ProcessStep(None, self.p1.id_, p2.id_, None)
+        self.p1.set_steps(self.db_conn, [step])
         step.remove(self.db_conn)
-        self.assertEqual(p1.explicit_steps, [])
+        self.assertEqual(self.p1.explicit_steps, [])
         self.check_identity_with_cache_and_db([])
 
 
diff --git a/tests/utils.py b/tests/utils.py
index 6015710..25cc9ba 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -77,6 +77,33 @@ class TestCaseWithDB(TestCase):
                                                            row)]
         return db_found
 
+    def _change_obj(self, obj: object) -> str:
+        attr_name: str = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
+        return attr_name
+
+    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+        """Test both cache and DB equal content."""
+        expected_cache = {}
+        for item in content:
+            expected_cache[item.id_] = item
+        self.assertEqual(self.checked_class.get_cache(), expected_cache)
+        hashes_content = [hash(x) for x in content]
+        db_found: list[Any] = []
+        for item in content:
+            assert isinstance(item.id_, type(self.default_ids[0]))
+            db_found += self._load_from_db(item.id_)
+        hashes_db_found = [hash(x) for x in db_found]
+        self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
+
     @_within_checked_class
     def test_saving_versioned(self) -> None:
         """Test storage and initialization of versioned attributes."""
@@ -113,20 +140,6 @@ class TestCaseWithDB(TestCase):
             attr_vals_saved = retrieve_attr_vals()
             self.assertEqual(vals + [vals[0]], attr_vals_saved)
 
-    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
-        """Test both cache and DB equal content."""
-        expected_cache = {}
-        for item in content:
-            expected_cache[item.id_] = item
-        self.assertEqual(self.checked_class.get_cache(), expected_cache)
-        hashes_content = [hash(x) for x in content]
-        db_found: list[Any] = []
-        for item in content:
-            assert isinstance(item.id_, type(self.default_ids[0]))
-            db_found += self._load_from_db(item.id_)
-        hashes_db_found = [hash(x) for x in db_found]
-        self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
-
     @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test effects of .cache() and .save()."""
@@ -198,20 +211,13 @@ class TestCaseWithDB(TestCase):
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
-        assert isinstance(obj.id_, type(self.default_ids[0]))
+        assert isinstance(obj.id_, type(id_))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
             # check .from_table_row reproduces state saved, no matter if obj
             # later changed (with caching even)
             hash_original = hash(obj)
-            attr_name = self.checked_class.to_save[-1]
-            attr = getattr(obj, attr_name)
-            if isinstance(attr, (int, float)):
-                setattr(obj, attr_name, attr + 1)
-            elif isinstance(attr, str):
-                setattr(obj, attr_name, attr + "_")
-            elif isinstance(attr, bool):
-                setattr(obj, attr_name, not attr)
+            attr_name = self._change_obj(obj)
             obj.cache()
             to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
@@ -220,20 +226,19 @@ class TestCaseWithDB(TestCase):
             # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
-
-    def check_versioned_from_table_row(self, attr_name: str,
-                                       type_: type) -> None:
-        """Test .from_table_row() reads versioned attributes from DB."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
-            retrieved = owner.__class__.from_table_row(self.db_conn, row)
-            attr = getattr(retrieved, attr_name)
-            self.assertEqual(sorted(attr.history.values()), vals)
+        # check .from_table_row also reads versioned attributes from DB
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            for row in self.db_conn.row_where(owner.table_name, 'id',
+                                              owner.id_):
+                retrieved = owner.__class__.from_table_row(self.db_conn, row)
+                attr = getattr(retrieved, attr_name)
+                self.assertEqual(sorted(attr.history.values()), vals)
 
     @_within_checked_class
     def test_all(self) -> None:
@@ -259,36 +264,38 @@ class TestCaseWithDB(TestCase):
         id1 = self.default_ids[0]
         obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
-        attr_name = self.checked_class.to_save[-1]
-        attr = getattr(obj, attr_name)
-        new_attr: str | int | float | bool
-        if isinstance(attr, (int, float)):
-            new_attr = attr + 1
-        elif isinstance(attr, str):
-            new_attr = attr + '_'
-        elif isinstance(attr, bool):
-            new_attr = not attr
-        setattr(obj, attr_name, new_attr)
+        # change object, expect retrieved through .by_id to carry change
+        attr_name = self._change_obj(obj)
+        new_attr = getattr(obj, attr_name)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
-    def check_versioned_singularity(self) -> None:
+    @_within_checked_class
+    def test_versioned_singularity_title(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
-        obj = self.checked_class(None)  # pylint: disable=not-callable
-        obj.save(self.db_conn)
-        assert isinstance(obj.id_, int)
-        obj.title.set('named')
-        retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
-        self.assertEqual(obj.title.history, retrieved.title.history)
+        if 'title' in self.test_versioneds:
+            obj = self.checked_class(None)
+            obj.save(self.db_conn)
+            assert isinstance(obj.id_, int)
+            # change obj, expect retrieved through .by_id to carry change
+            obj.title.set('named')
+            retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
+            self.assertEqual(obj.title.history, retrieved.title.history)
 
-    def check_remove(self, *args: Any) -> None:
+    @_within_checked_class
+    def test_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
+        # check removal only works after saving
         with self.assertRaises(HandledException):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
+        # check access to obj fails after removal
+        with self.assertRaises(HandledException):
+            print(obj.id_)
+        # check DB and cache now empty
         self.check_identity_with_cache_and_db([])