X-Git-Url: https://plomlompom.com/repos/berlin_corona.txt?a=blobdiff_plain;f=tests%2Futils.py;h=f473c180ba565355a3b6ac0a03acb85d2fa307de;hb=bdb93117ce0f2b08b7b70cf43ac086afa4689c0f;hp=0925b2d5b2adc0e415293526a4b01c04fc42b178;hpb=5e87cc0397c0aaf5b4f15eeb7518b25776bcef71;p=plomtask diff --git a/tests/utils.py b/tests/utils.py index 0925b2d..f473c18 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -17,27 +17,36 @@ from plomtask.todos import Todo from plomtask.exceptions import NotFoundException, HandledException +def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: TestCase) -> None: + if hasattr(self, 'checked_class'): + f(self) + return wrapper + + class TestCaseSansDB(TestCase): """Tests requiring no DB setup.""" checked_class: Any - do_id_test: bool = False default_init_args: list[Any] = [] versioned_defaults_to_test: dict[str, str | float] = {} + legal_ids = [1, 5] + illegal_ids = [0] - def test_id_setting(self) -> None: - """Test .id_ being set and its legal range being enforced.""" - if not self.do_id_test: - return - with self.assertRaises(HandledException): - self.checked_class(0, *self.default_init_args) - obj = self.checked_class(5, *self.default_init_args) - self.assertEqual(obj.id_, 5) + @_within_checked_class + def test_id_validation(self) -> None: + """Test .id_ validation/setting.""" + for id_ in self.illegal_ids: + with self.assertRaises(HandledException): + self.checked_class(id_, *self.default_init_args) + for id_ in self.legal_ids: + obj = self.checked_class(id_, *self.default_init_args) + self.assertEqual(obj.id_, id_) + @_within_checked_class def test_versioned_defaults(self) -> None: """Test defaults of VersionedAttributes.""" - if len(self.versioned_defaults_to_test) == 0: - return - obj = self.checked_class(1, *self.default_init_args) + id_ = self.legal_ids[0] + obj = self.checked_class(id_, *self.default_init_args) for k, v in self.versioned_defaults_to_test.items(): self.assertEqual(getattr(obj, k).newest, v) @@ -62,13 +71,6 @@ class TestCaseWithDB(TestCase): self.db_conn.close() remove_file(self.db_file.path) - @staticmethod - def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: - def wrapper(self: TestCaseWithDB) -> None: - if hasattr(self, 'checked_class'): - f(self) - return wrapper - def _load_from_db(self, id_: int | str) -> list[object]: db_found: list[object] = [] for row in self.db_conn.row_where(self.checked_class.table_name,