home · contact · privacy
Refactor remaining test.utils helpers into actual tests.
[plomtask] / tests / utils.py
index 60157104624ac79a59581757dd58a0344b30da6b..25cc9ba1e79d663ec692570f6f4c1fce4eaaf911 100644 (file)
@@ -77,6 +77,33 @@ class TestCaseWithDB(TestCase):
                                                            row)]
         return db_found
 
+    def _change_obj(self, obj: object) -> str:
+        attr_name: str = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
+        return attr_name
+
+    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+        """Test both cache and DB equal content."""
+        expected_cache = {}
+        for item in content:
+            expected_cache[item.id_] = item
+        self.assertEqual(self.checked_class.get_cache(), expected_cache)
+        hashes_content = [hash(x) for x in content]
+        db_found: list[Any] = []
+        for item in content:
+            assert isinstance(item.id_, type(self.default_ids[0]))
+            db_found += self._load_from_db(item.id_)
+        hashes_db_found = [hash(x) for x in db_found]
+        self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
+
     @_within_checked_class
     def test_saving_versioned(self) -> None:
         """Test storage and initialization of versioned attributes."""
@@ -113,20 +140,6 @@ class TestCaseWithDB(TestCase):
             attr_vals_saved = retrieve_attr_vals()
             self.assertEqual(vals + [vals[0]], attr_vals_saved)
 
-    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
-        """Test both cache and DB equal content."""
-        expected_cache = {}
-        for item in content:
-            expected_cache[item.id_] = item
-        self.assertEqual(self.checked_class.get_cache(), expected_cache)
-        hashes_content = [hash(x) for x in content]
-        db_found: list[Any] = []
-        for item in content:
-            assert isinstance(item.id_, type(self.default_ids[0]))
-            db_found += self._load_from_db(item.id_)
-        hashes_db_found = [hash(x) for x in db_found]
-        self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
-
     @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test effects of .cache() and .save()."""
@@ -198,20 +211,13 @@ class TestCaseWithDB(TestCase):
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
-        assert isinstance(obj.id_, type(self.default_ids[0]))
+        assert isinstance(obj.id_, type(id_))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
             # check .from_table_row reproduces state saved, no matter if obj
             # later changed (with caching even)
             hash_original = hash(obj)
-            attr_name = self.checked_class.to_save[-1]
-            attr = getattr(obj, attr_name)
-            if isinstance(attr, (int, float)):
-                setattr(obj, attr_name, attr + 1)
-            elif isinstance(attr, str):
-                setattr(obj, attr_name, attr + "_")
-            elif isinstance(attr, bool):
-                setattr(obj, attr_name, not attr)
+            attr_name = self._change_obj(obj)
             obj.cache()
             to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
@@ -220,20 +226,19 @@ class TestCaseWithDB(TestCase):
             # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
-
-    def check_versioned_from_table_row(self, attr_name: str,
-                                       type_: type) -> None:
-        """Test .from_table_row() reads versioned attributes from DB."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
-            retrieved = owner.__class__.from_table_row(self.db_conn, row)
-            attr = getattr(retrieved, attr_name)
-            self.assertEqual(sorted(attr.history.values()), vals)
+        # check .from_table_row also reads versioned attributes from DB
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            for row in self.db_conn.row_where(owner.table_name, 'id',
+                                              owner.id_):
+                retrieved = owner.__class__.from_table_row(self.db_conn, row)
+                attr = getattr(retrieved, attr_name)
+                self.assertEqual(sorted(attr.history.values()), vals)
 
     @_within_checked_class
     def test_all(self) -> None:
@@ -259,36 +264,38 @@ class TestCaseWithDB(TestCase):
         id1 = self.default_ids[0]
         obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
-        attr_name = self.checked_class.to_save[-1]
-        attr = getattr(obj, attr_name)
-        new_attr: str | int | float | bool
-        if isinstance(attr, (int, float)):
-            new_attr = attr + 1
-        elif isinstance(attr, str):
-            new_attr = attr + '_'
-        elif isinstance(attr, bool):
-            new_attr = not attr
-        setattr(obj, attr_name, new_attr)
+        # change object, expect retrieved through .by_id to carry change
+        attr_name = self._change_obj(obj)
+        new_attr = getattr(obj, attr_name)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
-    def check_versioned_singularity(self) -> None:
+    @_within_checked_class
+    def test_versioned_singularity_title(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
-        obj = self.checked_class(None)  # pylint: disable=not-callable
-        obj.save(self.db_conn)
-        assert isinstance(obj.id_, int)
-        obj.title.set('named')
-        retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
-        self.assertEqual(obj.title.history, retrieved.title.history)
+        if 'title' in self.test_versioneds:
+            obj = self.checked_class(None)
+            obj.save(self.db_conn)
+            assert isinstance(obj.id_, int)
+            # change obj, expect retrieved through .by_id to carry change
+            obj.title.set('named')
+            retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
+            self.assertEqual(obj.title.history, retrieved.title.history)
 
-    def check_remove(self, *args: Any) -> None:
+    @_within_checked_class
+    def test_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
+        # check removal only works after saving
         with self.assertRaises(HandledException):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
+        # check access to obj fails after removal
+        with self.assertRaises(HandledException):
+            print(obj.id_)
+        # check DB and cache now empty
         self.check_identity_with_cache_and_db([])