home · contact · privacy
Some test utils refactoring.
[plomtask] / tests / utils.py
index 86d049dac2387b97f441d228ecc3c33e9bfe6552..9d3d11d9f841290fed50b03c8d33acea7f5248ac 100644 (file)
@@ -1,12 +1,13 @@
 """Shared test utilities."""
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
-from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
@@ -61,16 +62,30 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if hasattr(self, 'checked_class'):
+                f(self)
+        return wrapper
+
+    @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
-        if not hasattr(self, 'checked_class'):
-            return
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
-        for k, v in self.test_versioneds.items():
-            self.check_saving_of_versioned(k, v)
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
 
     def check_storage(self, content: list[Any]) -> None:
         """Test cache and DB equal content."""
 
     def check_storage(self, content: list[Any]) -> None:
         """Test cache and DB equal content."""
@@ -101,18 +116,6 @@ class TestCaseWithDB(TestCase):
         # check saving stored properly in cache and DB
         self.check_storage([obj])
 
         # check saving stored properly in cache and DB
         self.check_storage([obj])
 
-    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
-        """Test owner's versioned attributes."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
-        attr = getattr(retrieved, attr_name)
-        self.assertEqual(sorted(attr.history.values()), vals)
-
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
         # check failure if not yet saved
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
         # check failure if not yet saved
@@ -130,17 +133,32 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
-    def check_from_table_row(self, *args: Any) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+    @_within_checked_class
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class directly from DB."""
         id_ = self.default_ids[0]
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
             hash_original = hash(obj)
             hash_original = hash(obj)
+            attr_name = self.checked_class.to_save[-1]
+            attr = getattr(obj, attr_name)
+            if isinstance(attr, (int, float)):
+                setattr(obj, attr_name, attr + 1)
+            elif isinstance(attr, str):
+                setattr(obj, attr_name, attr + "_")
+            elif isinstance(attr, bool):
+                setattr(obj, attr_name, not attr)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
             self.assertEqual(hash_original, hash(retrieved))
             self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
@@ -158,34 +176,42 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
-    def check_all(self) -> tuple[Any, Any, Any]:
-        """Test .all()."""
-        # pylint: disable=not-callable
-        item1 = self.checked_class(self.default_ids[0])
-        item2 = self.checked_class(self.default_ids[1])
-        item3 = self.checked_class(self.default_ids[2])
-        # check pre-save .all() returns empty list
+    @_within_checked_class
+    def test_all(self) -> None:
+        """Test .all() and its relation to cache and savings."""
+        id_1, id_2, id_3 = self.default_ids
+        item1 = self.checked_class(id_1, **self.default_init_kwargs)
+        item2 = self.checked_class(id_2, **self.default_init_kwargs)
+        item3 = self.checked_class(id_3, **self.default_init_kwargs)
+        # check .all() returns empty list on un-cached items
         self.assertEqual(self.checked_class.all(self.db_conn), [])
         self.assertEqual(self.checked_class.all(self.db_conn), [])
-        # check that all() shows all saved, but no unsaved items
-        item1.save(self.db_conn)
+        # check that all() shows only cached/saved items
+        item1.cache()
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
-        return item1, item2, item3
 
 
-    def check_singularity(self, defaulting_field: str,
-                          non_default_value: Any, *args: Any) -> None:
+    @_within_checked_class
+    def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
-        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
         obj.save(self.db_conn)
-        setattr(obj, defaulting_field, non_default_value)
+        attr_name = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
-        self.assertEqual(non_default_value,
-                         getattr(retrieved, defaulting_field))
+        self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
@@ -274,22 +300,23 @@ class TestCaseWithServer(TestCaseWithDB):
         """Compare JSON on GET path with expected.
 
         To simplify comparison of VersionedAttribute histories, transforms
         """Compare JSON on GET path with expected.
 
         To simplify comparison of VersionedAttribute histories, transforms
-        keys under "history"-named dicts into bracketed integer strings
-        counting upwards in chronology.
+        timestamp keys of VersionedAttribute history keys into integers
+        counting chronologically forward from 0.
         """
         def rewrite_history_keys_in(item: Any) -> Any:
             if isinstance(item, dict):
         """
         def rewrite_history_keys_in(item: Any) -> Any:
             if isinstance(item, dict):
-                if 'history' in item.keys():
-                    vals = item['history'].values()
-                    history = {}
-                    for i, val in enumerate(vals):
-                        history[f'[{i}]'] = val
-                    item['history'] = history
+                if '_versioned' in item.keys():
+                    for k in item['_versioned']:
+                        vals = item['_versioned'][k].values()
+                        history = {}
+                        for i, val in enumerate(vals):
+                            history[i] = val
+                        item['_versioned'][k] = history
                 for k in list(item.keys()):
                     rewrite_history_keys_in(item[k])
             elif isinstance(item, list):
                 item[:] = [rewrite_history_keys_in(i) for i in item]
                 for k in list(item.keys()):
                     rewrite_history_keys_in(item[k])
             elif isinstance(item, list):
                 item[:] = [rewrite_history_keys_in(i) for i in item]
-            return item 
+            return item
         self.conn.request('GET', path)
         response = self.conn.getresponse()
         self.assertEqual(response.status, 200)
         self.conn.request('GET', path)
         response = self.conn.getresponse()
         self.assertEqual(response.status, 200)