X-Git-Url: https://plomlompom.com/repos/?a=blobdiff_plain;f=tests%2Futils.py;h=fbe739d2e002137a4187ca09b19846ea63ccbedf;hb=c4ccb784bb3a83c1c614c9bab7fc007ee17f6615;hp=2a919a31dbb057f4db558861d2fafe8277db5521;hpb=e14580b4ee47363cad317e4ec1de91affe03d53a;p=plomtask diff --git a/tests/utils.py b/tests/utils.py index 2a919a3..fbe739d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,17 +5,38 @@ from http.client import HTTPConnection from urllib.parse import urlencode from datetime import datetime from os import remove as remove_file -from typing import Mapping +from typing import Mapping, Any from plomtask.db import DatabaseFile, DatabaseConnection from plomtask.http import TaskHandler, TaskServer from plomtask.processes import Process, ProcessStep from plomtask.conditions import Condition from plomtask.days import Day from plomtask.todos import Todo +from plomtask.exceptions import NotFoundException, HandledException + + +class TestCaseSansDB(TestCase): + """Tests requiring no DB setup.""" + checked_class: Any + + def check_id_setting(self, *args: Any) -> None: + """Test .id_ being set and its legal range being enforced.""" + with self.assertRaises(HandledException): + self.checked_class(0, *args) + obj = self.checked_class(5, *args) + self.assertEqual(obj.id_, 5) + + def check_versioned_defaults(self, attrs: dict[str, Any]) -> None: + """Test defaults of VersionedAttributes.""" + obj = self.checked_class(None) + for k, v in attrs.items(): + self.assertEqual(getattr(obj, k).newest, v) class TestCaseWithDB(TestCase): """Module tests not requiring DB setup.""" + checked_class: Any + default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3) def setUp(self) -> None: Condition.empty_cache() @@ -32,6 +53,110 @@ class TestCaseWithDB(TestCase): self.db_conn.close() remove_file(self.db_file.path) + def check_storage(self, content: list[Any]) -> None: + """Test cache and DB equal content.""" + expected_cache = {} + for item in content: + expected_cache[item.id_] = item + self.assertEqual(self.checked_class.get_cache(), expected_cache) + db_found: list[Any] = [] + for item in content: + assert isinstance(item.id_, type(self.default_ids[0])) + for row in self.db_conn.row_where(self.checked_class.table_name, + 'id', item.id_): + db_found += [self.checked_class.from_table_row(self.db_conn, + row)] + self.assertEqual(sorted(content), sorted(db_found)) + + def check_saving_and_caching(self, **kwargs: Any) -> Any: + """Test instance.save in its core without relations.""" + obj = self.checked_class(**kwargs) # pylint: disable=not-callable + # check object init itself doesn't store anything yet + self.check_storage([]) + # check saving stores in cache and DB + obj.save(self.db_conn) + self.check_storage([obj]) + # check core attributes set properly (and not unset by saving) + for key, value in kwargs.items(): + self.assertEqual(getattr(obj, key), value) + + def check_by_id(self) -> None: + """Test .by_id(), including creation.""" + # check failure if not yet saved + id1, id2 = self.default_ids[0], self.default_ids[1] + obj = self.checked_class(id1) # pylint: disable=not-callable + with self.assertRaises(NotFoundException): + self.checked_class.by_id(self.db_conn, id1) + # check identity of saved and retrieved + obj.save(self.db_conn) + self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1)) + # check create=True acts like normal instantiation (sans saving) + by_id_created = self.checked_class.by_id(self.db_conn, id2, + create=True) + # pylint: disable=not-callable + self.assertEqual(self.checked_class(id2), by_id_created) + self.check_storage([obj]) + + def check_from_table_row(self, *args: Any) -> None: + """Test .from_table_row() properly reads in class from DB""" + id_ = self.default_ids[0] + obj = self.checked_class(id_, *args) # pylint: disable=not-callable + obj.save(self.db_conn) + assert isinstance(obj.id_, type(self.default_ids[0])) + for row in self.db_conn.row_where(self.checked_class.table_name, + 'id', obj.id_): + retrieved = self.checked_class.from_table_row(self.db_conn, row) + self.assertEqual(obj, retrieved) + self.assertEqual({obj.id_: obj}, self.checked_class.get_cache()) + + def check_all(self) -> tuple[Any, Any, Any]: + """Test .all().""" + # pylint: disable=not-callable + item1 = self.checked_class(self.default_ids[0]) + item2 = self.checked_class(self.default_ids[1]) + item3 = self.checked_class(self.default_ids[2]) + # check pre-save .all() returns empty list + self.assertEqual(self.checked_class.all(self.db_conn), []) + # check that all() shows all saved, but no unsaved items + item1.save(self.db_conn) + item3.save(self.db_conn) + self.assertEqual(sorted(self.checked_class.all(self.db_conn)), + sorted([item1, item3])) + item2.save(self.db_conn) + self.assertEqual(sorted(self.checked_class.all(self.db_conn)), + sorted([item1, item2, item3])) + return item1, item2, item3 + + def check_singularity(self, defaulting_field: str, + non_default_value: Any, *args: Any) -> None: + """Test pointers made for single object keep pointing to it.""" + id1 = self.default_ids[0] + obj = self.checked_class(id1, *args) # pylint: disable=not-callable + obj.save(self.db_conn) + setattr(obj, defaulting_field, non_default_value) + retrieved = self.checked_class.by_id(self.db_conn, id1) + self.assertEqual(non_default_value, + getattr(retrieved, defaulting_field)) + + def check_versioned_singularity(self) -> None: + """Test singularity of VersionedAttributes on saving (with .title).""" + obj = self.checked_class(None) # pylint: disable=not-callable + obj.save(self.db_conn) + assert isinstance(obj.id_, int) + obj.title.set('named') + retrieved = self.checked_class.by_id(self.db_conn, obj.id_) + self.assertEqual(obj.title.history, retrieved.title.history) + + def check_remove(self, *args: Any) -> None: + """Test .remove() effects on DB and cache.""" + id_ = self.default_ids[0] + obj = self.checked_class(id_, *args) # pylint: disable=not-callable + with self.assertRaises(HandledException): + obj.remove(self.db_conn) + obj.save(self.db_conn) + obj.remove(self.db_conn) + self.check_storage([]) + class TestCaseWithServer(TestCaseWithDB): """Module tests against our HTTP server/handler (and database).""" @@ -70,9 +195,26 @@ class TestCaseWithServer(TestCaseWithDB): 'Content-Length': str(len(encoded_form_data))} self.conn.request('POST', target, body=encoded_form_data, headers=headers) - if redirect_location == '': - redirect_location = target if 302 == expected_code: + if redirect_location == '': + redirect_location = target self.check_redirect(redirect_location) else: self.assertEqual(self.conn.getresponse().status, expected_code) + + def check_get_defaults(self, path: str) -> None: + """Some standard model paths to test.""" + self.check_get(path, 200) + self.check_get(f'{path}?id=', 200) + self.check_get(f'{path}?id=foo', 400) + self.check_get(f'/{path}?id=0', 500) + self.check_get(f'{path}?id=1', 200) + + def post_process(self, id_: int = 1, + form_data: dict[str, Any] | None = None + ) -> dict[str, Any]: + """POST basic Process.""" + if not form_data: + form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1} + self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}') + return form_data