home · contact · privacy
Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
[plomtask] / tests / utils.py
index d8bd247bd8391682e5f7577c56f5e77d86216e9f..55c948a409dbf1b2c43f775c0d39106ba3ed510d 100644 (file)
@@ -1,12 +1,13 @@
 """Shared test utilities."""
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
-from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
@@ -61,19 +62,33 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if hasattr(self, 'checked_class'):
+                f(self)
+        return wrapper
+
+    @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
-        if not hasattr(self, 'checked_class'):
-            return
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
-        for k, v in self.test_versioneds.items():
-            self.check_saving_of_versioned(k, v)
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
 
 
-    def check_storage(self, content: list[Any]) -> None:
-        """Test cache and DB equal content."""
+    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+        """Test both cache and DB equal content."""
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
@@ -93,47 +108,51 @@ class TestCaseWithDB(TestCase):
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
         # check saving sets core attributes properly
         obj.save(self.db_conn)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
         # check saving stored properly in cache and DB
         # check saving sets core attributes properly
         obj.save(self.db_conn)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
         # check saving stored properly in cache and DB
-        self.check_storage([obj])
-
-    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
-        """Test owner's versioned attributes."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
-        attr = getattr(retrieved, attr_name)
-        self.assertEqual(sorted(attr.history.values()), vals)
+        self.check_identity_with_cache_and_db([obj])
 
 
-    def check_by_id(self) -> None:
-        """Test .by_id(), including creation."""
+    @_within_checked_class
+    def test_by_id(self) -> None:
+        """Test .by_id()."""
+        id1, id2, _ = self.default_ids
         # check failure if not yet saved
         # check failure if not yet saved
-        id1, id2 = self.default_ids[0], self.default_ids[1]
-        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        obj1 = self.checked_class(id1, **self.default_init_kwargs)
         with self.assertRaises(NotFoundException):
             self.checked_class.by_id(self.db_conn, id1)
         with self.assertRaises(NotFoundException):
             self.checked_class.by_id(self.db_conn, id1)
+        # check identity of cached and retrieved
+        obj1.cache()
+        self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
         # check identity of saved and retrieved
         # check identity of saved and retrieved
-        obj.save(self.db_conn)
-        self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
-        # check create=True acts like normal instantiation (sans saving)
-        by_id_created = self.checked_class.by_id(self.db_conn, id2,
-                                                 create=True)
-        # pylint: disable=not-callable
-        self.assertEqual(self.checked_class(id2), by_id_created)
-        self.check_storage([obj])
+        obj2 = self.checked_class(id2, **self.default_init_kwargs)
+        obj2.save(self.db_conn)
+        self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
+        # obj1.save(self.db_conn)
+        # self.check_identity_with_cache_and_db([obj1, obj2])
 
 
+    @_within_checked_class
+    def test_by_id_or_create(self) -> None:
+        """Test .by_id_or_create."""
+        # check .by_id_or_create acts like normal instantiation (sans saving)
+        id_ = self.default_ids[0]
+        if not self.checked_class.can_create_by_id:
+            with self.assertRaises(HandledException):
+                self.checked_class.by_id_or_create(self.db_conn, id_)
+        # check .by_id_or_create fails if wrong class
+        else:
+            by_id_created = self.checked_class.by_id_or_create(self.db_conn,
+                                                               id_)
+            with self.assertRaises(NotFoundException):
+                self.checked_class.by_id(self.db_conn, id_)
+            self.assertEqual(self.checked_class(id_), by_id_created)
+
+    @_within_checked_class
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
-        if not hasattr(self, 'checked_class'):
-            return
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
@@ -174,10 +193,9 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
+    @_within_checked_class
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
-        if not hasattr(self, 'checked_class'):
-            return
         id_1, id_2, id_3 = self.default_ids
         item1 = self.checked_class(id_1, **self.default_init_kwargs)
         item2 = self.checked_class(id_2, **self.default_init_kwargs)
         id_1, id_2, id_3 = self.default_ids
         item1 = self.checked_class(id_1, **self.default_init_kwargs)
         item2 = self.checked_class(id_2, **self.default_init_kwargs)
@@ -193,10 +211,9 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
 
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
 
+    @_within_checked_class
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
-        if not hasattr(self, 'checked_class'):
-            return
         id1 = self.default_ids[0]
         obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
         id1 = self.default_ids[0]
         obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
@@ -230,7 +247,7 @@ class TestCaseWithDB(TestCase):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
 
 
 class TestCaseWithServer(TestCaseWithDB):
 
 
 class TestCaseWithServer(TestCaseWithDB):