home · contact · privacy
Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
[plomtask] / tests / utils.py
index f76fe33c93fc65d68aa07b7ca04aa0b98c762072..55c948a409dbf1b2c43f775c0d39106ba3ed510d 100644 (file)
@@ -1,12 +1,13 @@
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
-from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
@@ -61,19 +62,33 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if hasattr(self, 'checked_class'):
+                f(self)
+        return wrapper
+
+    @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
-        if not hasattr(self, 'checked_class'):
-            return
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
-        for k, v in self.test_versioneds.items():
-            self.check_saving_of_versioned(k, v)
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
 
-    def check_storage(self, content: list[Any]) -> None:
-        """Test cache and DB equal content."""
+    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+        """Test both cache and DB equal content."""
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
@@ -93,47 +108,51 @@ class TestCaseWithDB(TestCase):
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
         # check saving sets core attributes properly
         obj.save(self.db_conn)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
         # check saving stored properly in cache and DB
-        self.check_storage([obj])
-
-    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
-        """Test owner's versioned attributes."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
-        attr = getattr(retrieved, attr_name)
-        self.assertEqual(sorted(attr.history.values()), vals)
+        self.check_identity_with_cache_and_db([obj])
 
-    def check_by_id(self) -> None:
-        """Test .by_id(), including creation."""
+    @_within_checked_class
+    def test_by_id(self) -> None:
+        """Test .by_id()."""
+        id1, id2, _ = self.default_ids
         # check failure if not yet saved
-        id1, id2 = self.default_ids[0], self.default_ids[1]
-        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        obj1 = self.checked_class(id1, **self.default_init_kwargs)
         with self.assertRaises(NotFoundException):
             self.checked_class.by_id(self.db_conn, id1)
+        # check identity of cached and retrieved
+        obj1.cache()
+        self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
         # check identity of saved and retrieved
-        obj.save(self.db_conn)
-        self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
-        # check create=True acts like normal instantiation (sans saving)
-        by_id_created = self.checked_class.by_id(self.db_conn, id2,
-                                                 create=True)
-        # pylint: disable=not-callable
-        self.assertEqual(self.checked_class(id2), by_id_created)
-        self.check_storage([obj])
+        obj2 = self.checked_class(id2, **self.default_init_kwargs)
+        obj2.save(self.db_conn)
+        self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
+        # obj1.save(self.db_conn)
+        # self.check_identity_with_cache_and_db([obj1, obj2])
 
+    @_within_checked_class
+    def test_by_id_or_create(self) -> None:
+        """Test .by_id_or_create."""
+        # check .by_id_or_create acts like normal instantiation (sans saving)
+        id_ = self.default_ids[0]
+        if not self.checked_class.can_create_by_id:
+            with self.assertRaises(HandledException):
+                self.checked_class.by_id_or_create(self.db_conn, id_)
+        # check .by_id_or_create fails if wrong class
+        else:
+            by_id_created = self.checked_class.by_id_or_create(self.db_conn,
+                                                               id_)
+            with self.assertRaises(NotFoundException):
+                self.checked_class.by_id(self.db_conn, id_)
+            self.assertEqual(self.checked_class(id_), by_id_created)
+
+    @_within_checked_class
     def test_from_table_row(self) -> None:
-        """Test .from_table_row() properly reads in class from DB."""
-        if not hasattr(self, 'checked_class'):
-            return
+        """Test .from_table_row() properly reads in class directly from DB."""
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
@@ -174,34 +193,42 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
-    def check_all(self) -> tuple[Any, Any, Any]:
-        """Test .all()."""
-        # pylint: disable=not-callable
-        item1 = self.checked_class(self.default_ids[0])
-        item2 = self.checked_class(self.default_ids[1])
-        item3 = self.checked_class(self.default_ids[2])
-        # check pre-save .all() returns empty list
+    @_within_checked_class
+    def test_all(self) -> None:
+        """Test .all() and its relation to cache and savings."""
+        id_1, id_2, id_3 = self.default_ids
+        item1 = self.checked_class(id_1, **self.default_init_kwargs)
+        item2 = self.checked_class(id_2, **self.default_init_kwargs)
+        item3 = self.checked_class(id_3, **self.default_init_kwargs)
+        # check .all() returns empty list on un-cached items
         self.assertEqual(self.checked_class.all(self.db_conn), [])
-        # check that all() shows all saved, but no unsaved items
-        item1.save(self.db_conn)
+        # check that all() shows only cached/saved items
+        item1.cache()
         item3.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item3]))
         item2.save(self.db_conn)
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
-        return item1, item2, item3
 
-    def check_singularity(self, defaulting_field: str,
-                          non_default_value: Any, *args: Any) -> None:
+    @_within_checked_class
+    def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
-        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
+        obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
-        setattr(obj, defaulting_field, non_default_value)
+        attr_name = self.checked_class.to_save[-1]
+        attr = getattr(obj, attr_name)
+        new_attr: str | int | float | bool
+        if isinstance(attr, (int, float)):
+            new_attr = attr + 1
+        elif isinstance(attr, str):
+            new_attr = attr + '_'
+        elif isinstance(attr, bool):
+            new_attr = not attr
+        setattr(obj, attr_name, new_attr)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
-        self.assertEqual(non_default_value,
-                         getattr(retrieved, defaulting_field))
+        self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
     def check_versioned_singularity(self) -> None:
         """Test singularity of VersionedAttributes on saving (with .title)."""
@@ -220,7 +247,7 @@ class TestCaseWithDB(TestCase):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
 
 
 class TestCaseWithServer(TestCaseWithDB):