home · contact · privacy
Some test utils refactoring.
authorChristian Heller <c.heller@plomlompom.de>
Tue, 18 Jun 2024 01:58:37 +0000 (03:58 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Tue, 18 Jun 2024 01:58:37 +0000 (03:58 +0200)
tests/utils.py

index d8bd247bd8391682e5f7577c56f5e77d86216e9f..9d3d11d9f841290fed50b03c8d33acea7f5248ac 100644 (file)
@@ -1,12 +1,13 @@
 """Shared test utilities."""
+from __future__ import annotations
 from unittest import TestCase
+from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
 from os import remove as remove_file
-from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
@@ -61,16 +62,30 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if hasattr(self, 'checked_class'):
+                f(self)
+        return wrapper
+
+    @_within_checked_class
     def test_saving_and_caching(self) -> None:
         """Test storage and initialization of instances and attributes."""
-        if not hasattr(self, 'checked_class'):
-            return
         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
         obj = self.checked_class(None, **self.default_init_kwargs)
         obj.save(self.db_conn)
         self.assertEqual(obj.id_, 2)
-        for k, v in self.test_versioneds.items():
-            self.check_saving_of_versioned(k, v)
+        for attr_name, type_ in self.test_versioneds.items():
+            owner = self.checked_class(None)
+            vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+            attr = getattr(owner, attr_name)
+            attr.set(vals[0])
+            attr.set(vals[1])
+            owner.save(self.db_conn)
+            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
 
     def check_storage(self, content: list[Any]) -> None:
         """Test cache and DB equal content."""
@@ -101,18 +116,6 @@ class TestCaseWithDB(TestCase):
         # check saving stored properly in cache and DB
         self.check_storage([obj])
 
-    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
-        """Test owner's versioned attributes."""
-        owner = self.checked_class(None)
-        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        owner.save(self.db_conn)
-        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
-        attr = getattr(retrieved, attr_name)
-        self.assertEqual(sorted(attr.history.values()), vals)
-
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
         # check failure if not yet saved
@@ -130,10 +133,9 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
+    @_within_checked_class
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
-        if not hasattr(self, 'checked_class'):
-            return
         id_ = self.default_ids[0]
         obj = self.checked_class(id_, **self.default_init_kwargs)
         obj.save(self.db_conn)
@@ -174,10 +176,9 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
+    @_within_checked_class
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
-        if not hasattr(self, 'checked_class'):
-            return
         id_1, id_2, id_3 = self.default_ids
         item1 = self.checked_class(id_1, **self.default_init_kwargs)
         item2 = self.checked_class(id_2, **self.default_init_kwargs)
@@ -193,10 +194,9 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
 
+    @_within_checked_class
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
-        if not hasattr(self, 'checked_class'):
-            return
         id1 = self.default_ids[0]
         obj = self.checked_class(id1, **self.default_init_kwargs)
         obj.save(self.db_conn)