from plomtask.exceptions import NotFoundException, HandledException
+def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+ def wrapper(self: TestCase) -> None:
+ if hasattr(self, 'checked_class'):
+ f(self)
+ return wrapper
+
+
class TestCaseSansDB(TestCase):
"""Tests requiring no DB setup."""
checked_class: Any
- do_id_test: bool = False
default_init_args: list[Any] = []
versioned_defaults_to_test: dict[str, str | float] = {}
+ legal_ids = [1, 5]
+ illegal_ids = [0]
- def test_id_setting(self) -> None:
- """Test .id_ being set and its legal range being enforced."""
- if not self.do_id_test:
- return
- with self.assertRaises(HandledException):
- self.checked_class(0, *self.default_init_args)
- obj = self.checked_class(5, *self.default_init_args)
- self.assertEqual(obj.id_, 5)
+ @_within_checked_class
+ def test_id_validation(self) -> None:
+ """Test .id_ validation/setting."""
+ for id_ in self.illegal_ids:
+ with self.assertRaises(HandledException):
+ self.checked_class(id_, *self.default_init_args)
+ for id_ in self.legal_ids:
+ obj = self.checked_class(id_, *self.default_init_args)
+ self.assertEqual(obj.id_, id_)
+ @_within_checked_class
def test_versioned_defaults(self) -> None:
"""Test defaults of VersionedAttributes."""
- if len(self.versioned_defaults_to_test) == 0:
- return
- obj = self.checked_class(1, *self.default_init_args)
+ id_ = self.legal_ids[0]
+ obj = self.checked_class(id_, *self.default_init_args)
for k, v in self.versioned_defaults_to_test.items():
self.assertEqual(getattr(obj, k).newest, v)
self.db_conn.close()
remove_file(self.db_file.path)
- @staticmethod
- def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
- def wrapper(self: TestCaseWithDB) -> None:
- if hasattr(self, 'checked_class'):
- f(self)
- return wrapper
-
def _load_from_db(self, id_: int | str) -> list[object]:
db_found: list[object] = []
for row in self.db_conn.row_where(self.checked_class.table_name,
self.server_thread.join()
super().tearDown()
+ @staticmethod
+ def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
+ """Return list of only 'id' fields of items."""
+ id_list = []
+ for item in items:
+ assert isinstance(item['id'], (int, str))
+ id_list += [item['id']]
+ return id_list
+
+ @staticmethod
+ def as_refs(items: list[dict[str, object]]
+ ) -> dict[str, dict[str, object]]:
+ """Return dictionary of items by their 'id' fields."""
+ refs = {}
+ for item in items:
+ refs[str(item['id'])] = item
+ return refs
+
+ @staticmethod
+ def cond_as_dict(id_: int = 1,
+ is_active: bool = False,
+ titles: None | list[str] = None,
+ descriptions: None | list[str] = None
+ ) -> dict[str, object]:
+ """Return JSON of Condition to expect."""
+ d = {'id': id_,
+ 'is_active': is_active,
+ '_versioned': {
+ 'title': {},
+ 'description': {}}}
+ titles = titles if titles else []
+ descriptions = descriptions if descriptions else []
+ assert isinstance(d['_versioned'], dict)
+ for i, title in enumerate(titles):
+ d['_versioned']['title'][i] = title
+ for i, description in enumerate(descriptions):
+ d['_versioned']['description'][i] = description
+ return d
+
+ @staticmethod
+ def proc_as_dict(id_: int = 1,
+ title: str = 'A',
+ description: str = '',
+ effort: float = 1.0,
+ conditions: None | list[int] = None,
+ disables: None | list[int] = None,
+ blockers: None | list[int] = None,
+ enables: None | list[int] = None
+ ) -> dict[str, object]:
+ """Return JSON of Process to expect."""
+ # pylint: disable=too-many-arguments
+ d = {'id': id_,
+ 'calendarize': False,
+ 'suppressed_steps': [],
+ 'explicit_steps': [],
+ '_versioned': {
+ 'title': {0: title},
+ 'description': {0: description},
+ 'effort': {0: effort}},
+ 'conditions': conditions if conditions else [],
+ 'disables': disables if disables else [],
+ 'enables': enables if enables else [],
+ 'blockers': blockers if blockers else []}
+ return d
+
def check_redirect(self, target: str) -> None:
"""Check that self.conn answers with a 302 redirect to target."""
response = self.conn.getresponse()