home · contact · privacy
Some test utils refactoring.
[plomtask] / tests / utils.py
index bb37270ca68afde8e00df09a19892011fd371a29..9d3d11d9f841290fed50b03c8d33acea7f5248ac 100644 (file)
@@ -1,11 +1,13 @@
 """Shared test utilities."""
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from threading import Thread
 from http.client import HTTPConnection
+from json import loads as json_loads
 from urllib.parse import urlencode
 from urllib.parse import urlencode
-from datetime import datetime
+from uuid import uuid4
 from os import remove as remove_file
 from os import remove as remove_file
-from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
@@ -18,18 +20,25 @@ from plomtask.exceptions import NotFoundException, HandledException
 class TestCaseSansDB(TestCase):
     """Tests requiring no DB setup."""
     checked_class: Any
 class TestCaseSansDB(TestCase):
     """Tests requiring no DB setup."""
     checked_class: Any
+    do_id_test: bool = False
+    default_init_args: list[Any] = []
+    versioned_defaults_to_test: dict[str, str | float] = {}
 
 
-    def check_id_setting(self, *args: Any) -> None:
+    def test_id_setting(self) -> None:
         """Test .id_ being set and its legal range being enforced."""
         """Test .id_ being set and its legal range being enforced."""
+        if not self.do_id_test:
+            return
         with self.assertRaises(HandledException):
         with self.assertRaises(HandledException):
-            self.checked_class(0, *args)
-        obj = self.checked_class(5, *args)
+            self.checked_class(0, *self.default_init_args)
+        obj = self.checked_class(5, *self.default_init_args)
         self.assertEqual(obj.id_, 5)
 
         self.assertEqual(obj.id_, 5)
 
-    def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
+    def test_versioned_defaults(self) -> None:
         """Test defaults of VersionedAttributes."""
         """Test defaults of VersionedAttributes."""
-        obj = self.checked_class(None)
-        for k, v in attrs.items():
+        if len(self.versioned_defaults_to_test) == 0:
+            return
+        obj = self.checked_class(1, *self.default_init_args)
+        for k, v in self.versioned_defaults_to_test.items():
             self.assertEqual(getattr(obj, k).newest, v)
 
 
             self.assertEqual(getattr(obj, k).newest, v)
 
 
@@ -37,6 +46,8 @@ class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
     checked_class: Any
     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
     """Module tests not requiring DB setup."""
     checked_class: Any
     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
+    default_init_kwargs: dict[str, Any] = {}
+    test_versioneds: dict[str, type] = {}
 
     def setUp(self) -> None:
         Condition.empty_cache()
 
     def setUp(self) -> None:
         Condition.empty_cache()
@@ -44,21 +55,45 @@ class TestCaseWithDB(TestCase):
         Process.empty_cache()
         ProcessStep.empty_cache()
         Todo.empty_cache()
         Process.empty_cache()
         ProcessStep.empty_cache()
         Todo.empty_cache()
-        timestamp = datetime.now().timestamp()
-        self.db_file = DatabaseFile(f'test_db:{timestamp}')
-        self.db_file.remake()
+        self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
         self.db_conn = DatabaseConnection(self.db_file)
 
     def tearDown(self) -> None:
         self.db_conn.close()
         remove_file(self.db_file.path)
 
         self.db_conn = DatabaseConnection(self.db_file)
 
     def tearDown(self) -> None:
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if hasattr(self, 'checked_class'):
+                f(self)
+        return wrapper
+
+    @_within_checked_class
+    def test_saving_and_caching(self) -> None:
+        """Test storage and initialization of instances and attributes."""
+        self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
+        obj = self.checked_class(None, **self.default_init_kwargs)
+        obj.save(self.db_conn)
+        self.assertEqual(obj.id_, 2)
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
+
     def check_storage(self, content: list[Any]) -> None:
         """Test cache and DB equal content."""
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
         self.assertEqual(self.checked_class.get_cache(), expected_cache)
     def check_storage(self, content: list[Any]) -> None:
         """Test cache and DB equal content."""
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
         self.assertEqual(self.checked_class.get_cache(), expected_cache)
+        hashes_content = [hash(x) for x in content]
         db_found: list[Any] = []
         for item in content:
             assert isinstance(item.id_, type(self.default_ids[0]))
         db_found: list[Any] = []
         for item in content:
             assert isinstance(item.id_, type(self.default_ids[0]))
@@ -66,32 +101,20 @@ class TestCaseWithDB(TestCase):
                                               'id', item.id_):
                 db_found += [self.checked_class.from_table_row(self.db_conn,
                                                                row)]
                                               'id', item.id_):
                 db_found += [self.checked_class.from_table_row(self.db_conn,
                                                                row)]
-        self.assertEqual(sorted(content), sorted(db_found))
+        hashes_db_found = [hash(x) for x in db_found]
+        self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
 
 
-    def check_saving_and_caching(self, **kwargs: Any) -> Any:
+    def check_saving_and_caching(self, **kwargs: Any) -> None:
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
         self.check_storage([])
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
         self.check_storage([])
-        # check saving stores in cache and DB
+        # check saving sets core attributes properly
         obj.save(self.db_conn)
         obj.save(self.db_conn)
