home · contact · privacy
Refactor .all() tests.
[plomtask] / tests / utils.py
index 15a53ae0ddc0b78835b5baacea15f43d3a81cba0..d8bd247bd8391682e5f7577c56f5e77d86216e9f 100644 (file)
@@ -2,6 +2,7 @@
 from unittest import TestCase
 from threading import Thread
 from http.client import HTTPConnection
+from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
@@ -129,17 +130,33 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
-    def check_from_table_row(self, *args: Any) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class directly from DB."""
+        if not hasattr(self, 'checked_class'):
+            return
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
             hash_original = hash(obj)
+            attr_name = self.checked_class.to_save[-1]
+            attr = getattr(obj, attr_name)
+            if isinstance(attr, (int, float)):
+                setattr(obj, attr_name, attr + 1)
+            elif isinstance(attr, str):
+                setattr(obj, attr_name, attr + "_")
+            elif isinstance(attr, bool):
+                setattr(obj, attr_name, not attr)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
             self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
@@ -157,34 +174,44 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
-    def check_all(self) -> tuple[Any, Any, Any]:
-        """Test .all()."""
-        # pylint: disable=not-callable
-        item1 = self.checked_class(self.default_ids[0])
-        item2 = self.checked_class(self.default_ids[1])
-        item3 = self.checked_class(self.default_ids[2])
-        # check pre-save .all() returns empty list
+    def test_all(self) -> None:
+        """Test .all() and its relation to cache and savings."""
+        if not hasattr(self, 'checked_class'):
+            return
+        id_1, id_2, id_3 = self.default_ids
+        item1 = self.checked_class(id_1, **self.default_init_kwargs)
+        item2 = self.checked_class(id_2, **self.default_init_kwargs)
+        item3 = self.checked_class(id_3, **self.default_init_kwargs)
+        # check .all() returns empty list on un-cached items
         self.assertEqual(self.checked_class.all(self.db_conn), [])
-        # check that all() shows all saved, but no unsaved items
-        item1.save(self.db_conn)
+        # check that all() shows only cached/saved items
+        item1.cache()
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
-        return item1, item2, item3
 
-    def check_singularity(self, defaulting_field: str,
-                          non_default_value: Any, *args: Any) -> None:
+    def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
+        if not hasattr(self, 'checked_class'):
+            return
         id1 = self.default_ids[0]
-        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
-        setattr(obj, defaulting_field, non_default_value)
+        attr_name = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
-        self.assertEqual(non_default_value,
-                         getattr(retrieved, defaulting_field))
+        self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
@@ -268,3 +295,31 @@ class TestCaseWithServer(TestCaseWithDB):
         self.check_post(form_data, f'/process?id={id_}', 302,
                         f'/process?id={id_}')
         return form_data
+
+    def check_json_get(self, path: str, expected: dict[str, object]) -> None:
+        """Compare JSON on GET path with expected.
+
+        To simplify comparison of VersionedAttribute histories, transforms
+        timestamp keys of VersionedAttribute history keys into integers
+        counting chronologically forward from 0.
+        """
+        def rewrite_history_keys_in(item: Any) -> Any:
+            if isinstance(item, dict):
+                if '_versioned' in item.keys():
+                    for k in item['_versioned']:
+                        vals = item['_versioned'][k].values()
+                        history = {}
+                        for i, val in enumerate(vals):
+                            history[i] = val
+                        item['_versioned'][k] = history
+                for k in list(item.keys()):
+                    rewrite_history_keys_in(item[k])
+            elif isinstance(item, list):
+                item[:] = [rewrite_history_keys_in(i) for i in item]
+            return item
+        self.conn.request('GET', path)
+        response = self.conn.getresponse()
+        self.assertEqual(response.status, 200)
+        retrieved = json_loads(response.read().decode())
+        rewrite_history_keys_in(retrieved)
+        self.assertEqual(expected, retrieved)