home · contact · privacy
Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
[plomtask] / tests / utils.py
index 13e4f949d177a7d73456d2546db9f49de3e33fb1..55c948a409dbf1b2c43f775c0d39106ba3ed510d 100644 (file)
@@ -1,11 +1,13 @@
 """Shared test utilities."""
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from threading import Thread
 from http.client import HTTPConnection
+from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
-from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
@@ -60,19 +62,33 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if hasattr(self, 'checked_class'):
+                f(self)
+        return wrapper
+
+    @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
-        if not hasattr(self, 'checked_class'):
-            return
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
-        for k, v in self.test_versioneds.items():
-            self.check_saving_of_versioned(k, v)
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
 
 
-    def check_storage(self, content: list[Any]) -> None:
-        """Test cache and DB equal content."""
+    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+        """Test both cache and DB equal content."""
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
@@ -92,54 +108,74 @@ class TestCaseWithDB(TestCase):
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
         # check saving sets core attributes properly
         obj.save(self.db_conn)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
         # check saving stored properly in cache and DB
         # check saving sets core attributes properly
         obj.save(self.db_conn)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
         # check saving stored properly in cache and DB
-        self.check_storage([obj])
-
-    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
-        """Test owner's versioned attributes."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
-        attr = getattr(retrieved, attr_name)
-        self.assertEqual(sorted(attr.history.values()), vals)
+        self.check_identity_with_cache_and_db([obj])
 
 
-    def check_by_id(self) -> None:
-        """Test .by_id(), including creation."""
+    @_within_checked_class
+    def test_by_id(self) -> None:
+        """Test .by_id()."""
+        id1, id2, _ = self.default_ids
         # check failure if not yet saved
         # check failure if not yet saved
-        id1, id2 = self.default_ids[0], self.default_ids[1]
-        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        obj1 = self.checked_class(id1, **self.default_init_kwargs)
         with self.assertRaises(NotFoundException):
             self.checked_class.by_id(self.db_conn, id1)
         with self.assertRaises(NotFoundException):
             self.checked_class.by_id(self.db_conn, id1)
+        # check identity of cached and retrieved
+        obj1.cache()
+        self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
         # check identity of saved and retrieved
         # check identity of saved and retrieved
-        obj.save(self.db_conn)
-        self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
-        # check create=True acts like normal instantiation (sans saving)
-        by_id_created = self.checked_class.by_id(self.db_conn, id2,
-                                                 create=True)
-        # pylint: disable=not-callable
-        self.assertEqual(self.checked_class(id2), by_id_created)
-        self.check_storage([obj])
+        obj2 = self.checked_class(id2, **self.default_init_kwargs)
+        obj2.save(self.db_conn)
+        self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
+        # obj1.save(self.db_conn)
+        # self.check_identity_with_cache_and_db([obj1, obj2])
 
 
-    def check_from_table_row(self, *args: Any) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+    @_within_checked_class
+    def test_by_id_or_create(self) -> None:
+        """Test .by_id_or_create."""
+        # check .by_id_or_create acts like normal instantiation (sans saving)
         id_ = self.default_ids[0]
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        if not self.checked_class.can_create_by_id:
+            with self.assertRaises(HandledException):
+                self.checked_class.by_id_or_create(self.db_conn, id_)
+        # check .by_id_or_create fails if wrong class
+        else:
+            by_id_created = self.checked_class.by_id_or_create(self.db_conn,
+                                                               id_)
+            with self.assertRaises(NotFoundException):
+                self.checked_class.by_id(self.db_conn, id_)
+            self.assertEqual(self.checked_class(id_), by_id_created)
+
+    @_within_checked_class
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class directly from DB."""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
             hash_original = hash(obj)
             hash_original = hash(obj)
+            attr_name = self.checked_class.to_save[-1]
+            attr = getattr(obj, attr_name)
+            if isinstance(attr, (int, float)):
+                setattr(obj, attr_name, attr + 1)
+            elif isinstance(attr, str):
+                setattr(obj, attr_name, attr + "_")
+            elif isinstance(attr, bool):
+                setattr(obj, attr_name, not attr)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
             self.assertEqual(hash_original, hash(retrieved))
             self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
