"""Shared test utilities."""
+from __future__ import annotations
from unittest import TestCase
+from typing import Mapping, Any, Callable
from threading import Thread
from http.client import HTTPConnection
from json import loads as json_loads
from urllib.parse import urlencode
from uuid import uuid4
from os import remove as remove_file
-from typing import Mapping, Any
from plomtask.db import DatabaseFile, DatabaseConnection
from plomtask.http import TaskHandler, TaskServer
from plomtask.processes import Process, ProcessStep
self.db_conn.close()
remove_file(self.db_file.path)
+ @staticmethod
+ def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+ def wrapper(self: TestCaseWithDB) -> None:
+ if hasattr(self, 'checked_class'):
+ f(self)
+ return wrapper
+
+ @_within_checked_class
def test_saving_and_caching(self) -> None:
"""Test storage and initialization of instances and attributes."""
- if not hasattr(self, 'checked_class'):
- return
self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
obj = self.checked_class(None, **self.default_init_kwargs)
obj.save(self.db_conn)
self.assertEqual(obj.id_, 2)
- for k, v in self.test_versioneds.items():
- self.check_saving_of_versioned(k, v)
+ for attr_name, type_ in self.test_versioneds.items():
+ owner = self.checked_class(None)
+ vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+ attr = getattr(owner, attr_name)
+ attr.set(vals[0])
+ attr.set(vals[1])
+ owner.save(self.db_conn)
+ retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+ attr = getattr(retrieved, attr_name)
+ self.assertEqual(sorted(attr.history.values()), vals)
- def check_storage(self, content: list[Any]) -> None:
- """Test cache and DB equal content."""
+ def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+ """Test both cache and DB equal content."""
expected_cache = {}
for item in content:
expected_cache[item.id_] = item
"""Test instance.save in its core without relations."""
obj = self.checked_class(**kwargs) # pylint: disable=not-callable
# check object init itself doesn't store anything yet
- self.check_storage([])
+ self.check_identity_with_cache_and_db([])
# check saving sets core attributes properly
obj.save(self.db_conn)
for key, value in kwargs.items():
self.assertEqual(getattr(obj, key), value)
# check saving stored properly in cache and DB
- self.check_storage([obj])
-
- def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
- """Test owner's versioned attributes."""
- owner = self.checked_class(None)
- vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
- attr = getattr(owner, attr_name)
- attr.set(vals[0])
- attr.set(vals[1])
- owner.save(self.db_conn)
- retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
- attr = getattr(retrieved, attr_name)
- self.assertEqual(sorted(attr.history.values()), vals)
+ self.check_identity_with_cache_and_db([obj])
- def check_by_id(self) -> None:
- """Test .by_id(), including creation."""
+ @_within_checked_class
+ def test_by_id(self) -> None:
+ """Test .by_id()."""
+ id1, id2, _ = self.default_ids
# check failure if not yet saved
- id1, id2 = self.default_ids[0], self.default_ids[1]
- obj = self.checked_class(id1) # pylint: disable=not-callable
+ obj1 = self.checked_class(id1, **self.default_init_kwargs)
with self.assertRaises(NotFoundException):
self.checked_class.by_id(self.db_conn, id1)
+ # check identity of cached and retrieved
+ obj1.cache()
+ self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
# check identity of saved and retrieved
- obj.save(self.db_conn)
- self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
- # check create=True acts like normal instantiation (sans saving)
- by_id_created = self.checked_class.by_id(self.db_conn, id2,
- create=True)
- # pylint: disable=not-callable
- self.assertEqual(self.checked_class(id2), by_id_created)
- self.check_storage([obj])
+ obj2 = self.checked_class(id2, **self.default_init_kwargs)
+ obj2.save(self.db_conn)
+ self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
+ # obj1.save(self.db_conn)
+ # self.check_identity_with_cache_and_db([obj1, obj2])
+ @_within_checked_class
+ def test_by_id_or_create(self) -> None:
+ """Test .by_id_or_create."""
+ # check .by_id_or_create acts like normal instantiation (sans saving)
+ id_ = self.default_ids[0]
+ if not self.checked_class.can_create_by_id:
+ with self.assertRaises(HandledException):
+ self.checked_class.by_id_or_create(self.db_conn, id_)
+ # check .by_id_or_create fails if wrong class
+ else:
+ by_id_created = self.checked_class.by_id_or_create(self.db_conn,
+ id_)
+ with self.assertRaises(NotFoundException):
+ self.checked_class.by_id(self.db_conn, id_)
+ self.assertEqual(self.checked_class(id_), by_id_created)
+
+ @_within_checked_class
def test_from_table_row(self) -> None:
"""Test .from_table_row() properly reads in class directly from DB."""
- if not hasattr(self, 'checked_class'):
- return
id_ = self.default_ids[0]
obj = self.checked_class(id_, **self.default_init_kwargs)
obj.save(self.db_conn)
attr = getattr(retrieved, attr_name)
self.assertEqual(sorted(attr.history.values()), vals)
+ @_within_checked_class
def test_all(self) -> None:
"""Test .all() and its relation to cache and savings."""
- if not hasattr(self, 'checked_class'):
- return
id_1, id_2, id_3 = self.default_ids
item1 = self.checked_class(id_1, **self.default_init_kwargs)
item2 = self.checked_class(id_2, **self.default_init_kwargs)
self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
sorted([item1, item2, item3]))
+ @_within_checked_class
def test_singularity(self) -> None:
"""Test pointers made for single object keep pointing to it."""
- if not hasattr(self, 'checked_class'):
- return
id1 = self.default_ids[0]
obj = self.checked_class(id1, **self.default_init_kwargs)
obj.save(self.db_conn)
obj.remove(self.db_conn)
obj.save(self.db_conn)
obj.remove(self.db_conn)
- self.check_storage([])
+ self.check_identity_with_cache_and_db([])
class TestCaseWithServer(TestCaseWithDB):