class TestsWithDB(TestCaseWithDB, TestCaseSansDB):
"""Tests requiring DB, but not server setup.
- NB: We subclass TestCaseSansDB too, to pull in its .test_id_setting, which
- for Todo wouldn't run without a DB being set up due to the need for
+ NB: We subclass TestCaseSansDB too, to pull in its .test_id_validation,
+ which for Todo wouldn't run without a DB being set up due to the need for
Processes with set IDs.
"""
checked_class = Todo
'date': '2024-01-01'}
# solely used for TestCaseSansDB.test_id_setting
default_init_args = [None, False, '2024-01-01']
- do_id_test = True
def setUp(self) -> None:
super().setUp()
from plomtask.exceptions import NotFoundException, HandledException
+def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+ def wrapper(self: TestCase) -> None:
+ if hasattr(self, 'checked_class'):
+ f(self)
+ return wrapper
+
+
class TestCaseSansDB(TestCase):
"""Tests requiring no DB setup."""
checked_class: Any
- do_id_test: bool = False
default_init_args: list[Any] = []
versioned_defaults_to_test: dict[str, str | float] = {}
+ legal_ids = [1, 5]
+ illegal_ids = [0]
- def test_id_setting(self) -> None:
- """Test .id_ being set and its legal range being enforced."""
- if not self.do_id_test:
- return
- with self.assertRaises(HandledException):
- self.checked_class(0, *self.default_init_args)
- obj = self.checked_class(5, *self.default_init_args)
- self.assertEqual(obj.id_, 5)
+ @_within_checked_class
+ def test_id_validation(self) -> None:
+ """Test .id_ validation/setting."""
+ for id_ in self.illegal_ids:
+ with self.assertRaises(HandledException):
+ self.checked_class(id_, *self.default_init_args)
+ for id_ in self.legal_ids:
+ obj = self.checked_class(id_, *self.default_init_args)
+ self.assertEqual(obj.id_, id_)
+ @_within_checked_class
def test_versioned_defaults(self) -> None:
"""Test defaults of VersionedAttributes."""
- if len(self.versioned_defaults_to_test) == 0:
- return
- obj = self.checked_class(1, *self.default_init_args)
+ id_ = self.legal_ids[0]
+ obj = self.checked_class(id_, *self.default_init_args)
for k, v in self.versioned_defaults_to_test.items():
self.assertEqual(getattr(obj, k).newest, v)
self.db_conn.close()
remove_file(self.db_file.path)
- @staticmethod
- def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
- def wrapper(self: TestCaseWithDB) -> None:
- if hasattr(self, 'checked_class'):
- f(self)
- return wrapper
-
def _load_from_db(self, id_: int | str) -> list[object]:
db_found: list[object] = []
for row in self.db_conn.row_where(self.checked_class.table_name,