from plomtask.exceptions import NotFoundException, HandledException
+def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+ def wrapper(self: TestCase) -> None:
+ if hasattr(self, 'checked_class'):
+ f(self)
+ return wrapper
+
+
class TestCaseSansDB(TestCase):
"""Tests requiring no DB setup."""
checked_class: Any
- do_id_test: bool = False
default_init_args: list[Any] = []
versioned_defaults_to_test: dict[str, str | float] = {}
+ legal_ids = [1, 5]
+ illegal_ids = [0]
- def test_id_setting(self) -> None:
- """Test .id_ being set and its legal range being enforced."""
- if not self.do_id_test:
- return
- with self.assertRaises(HandledException):
- self.checked_class(0, *self.default_init_args)
- obj = self.checked_class(5, *self.default_init_args)
- self.assertEqual(obj.id_, 5)
+ @_within_checked_class
+ def test_id_validation(self) -> None:
+ """Test .id_ validation/setting."""
+ for id_ in self.illegal_ids:
+ with self.assertRaises(HandledException):
+ self.checked_class(id_, *self.default_init_args)
+ for id_ in self.legal_ids:
+ obj = self.checked_class(id_, *self.default_init_args)
+ self.assertEqual(obj.id_, id_)
+ @_within_checked_class
def test_versioned_defaults(self) -> None:
"""Test defaults of VersionedAttributes."""
- if len(self.versioned_defaults_to_test) == 0:
- return
- obj = self.checked_class(1, *self.default_init_args)
+ id_ = self.legal_ids[0]
+ obj = self.checked_class(id_, *self.default_init_args)
for k, v in self.versioned_defaults_to_test.items():
self.assertEqual(getattr(obj, k).newest, v)
self.db_conn.close()
remove_file(self.db_file.path)
- @staticmethod
- def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
- def wrapper(self: TestCaseWithDB) -> None:
- if hasattr(self, 'checked_class'):
- f(self)
- return wrapper
-
def _load_from_db(self, id_: int | str) -> list[object]:
db_found: list[object] = []
for row in self.db_conn.row_where(self.checked_class.table_name,
@_within_checked_class
def test_by_id_or_create(self) -> None:
"""Test .by_id_or_create."""
- # check .by_id_or_create acts like normal instantiation (sans saving)
- id_ = self.default_ids[0]
+ # check .by_id_or_create fails if wrong class
if not self.checked_class.can_create_by_id:
with self.assertRaises(HandledException):
- self.checked_class.by_id_or_create(self.db_conn, id_)
- # check .by_id_or_create fails if wrong class
- else:
- by_id_created = self.checked_class.by_id_or_create(self.db_conn,
- id_)
- with self.assertRaises(NotFoundException):
- self.checked_class.by_id(self.db_conn, id_)
- self.assertEqual(self.checked_class(id_), by_id_created)
+ self.checked_class.by_id_or_create(self.db_conn, None)
+ return
+ # check ID input of None creates, on saving, ID=1,2,… for int IDs
+ if isinstance(self.default_ids[0], int):
+ for n in range(2):
+ item = self.checked_class.by_id_or_create(self.db_conn, None)
+ self.assertEqual(item.id_, None)
+ item.save(self.db_conn)
+ self.assertEqual(item.id_, n+1)
+ # check .by_id_or_create acts like normal instantiation (sans saving)
+ id_ = self.default_ids[2]
+ item = self.checked_class.by_id_or_create(self.db_conn, id_)
+ self.assertEqual(item.id_, id_)
+ with self.assertRaises(NotFoundException):
+ self.checked_class.by_id(self.db_conn, item.id_)
+ self.assertEqual(self.checked_class(item.id_), item)
@_within_checked_class
def test_from_table_row(self) -> None: