home · contact · privacy
Some test utils refactoring.
authorChristian Heller <c.heller@plomlompom.de>
Tue, 18 Jun 2024 01:58:37 +0000 (03:58 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Tue, 18 Jun 2024 01:58:37 +0000 (03:58 +0200)
tests/utils.py

index d8bd247bd8391682e5f7577c56f5e77d86216e9f..9d3d11d9f841290fed50b03c8d33acea7f5248ac 100644 (file)
@@ -1,12 +1,13 @@
 """Shared test utilities."""
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
-from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
@@ -61,16 +62,30 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if hasattr(self, 'checked_class'):
+                f(self)
+        return wrapper
+
+    @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
-        if not hasattr(self, 'checked_class'):
-            return
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
-        for k, v in self.test_versioneds.items():
-            self.check_saving_of_versioned(k, v)
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
 
     def check_storage(self, content: list[Any]) -> None:
         """Test cache and DB equal content."""
 
     def check_storage(self, content: list[Any]) -> None:
         """Test cache and DB equal content."""
@@ -101,18 +116,6 @@ class TestCaseWithDB(TestCase):
         # check saving stored properly in cache and DB
         self.check_storage([obj])
 
         # check saving stored properly in cache and DB
         self.check_storage([obj])
 
-    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
-        """Test owner's versioned attributes."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
-        attr = getattr(retrieved, attr_name)
-        self.assertEqual(sorted(attr.history.values()), vals)
-
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
         # check failure if not yet saved
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
         # check failure if not yet saved
@@ -130,10 +133,9 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
+    @_within_checked_class
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
-        if not hasattr(self, 'checked_class'):
-            return
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
@@ -174,10 +176,9 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
+    @_within_checked_class
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
-        if not hasattr(self, 'checked_class'):
-            return
         id_1, id_2, id_3 = self.default_ids
         item1 = self.checked_class(id_1, **self.default_init_kwargs)
         item2 = self.checked_class(id_2, **self.default_init_kwargs)
         id_1, id_2, id_3 = self.default_ids
         item1 = self.checked_class(id_1, **self.default_init_kwargs)
         item2 = self.checked_class(id_2, **self.default_init_kwargs)
@@ -193,10 +194,9 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
 
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
 
+    @_within_checked_class
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
-        if not hasattr(self, 'checked_class'):
-            return
         id1 = self.default_ids[0]
         obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)
         id1 = self.default_ids[0]
         obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)