From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 17 Jun 2024 23:54:46 +0000 (+0200)
Subject: Refactor BaseModel.from_table_row testing.
X-Git-Url: https://plomlompom.com/repos/%7B%7B%20web_path%20%7D%7D/%7B%7Bprefix%7D%7D/%7B%7Bdb.prefix%7D%7D/template?a=commitdiff_plain;h=e3bfd84f9061d5f03ec5f5764f75e4137505ea45;p=taskplom

Refactor BaseModel.from_table_row testing.
---

diff --git a/plomtask/db.py b/plomtask/db.py
index 385e798..853b4c6 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -344,7 +344,7 @@ class BaseModel(Generic[BaseModelId]):
             return obj
         return None
 
-    def _cache(self) -> None:
+    def cache(self) -> None:
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
@@ -383,7 +383,7 @@ class BaseModel(Generic[BaseModelId]):
             table_name = attr.table_name
             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
                 attr.history_from_row(row_)
-        obj._cache()
+        obj.cache()
         return obj
 
     @classmethod
@@ -497,7 +497,7 @@ class BaseModel(Generic[BaseModelId]):
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
-        self._cache()
+        self.cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
diff --git a/tests/conditions.py b/tests/conditions.py
index 5270812..afb1841 100644
--- a/tests/conditions.py
+++ b/tests/conditions.py
@@ -19,9 +19,9 @@ class TestsWithDB(TestCaseWithDB):
     default_init_kwargs = {'is_active': False}
     test_versioneds = {'title': str, 'description': str}
 
-    def test_Condition_from_table_row(self) -> None:
+    def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class from DB"""
-        self.check_from_table_row()
+        super().test_from_table_row()
         self.check_versioned_from_table_row('title', str)
         self.check_versioned_from_table_row('description', str)
 
diff --git a/tests/days.py b/tests/days.py
index 901667f..e4c9de5 100644
--- a/tests/days.py
+++ b/tests/days.py
@@ -53,10 +53,6 @@ class TestsWithDB(TestCaseWithDB):
         kwargs = {'date': self.default_ids[0], 'comment': 'foo'}
         self.check_saving_and_caching(**kwargs)
 
-    def test_Day_from_table_row(self) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
-        self.check_from_table_row()
-
     def test_Day_by_id(self) -> None:
         """Test .by_id()."""
         self.check_by_id()
diff --git a/tests/processes.py b/tests/processes.py
index 4d2252c..d54fe84 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -63,9 +63,9 @@ class TestsWithDB(TestCaseWithDB):
         self.assertEqual(sorted(r.enables), sorted(set2))
         self.assertEqual(sorted(r.disables), sorted(set3))
 
-    def test_Process_from_table_row(self) -> None:
+    def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class from DB"""
-        self.check_from_table_row()
+        super().test_from_table_row()
         self.check_versioned_from_table_row('title', str)
         self.check_versioned_from_table_row('description', str)
         self.check_versioned_from_table_row('effort', float)
diff --git a/tests/utils.py b/tests/utils.py
index d6c5b20..f76fe33 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -130,17 +130,33 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
-    def check_from_table_row(self, *args: Any) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class from DB."""
+        if not hasattr(self, 'checked_class'):
+            return
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
             hash_original = hash(obj)
+            attr_name = self.checked_class.to_save[-1]
+            attr = getattr(obj, attr_name)
+            if isinstance(attr, (int, float)):
+                setattr(obj, attr_name, attr + 1)
+            elif isinstance(attr, str):
+                setattr(obj, attr_name, attr + "_")
+            elif isinstance(attr, bool):
+                setattr(obj, attr_name, not attr)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
             self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())