@@ -157,34 +193,42 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
-    def check_all(self) -> tuple[Any, Any, Any]:
-        """Test .all()."""
-        # pylint: disable=not-callable
-        item1 = self.checked_class(self.default_ids[0])
-        item2 = self.checked_class(self.default_ids[1])
-        item3 = self.checked_class(self.default_ids[2])
-        # check pre-save .all() returns empty list
+    @_within_checked_class
+    def test_all(self) -> None:
+        """Test .all() and its relation to cache and savings."""
+        id_1, id_2, id_3 = self.default_ids
+        item1 = self.checked_class(id_1, **self.default_init_kwargs)
+        item2 = self.checked_class(id_2, **self.default_init_kwargs)
+        item3 = self.checked_class(id_3, **self.default_init_kwargs)
+        # check .all() returns empty list on un-cached items
         self.assertEqual(self.checked_class.all(self.db_conn), [])
         self.assertEqual(self.checked_class.all(self.db_conn), [])
-        # check that all() shows all saved, but no unsaved items
-        item1.save(self.db_conn)
+        # check that all() shows only cached/saved items
+        item1.cache()
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
-        return item1, item2, item3
 
 
-    def check_singularity(self, defaulting_field: str,
-                          non_default_value: Any, *args: Any) -> None:
+    @_within_checked_class
+    def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
-        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
         obj.save(self.db_conn)
-        setattr(obj, defaulting_field, non_default_value)
+        attr_name = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
-        self.assertEqual(non_default_value,
-                         getattr(retrieved, defaulting_field))
+        self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
@@ -203,7 +247,7 @@ class TestCaseWithDB(TestCase):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
 
 
 class TestCaseWithServer(TestCaseWithDB):
 
 
 class TestCaseWithServer(TestCaseWithDB):
@@ -269,20 +313,30 @@ class TestCaseWithServer(TestCaseWithDB):
                         f'/process?id={id_}')
         return form_data
 
                         f'/process?id={id_}')
         return form_data
 
-    @staticmethod
-    def blank_history_keys_in(d: dict[str, object]) -> None:
-        """Re-write "history" object keys to bracketed integer strings."""
-        def walk_tree(d: Any) -> Any:
-            if isinstance(d, dict):
-                if 'history' in d.keys():
-                    vals = d['history'].values()
-                    history = {}
-                    for i, val in enumerate(vals):
-                        history[f'[{i}]'] = val
-                    d['history'] = history
-                for k in list(d.keys()):
-                    walk_tree(d[k])
-            elif isinstance(d, list):
-                d[:] = [walk_tree(i) for i in d]
-            return d
-        walk_tree(d)
+    def check_json_get(self, path: str, expected: dict[str, object]) -> None:
+        """Compare JSON on GET path with expected.
+
+        To simplify comparison of VersionedAttribute histories, transforms
+        timestamp keys of VersionedAttribute history keys into integers
+        counting chronologically forward from 0.
+        """
+        def rewrite_history_keys_in(item: Any) -> Any:
+            if isinstance(item, dict):
+                if '_versioned' in item.keys():
+                    for k in item['_versioned']:
+                        vals = item['_versioned'][k].values()
+                        history = {}
+                        for i, val in enumerate(vals):
+                            history[i] = val
+                        item['_versioned'][k] = history
+                for k in list(item.keys()):
+                    rewrite_history_keys_in(item[k])
+            elif isinstance(item, list):
+                item[:] = [rewrite_history_keys_in(i) for i in item]
+            return item
+        self.conn.request('GET', path)
+        response = self.conn.getresponse()
+        self.assertEqual(response.status, 200)
+        retrieved = json_loads(response.read().decode())
+        rewrite_history_keys_in(retrieved)
+        self.assertEqual(expected, retrieved)