home · contact · privacy
Refactor BaseModel.from_table_row testing.
[plomtask] / tests / utils.py
index 86d049dac2387b97f441d228ecc3c33e9bfe6552..f76fe33c93fc65d68aa07b7ca04aa0b98c762072 100644 (file)
@@ -130,17 +130,33 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
-    def check_from_table_row(self, *args: Any) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class from DB."""
+        if not hasattr(self, 'checked_class'):
+            return
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
             hash_original = hash(obj)
+            attr_name = self.checked_class.to_save[-1]
+            attr = getattr(obj, attr_name)
+            if isinstance(attr, (int, float)):
+                setattr(obj, attr_name, attr + 1)
+            elif isinstance(attr, str):
+                setattr(obj, attr_name, attr + "_")
+            elif isinstance(attr, bool):
+                setattr(obj, attr_name, not attr)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
             self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
@@ -274,22 +290,23 @@ class TestCaseWithServer(TestCaseWithDB):
         """Compare JSON on GET path with expected.
 
         To simplify comparison of VersionedAttribute histories, transforms
-        keys under "history"-named dicts into bracketed integer strings
-        counting upwards in chronology.
+        timestamp keys of VersionedAttribute history keys into integers
+        counting chronologically forward from 0.
         """
         def rewrite_history_keys_in(item: Any) -> Any:
             if isinstance(item, dict):
-                if 'history' in item.keys():
-                    vals = item['history'].values()
-                    history = {}
-                    for i, val in enumerate(vals):
-                        history[f'[{i}]'] = val
-                    item['history'] = history
+                if '_versioned' in item.keys():
+                    for k in item['_versioned']:
+                        vals = item['_versioned'][k].values()
+                        history = {}
+                        for i, val in enumerate(vals):
+                            history[i] = val
+                        item['_versioned'][k] = history
                 for k in list(item.keys()):
                     rewrite_history_keys_in(item[k])
             elif isinstance(item, list):
                 item[:] = [rewrite_history_keys_in(i) for i in item]
-            return item 
+            return item
         self.conn.request('GET', path)
         response = self.conn.getresponse()
         self.assertEqual(response.status, 200)