X-Git-Url: https://plomlompom.com/repos/berlin_corona.txt?a=blobdiff_plain;f=tests%2Futils.py;h=f473c180ba565355a3b6ac0a03acb85d2fa307de;hb=bdb93117ce0f2b08b7b70cf43ac086afa4689c0f;hp=25cc9ba1e79d663ec692570f6f4c1fce4eaaf911;hpb=c021152e6566c8374170de916c69d6b5c816cd54;p=plomtask diff --git a/tests/utils.py b/tests/utils.py index 25cc9ba..f473c18 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,27 +17,36 @@ from plomtask.todos import Todo from plomtask.exceptions import NotFoundException, HandledException +def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: TestCase) -> None: + if hasattr(self, 'checked_class'): + f(self) + return wrapper + + class TestCaseSansDB(TestCase): """Tests requiring no DB setup.""" checked_class: Any - do_id_test: bool = False default_init_args: list[Any] = [] versioned_defaults_to_test: dict[str, str | float] = {} + legal_ids = [1, 5] + illegal_ids = [0] - def test_id_setting(self) -> None: - """Test .id_ being set and its legal range being enforced.""" - if not self.do_id_test: - return - with self.assertRaises(HandledException): - self.checked_class(0, *self.default_init_args) - obj = self.checked_class(5, *self.default_init_args) - self.assertEqual(obj.id_, 5) + @_within_checked_class + def test_id_validation(self) -> None: + """Test .id_ validation/setting.""" + for id_ in self.illegal_ids: + with self.assertRaises(HandledException): + self.checked_class(id_, *self.default_init_args) + for id_ in self.legal_ids: + obj = self.checked_class(id_, *self.default_init_args) + self.assertEqual(obj.id_, id_) + @_within_checked_class def test_versioned_defaults(self) -> None: """Test defaults of VersionedAttributes.""" - if len(self.versioned_defaults_to_test) == 0: - return - obj = self.checked_class(1, *self.default_init_args) + id_ = self.legal_ids[0] + obj = self.checked_class(id_, *self.default_init_args) for k, v in self.versioned_defaults_to_test.items(): self.assertEqual(getattr(obj, k).newest, v) @@ -62,13 +71,6 @@ class TestCaseWithDB(TestCase): self.db_conn.close() remove_file(self.db_file.path) - @staticmethod - def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: - def wrapper(self: TestCaseWithDB) -> None: - if hasattr(self, 'checked_class'): - f(self) - return wrapper - def _load_from_db(self, id_: int | str) -> list[object]: db_found: list[object] = [] for row in self.db_conn.row_where(self.checked_class.table_name, @@ -192,18 +194,25 @@ class TestCaseWithDB(TestCase): @_within_checked_class def test_by_id_or_create(self) -> None: """Test .by_id_or_create.""" - # check .by_id_or_create acts like normal instantiation (sans saving) - id_ = self.default_ids[0] + # check .by_id_or_create fails if wrong class if not self.checked_class.can_create_by_id: with self.assertRaises(HandledException): - self.checked_class.by_id_or_create(self.db_conn, id_) - # check .by_id_or_create fails if wrong class - else: - by_id_created = self.checked_class.by_id_or_create(self.db_conn, - id_) - with self.assertRaises(NotFoundException): - self.checked_class.by_id(self.db_conn, id_) - self.assertEqual(self.checked_class(id_), by_id_created) + self.checked_class.by_id_or_create(self.db_conn, None) + return + # check ID input of None creates, on saving, ID=1,2,… for int IDs + if isinstance(self.default_ids[0], int): + for n in range(2): + item = self.checked_class.by_id_or_create(self.db_conn, None) + self.assertEqual(item.id_, None) + item.save(self.db_conn) + self.assertEqual(item.id_, n+1) + # check .by_id_or_create acts like normal instantiation (sans saving) + id_ = self.default_ids[2] + item = self.checked_class.by_id_or_create(self.db_conn, id_) + self.assertEqual(item.id_, id_) + with self.assertRaises(NotFoundException): + self.checked_class.by_id(self.db_conn, item.id_) + self.assertEqual(self.checked_class(item.id_), item) @_within_checked_class def test_from_table_row(self) -> None: