"""Shared test utilities."""
+from __future__ import annotations
from unittest import TestCase
+from typing import Mapping, Any, Callable
from threading import Thread
from http.client import HTTPConnection
from json import loads as json_loads
from urllib.parse import urlencode
from uuid import uuid4
from os import remove as remove_file
-from typing import Mapping, Any
from plomtask.db import DatabaseFile, DatabaseConnection
from plomtask.http import TaskHandler, TaskServer
from plomtask.processes import Process, ProcessStep
self.db_conn.close()
remove_file(self.db_file.path)
+ @staticmethod
+ def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+ def wrapper(self: TestCaseWithDB) -> None:
+ if hasattr(self, 'checked_class'):
+ f(self)
+ return wrapper
+
+ @_within_checked_class
def test_saving_and_caching(self) -> None:
"""Test storage and initialization of instances and attributes."""
- if not hasattr(self, 'checked_class'):
- return
self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
obj = self.checked_class(None, **self.default_init_kwargs)
obj.save(self.db_conn)
self.assertEqual(obj.id_, 2)
- for k, v in self.test_versioneds.items():
- self.check_saving_of_versioned(k, v)
+ for attr_name, type_ in self.test_versioneds.items():
+ owner = self.checked_class(None)
+ vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+ attr = getattr(owner, attr_name)
+ attr.set(vals[0])
+ attr.set(vals[1])
+ owner.save(self.db_conn)
+ retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+ attr = getattr(retrieved, attr_name)
+ self.assertEqual(sorted(attr.history.values()), vals)
def check_storage(self, content: list[Any]) -> None:
"""Test cache and DB equal content."""
# check saving stored properly in cache and DB
self.check_storage([obj])
- def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
- """Test owner's versioned attributes."""
- owner = self.checked_class(None)
- vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
- attr = getattr(owner, attr_name)
- attr.set(vals[0])
- attr.set(vals[1])
- owner.save(self.db_conn)
- retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
- attr = getattr(retrieved, attr_name)
- self.assertEqual(sorted(attr.history.values()), vals)
-
def check_by_id(self) -> None:
"""Test .by_id(), including creation."""
# check failure if not yet saved
self.assertEqual(self.checked_class(id2), by_id_created)
self.check_storage([obj])
+ @_within_checked_class
def test_from_table_row(self) -> None:
"""Test .from_table_row() properly reads in class directly from DB."""
- if not hasattr(self, 'checked_class'):
- return
id_ = self.default_ids[0]
obj = self.checked_class(id_, **self.default_init_kwargs)
obj.save(self.db_conn)
attr = getattr(retrieved, attr_name)
self.assertEqual(sorted(attr.history.values()), vals)
+ @_within_checked_class
def test_all(self) -> None:
"""Test .all() and its relation to cache and savings."""
- if not hasattr(self, 'checked_class'):
- return
id_1, id_2, id_3 = self.default_ids
item1 = self.checked_class(id_1, **self.default_init_kwargs)
item2 = self.checked_class(id_2, **self.default_init_kwargs)
self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
sorted([item1, item2, item3]))
+ @_within_checked_class
def test_singularity(self) -> None:
"""Test pointers made for single object keep pointing to it."""
- if not hasattr(self, 'checked_class'):
- return
id1 = self.default_ids[0]
obj = self.checked_class(id1, **self.default_init_kwargs)
obj.save(self.db_conn)