home · contact · privacy
Refactor BaseModel.from_table_row testing.
authorChristian Heller <c.heller@plomlompom.de>
Mon, 17 Jun 2024 23:54:46 +0000 (01:54 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Mon, 17 Jun 2024 23:54:46 +0000 (01:54 +0200)
plomtask/db.py
tests/conditions.py
tests/days.py
tests/processes.py
tests/utils.py

index 385e79855286c45087537b3b63b8b024fd553ca0..853b4c68c65780e339b77785b961788370373648 100644 (file)
@@ -344,7 +344,7 @@ class BaseModel(Generic[BaseModelId]):
             return obj
         return None
 
             return obj
         return None
 
-    def _cache(self) -> None:
+    def cache(self) -> None:
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
@@ -383,7 +383,7 @@ class BaseModel(Generic[BaseModelId]):
             table_name = attr.table_name
             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
                 attr.history_from_row(row_)
             table_name = attr.table_name
             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
                 attr.history_from_row(row_)
-        obj._cache()
+        obj.cache()
         return obj
 
     @classmethod
         return obj
 
     @classmethod
@@ -497,7 +497,7 @@ class BaseModel(Generic[BaseModelId]):
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
-        self._cache()
+        self.cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
index 5270812a41e9f7fc369a426876fe316b349e34b7..afb1841e9a6c230a64f66c9144169999e70da372 100644 (file)
@@ -19,9 +19,9 @@ class TestsWithDB(TestCaseWithDB):
     default_init_kwargs = {'is_active': False}
     test_versioneds = {'title': str, 'description': str}
 
     default_init_kwargs = {'is_active': False}
     test_versioneds = {'title': str, 'description': str}
 
-    def test_Condition_from_table_row(self) -> None:
+    def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class from DB"""
         """Test .from_table_row() properly reads in class from DB"""
-        self.check_from_table_row()
+        super().test_from_table_row()
         self.check_versioned_from_table_row('title', str)
         self.check_versioned_from_table_row('description', str)
 
         self.check_versioned_from_table_row('title', str)
         self.check_versioned_from_table_row('description', str)
 
index 901667f4c6e0276a2800bf4b21b15b03a17be2fb..e4c9de53b0576f685a968a4b7682749c5fe1a7f5 100644 (file)
@@ -53,10 +53,6 @@ class TestsWithDB(TestCaseWithDB):
         kwargs = {'date': self.default_ids[0], 'comment': 'foo'}
         self.check_saving_and_caching(**kwargs)
 
         kwargs = {'date': self.default_ids[0], 'comment': 'foo'}
         self.check_saving_and_caching(**kwargs)
 
-    def test_Day_from_table_row(self) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
-        self.check_from_table_row()
-
     def test_Day_by_id(self) -> None:
         """Test .by_id()."""
         self.check_by_id()
     def test_Day_by_id(self) -> None:
         """Test .by_id()."""
         self.check_by_id()
index 4d2252c4b748718ee5f61844521322a15ff03669..d54fe84bb1041f1084fd341803576c99457e1877 100644 (file)
@@ -63,9 +63,9 @@ class TestsWithDB(TestCaseWithDB):
         self.assertEqual(sorted(r.enables), sorted(set2))
         self.assertEqual(sorted(r.disables), sorted(set3))
 
         self.assertEqual(sorted(r.enables), sorted(set2))
         self.assertEqual(sorted(r.disables), sorted(set3))
 
-    def test_Process_from_table_row(self) -> None:
+    def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class from DB"""
         """Test .from_table_row() properly reads in class from DB"""
-        self.check_from_table_row()
+        super().test_from_table_row()
         self.check_versioned_from_table_row('title', str)
         self.check_versioned_from_table_row('description', str)
         self.check_versioned_from_table_row('effort', float)
         self.check_versioned_from_table_row('title', str)
         self.check_versioned_from_table_row('description', str)
         self.check_versioned_from_table_row('effort', float)
index d6c5b20ac7882281d4958e99e3dbcd6a35de708b..f76fe33c93fc65d68aa07b7ca04aa0b98c762072 100644 (file)
@@ -130,17 +130,33 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
-    def check_from_table_row(self, *args: Any) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class from DB."""
+        if not hasattr(self, 'checked_class'):
+            return
         id_ = self.default_ids[0]
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
             hash_original = hash(obj)
             hash_original = hash(obj)
+            attr_name = self.checked_class.to_save[-1]
+            attr = getattr(obj, attr_name)
+            if isinstance(attr, (int, float)):
+                setattr(obj, attr_name, attr + 1)
+            elif isinstance(attr, str):
+                setattr(obj, attr_name, attr + "_")
+            elif isinstance(attr, bool):
+                setattr(obj, attr_name, not attr)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
             self.assertEqual(hash_original, hash(retrieved))
             self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())