-        self.check_storage([obj])
-        # check core attributes set properly (and not unset by saving)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
-
-    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
-        """Test owner's versioned attributes."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        owner.uncache()
-        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
-        attr = getattr(retrieved, attr_name)
-        self.assertEqual(sorted(attr.history.values()), vals)
+        # check saving stored properly in cache and DB
+        self.check_storage([obj])
 
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
 
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
@@ -110,17 +133,34 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
-    def check_from_table_row(self, *args: Any) -> None:
-        """Test .from_table_row() properly reads in class from DB"""
+    @_within_checked_class
+    def test_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class directly from DB."""
         id_ = self.default_ids[0]
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
         obj.save(self.db_conn)
         assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
+            # check .from_table_row reproduces state saved, no matter if obj
+            # later changed (with caching even)
+            hash_original = hash(obj)
+            attr_name = self.checked_class.to_save[-1]
+            attr = getattr(obj, attr_name)
+            if isinstance(attr, (int, float)):
+                setattr(obj, attr_name, attr + 1)
+            elif isinstance(attr, str):
+                setattr(obj, attr_name, attr + "_")
+            elif isinstance(attr, bool):
+                setattr(obj, attr_name, not attr)
+            obj.cache()
+            to_cmp = getattr(obj, attr_name)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
-            self.assertEqual(obj, retrieved)
-            self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
+            self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
+            self.assertEqual(hash_original, hash(retrieved))
+            # check cache contains what .from_table_row just produced
+            self.assertEqual({retrieved.id_: retrieved},
+                             self.checked_class.get_cache())
 
     def check_versioned_from_table_row(self, attr_name: str,
                                        type_: type) -> None:
 
     def check_versioned_from_table_row(self, attr_name: str,
                                        type_: type) -> None:
@@ -136,34 +176,42 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
-    def check_all(self) -> tuple[Any, Any, Any]:
-        """Test .all()."""
-        # pylint: disable=not-callable
-        item1 = self.checked_class(self.default_ids[0])
-        item2 = self.checked_class(self.default_ids[1])
-        item3 = self.checked_class(self.default_ids[2])
-        # check pre-save .all() returns empty list
+    @_within_checked_class
+    def test_all(self) -> None:
+        """Test .all() and its relation to cache and savings."""
+        id_1, id_2, id_3 = self.default_ids
+        item1 = self.checked_class(id_1, **self.default_init_kwargs)
+        item2 = self.checked_class(id_2, **self.default_init_kwargs)
+        item3 = self.checked_class(id_3, **self.default_init_kwargs)
+        # check .all() returns empty list on un-cached items
         self.assertEqual(self.checked_class.all(self.db_conn), [])
         self.assertEqual(self.checked_class.all(self.db_conn), [])
-        # check that all() shows all saved, but no unsaved items
-        item1.save(self.db_conn)
+        # check that all() shows only cached/saved items
+        item1.cache()
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
-        return item1, item2, item3
 
 
-    def check_singularity(self, defaulting_field: str,
-                          non_default_value: Any, *args: Any) -> None:
+    @_within_checked_class
+    def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
-        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
         obj.save(self.db_conn)
-        setattr(obj, defaulting_field, non_default_value)
+        attr_name = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
-        self.assertEqual(non_default_value,
-                         getattr(retrieved, defaulting_field))
+        self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
@@ -196,6 +244,7 @@ class TestCaseWithServer(TestCaseWithDB):
         self.server_thread.start()
         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
                                    self.httpd.server_address[1])
         self.server_thread.start()
         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
                                    self.httpd.server_address[1])
+        self.httpd.set_json_mode()
 
     def tearDown(self) -> None:
         self.httpd.shutdown()
 
     def tearDown(self) -> None:
         self.httpd.shutdown()
@@ -243,5 +292,34 @@ class TestCaseWithServer(TestCaseWithDB):
         """POST basic Process."""
         if not form_data:
             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
         """POST basic Process."""
         if not form_data:
             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
-        self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
+        self.check_post(form_data, f'/process?id={id_}', 302,
+                        f'/process?id={id_}')
         return form_data
         return form_data
+
+    def check_json_get(self, path: str, expected: dict[str, object]) -> None:
+        """Compare JSON on GET path with expected.
+
+        To simplify comparison of VersionedAttribute histories, transforms
+        timestamp keys of VersionedAttribute history keys into integers
+        counting chronologically forward from 0.
+        """
+        def rewrite_history_keys_in(item: Any) -> Any:
+            if isinstance(item, dict):
+                if '_versioned' in item.keys():
+                    for k in item['_versioned']:
+                        vals = item['_versioned'][k].values()
+                        history = {}
+                        for i, val in enumerate(vals):
+                            history[i] = val
+                        item['_versioned'][k] = history
+                for k in list(item.keys()):
+                    rewrite_history_keys_in(item[k])
+            elif isinstance(item, list):
+                item[:] = [rewrite_history_keys_in(i) for i in item]
+            return item
+        self.conn.request('GET', path)
+        response = self.conn.getresponse()
+        self.assertEqual(response.status, 200)
+        retrieved = json_loads(response.read().decode())
+        rewrite_history_keys_in(retrieved)
+        self.assertEqual(expected, retrieved)