X-Git-Url: https://plomlompom.com/repos/?a=blobdiff_plain;f=tests%2Futils.py;h=63b07e93e6f14ca51426e2fc00e959fdbfca7bf1;hb=HEAD;hp=cd0c457c0b0affb0034f5e81de2bc5d4a6139263;hpb=f20d686a4972db5e6bc10bdbd48d27d4b035a716;p=plomtask diff --git a/tests/utils.py b/tests/utils.py index cd0c457..6654368 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,38 +1,465 @@ """Shared test utilities.""" +from __future__ import annotations from unittest import TestCase +from typing import Mapping, Any, Callable from threading import Thread -from datetime import datetime +from http.client import HTTPConnection +from json import loads as json_loads +from urllib.parse import urlencode +from uuid import uuid4 from os import remove as remove_file from plomtask.db import DatabaseFile, DatabaseConnection from plomtask.http import TaskHandler, TaskServer +from plomtask.processes import Process, ProcessStep +from plomtask.conditions import Condition +from plomtask.days import Day +from plomtask.todos import Todo +from plomtask.exceptions import NotFoundException, HandledException + + +def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: TestCase) -> None: + if hasattr(self, 'checked_class'): + f(self) + return wrapper + + +class TestCaseSansDB(TestCase): + """Tests requiring no DB setup.""" + checked_class: Any + default_init_args: list[Any] = [] + versioned_defaults_to_test: dict[str, str | float] = {} + legal_ids = [1, 5] + illegal_ids = [0] + + @_within_checked_class + def test_id_validation(self) -> None: + """Test .id_ validation/setting.""" + for id_ in self.illegal_ids: + with self.assertRaises(HandledException): + self.checked_class(id_, *self.default_init_args) + for id_ in self.legal_ids: + obj = self.checked_class(id_, *self.default_init_args) + self.assertEqual(obj.id_, id_) + + @_within_checked_class + def test_versioned_defaults(self) -> None: + """Test defaults of VersionedAttributes.""" + id_ = self.legal_ids[0] + obj = self.checked_class(id_, *self.default_init_args) + for k, v in self.versioned_defaults_to_test.items(): + self.assertEqual(getattr(obj, k).newest, v) class TestCaseWithDB(TestCase): """Module tests not requiring DB setup.""" + checked_class: Any + default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3) + default_init_kwargs: dict[str, Any] = {} + test_versioneds: dict[str, type] = {} - def setUp(self): - timestamp = datetime.now().timestamp() - self.db_file = DatabaseFile(f'test_db:{timestamp}') - self.db_file.remake() + def setUp(self) -> None: + Condition.empty_cache() + Day.empty_cache() + Process.empty_cache() + ProcessStep.empty_cache() + Todo.empty_cache() + self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}') self.db_conn = DatabaseConnection(self.db_file) - def tearDown(self): + def tearDown(self) -> None: self.db_conn.close() remove_file(self.db_file.path) + def _load_from_db(self, id_: int | str) -> list[object]: + db_found: list[object] = [] + for row in self.db_conn.row_where(self.checked_class.table_name, + 'id', id_): + db_found += [self.checked_class.from_table_row(self.db_conn, + row)] + return db_found + + def _change_obj(self, obj: object) -> str: + attr_name: str = self.checked_class.to_save[-1] + attr = getattr(obj, attr_name) + new_attr: str | int | float | bool + if isinstance(attr, (int, float)): + new_attr = attr + 1 + elif isinstance(attr, str): + new_attr = attr + '_' + elif isinstance(attr, bool): + new_attr = not attr + setattr(obj, attr_name, new_attr) + return attr_name + + def check_identity_with_cache_and_db(self, content: list[Any]) -> None: + """Test both cache and DB equal content.""" + expected_cache = {} + for item in content: + expected_cache[item.id_] = item + self.assertEqual(self.checked_class.get_cache(), expected_cache) + hashes_content = [hash(x) for x in content] + db_found: list[Any] = [] + for item in content: + assert isinstance(item.id_, type(self.default_ids[0])) + db_found += self._load_from_db(item.id_) + hashes_db_found = [hash(x) for x in db_found] + self.assertEqual(sorted(hashes_content), sorted(hashes_db_found)) + + @_within_checked_class + def test_saving_versioned(self) -> None: + """Test storage and initialization of versioned attributes.""" + def retrieve_attr_vals() -> list[object]: + attr_vals_saved: list[object] = [] + assert hasattr(retrieved, 'id_') + for row in self.db_conn.row_where(attr.table_name, 'parent', + retrieved.id_): + attr_vals_saved += [row[2]] + return attr_vals_saved + for attr_name, type_ in self.test_versioneds.items(): + # fail saving attributes on non-saved owner + owner = self.checked_class(None, **self.default_init_kwargs) + vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1] + attr = getattr(owner, attr_name) + attr.set(vals[0]) + attr.set(vals[1]) + with self.assertRaises(NotFoundException): + attr.save(self.db_conn) + owner.save(self.db_conn) + # check stored attribute is as expected + retrieved = self._load_from_db(owner.id_)[0] + attr = getattr(retrieved, attr_name) + self.assertEqual(sorted(attr.history.values()), vals) + # check owner.save() created entries in attr table + attr_vals_saved = retrieve_attr_vals() + self.assertEqual(vals, attr_vals_saved) + # check setting new val to attr inconsequential to DB without save + attr.set(vals[0]) + attr_vals_saved = retrieve_attr_vals() + self.assertEqual(vals, attr_vals_saved) + # check save finally adds new val + attr.save(self.db_conn) + attr_vals_saved = retrieve_attr_vals() + self.assertEqual(vals + [vals[0]], attr_vals_saved) + + @_within_checked_class + def test_saving_and_caching(self) -> None: + """Test effects of .cache() and .save().""" + id1 = self.default_ids[0] + # check failure to cache without ID (if None-ID input possible) + if isinstance(id1, int): + obj0 = self.checked_class(None, **self.default_init_kwargs) + with self.assertRaises(HandledException): + obj0.cache() + # check mere object init itself doesn't even store in cache + obj1 = self.checked_class(id1, **self.default_init_kwargs) + self.assertEqual(self.checked_class.get_cache(), {}) + # check .cache() fills cache, but not DB + obj1.cache() + self.assertEqual(self.checked_class.get_cache(), {id1: obj1}) + db_found = self._load_from_db(id1) + self.assertEqual(db_found, []) + # check .save() sets ID (for int IDs), updates cache, and fills DB + # (expect ID to be set to id1, despite obj1 already having that as ID: + # it's generated by cursor.lastrowid on the DB table, and with obj1 + # not written there, obj2 should get it first!) + id_input = None if isinstance(id1, int) else id1 + obj2 = self.checked_class(id_input, **self.default_init_kwargs) + obj2.save(self.db_conn) + obj2_hash = hash(obj2) + self.assertEqual(self.checked_class.get_cache(), {id1: obj2}) + db_found += self._load_from_db(id1) + self.assertEqual([hash(o) for o in db_found], [obj2_hash]) + # check we cannot overwrite obj2 with obj1 despite its same ID, + # since it has disappeared now + with self.assertRaises(HandledException): + obj1.save(self.db_conn) + + @_within_checked_class + def test_by_id(self) -> None: + """Test .by_id().""" + id1, id2, _ = self.default_ids + # check failure if not yet saved + obj1 = self.checked_class(id1, **self.default_init_kwargs) + with self.assertRaises(NotFoundException): + self.checked_class.by_id(self.db_conn, id1) + # check identity of cached and retrieved + obj1.cache() + self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1)) + # check identity of saved and retrieved + obj2 = self.checked_class(id2, **self.default_init_kwargs) + obj2.save(self.db_conn) + self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2)) + + @_within_checked_class + def test_by_id_or_create(self) -> None: + """Test .by_id_or_create.""" + # check .by_id_or_create fails if wrong class + if not self.checked_class.can_create_by_id: + with self.assertRaises(HandledException): + self.checked_class.by_id_or_create(self.db_conn, None) + return + # check ID input of None creates, on saving, ID=1,2,… for int IDs + if isinstance(self.default_ids[0], int): + for n in range(2): + item = self.checked_class.by_id_or_create(self.db_conn, None) + self.assertEqual(item.id_, None) + item.save(self.db_conn) + self.assertEqual(item.id_, n+1) + # check .by_id_or_create acts like normal instantiation (sans saving) + id_ = self.default_ids[2] + item = self.checked_class.by_id_or_create(self.db_conn, id_) + self.assertEqual(item.id_, id_) + with self.assertRaises(NotFoundException): + self.checked_class.by_id(self.db_conn, item.id_) + self.assertEqual(self.checked_class(item.id_), item) + + @_within_checked_class + def test_from_table_row(self) -> None: + """Test .from_table_row() properly reads in class directly from DB.""" + id_ = self.default_ids[0] + obj = self.checked_class(id_, **self.default_init_kwargs) + obj.save(self.db_conn) + assert isinstance(obj.id_, type(id_)) + for row in self.db_conn.row_where(self.checked_class.table_name, + 'id', obj.id_): + # check .from_table_row reproduces state saved, no matter if obj + # later changed (with caching even) + hash_original = hash(obj) + attr_name = self._change_obj(obj) + obj.cache() + to_cmp = getattr(obj, attr_name) + retrieved = self.checked_class.from_table_row(self.db_conn, row) + self.assertNotEqual(to_cmp, getattr(retrieved, attr_name)) + self.assertEqual(hash_original, hash(retrieved)) + # check cache contains what .from_table_row just produced + self.assertEqual({retrieved.id_: retrieved}, + self.checked_class.get_cache()) + # check .from_table_row also reads versioned attributes from DB + for attr_name, type_ in self.test_versioneds.items(): + owner = self.checked_class(None) + vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1] + attr = getattr(owner, attr_name) + attr.set(vals[0]) + attr.set(vals[1]) + owner.save(self.db_conn) + for row in self.db_conn.row_where(owner.table_name, 'id', + owner.id_): + retrieved = owner.__class__.from_table_row(self.db_conn, row) + attr = getattr(retrieved, attr_name) + self.assertEqual(sorted(attr.history.values()), vals) + + @_within_checked_class + def test_all(self) -> None: + """Test .all() and its relation to cache and savings.""" + id_1, id_2, id_3 = self.default_ids + item1 = self.checked_class(id_1, **self.default_init_kwargs) + item2 = self.checked_class(id_2, **self.default_init_kwargs) + item3 = self.checked_class(id_3, **self.default_init_kwargs) + # check .all() returns empty list on un-cached items + self.assertEqual(self.checked_class.all(self.db_conn), []) + # check that all() shows only cached/saved items + item1.cache() + item3.save(self.db_conn) + self.assertEqual(sorted(self.checked_class.all(self.db_conn)), + sorted([item1, item3])) + item2.save(self.db_conn) + self.assertEqual(sorted(self.checked_class.all(self.db_conn)), + sorted([item1, item2, item3])) + + @_within_checked_class + def test_singularity(self) -> None: + """Test pointers made for single object keep pointing to it.""" + id1 = self.default_ids[0] + obj = self.checked_class(id1, **self.default_init_kwargs) + obj.save(self.db_conn) + # change object, expect retrieved through .by_id to carry change + attr_name = self._change_obj(obj) + new_attr = getattr(obj, attr_name) + retrieved = self.checked_class.by_id(self.db_conn, id1) + self.assertEqual(new_attr, getattr(retrieved, attr_name)) + + @_within_checked_class + def test_versioned_singularity_title(self) -> None: + """Test singularity of VersionedAttributes on saving (with .title).""" + if 'title' in self.test_versioneds: + obj = self.checked_class(None) + obj.save(self.db_conn) + assert isinstance(obj.id_, int) + # change obj, expect retrieved through .by_id to carry change + obj.title.set('named') + retrieved = self.checked_class.by_id(self.db_conn, obj.id_) + self.assertEqual(obj.title.history, retrieved.title.history) + + @_within_checked_class + def test_remove(self) -> None: + """Test .remove() effects on DB and cache.""" + id_ = self.default_ids[0] + obj = self.checked_class(id_, **self.default_init_kwargs) + # check removal only works after saving + with self.assertRaises(HandledException): + obj.remove(self.db_conn) + obj.save(self.db_conn) + obj.remove(self.db_conn) + # check access to obj fails after removal + with self.assertRaises(HandledException): + print(obj.id_) + # check DB and cache now empty + self.check_identity_with_cache_and_db([]) + class TestCaseWithServer(TestCaseWithDB): """Module tests against our HTTP server/handler (and database).""" - def setUp(self): + def setUp(self) -> None: super().setUp() self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler) self.server_thread = Thread(target=self.httpd.serve_forever) self.server_thread.daemon = True self.server_thread.start() + self.conn = HTTPConnection(str(self.httpd.server_address[0]), + self.httpd.server_address[1]) + self.httpd.set_json_mode() - def tearDown(self): + def tearDown(self) -> None: self.httpd.shutdown() self.httpd.server_close() self.server_thread.join() super().tearDown() + + @staticmethod + def as_id_list(items: list[dict[str, object]]) -> list[int | str]: + """Return list of only 'id' fields of items.""" + id_list = [] + for item in items: + assert isinstance(item['id'], (int, str)) + id_list += [item['id']] + return id_list + + @staticmethod + def as_refs(items: list[dict[str, object]] + ) -> dict[str, dict[str, object]]: + """Return dictionary of items by their 'id' fields.""" + refs = {} + for item in items: + refs[str(item['id'])] = item + return refs + + @staticmethod + def cond_as_dict(id_: int = 1, + is_active: bool = False, + titles: None | list[str] = None, + descriptions: None | list[str] = None + ) -> dict[str, object]: + """Return JSON of Condition to expect.""" + d = {'id': id_, + 'is_active': is_active, + '_versioned': { + 'title': {}, + 'description': {}}} + titles = titles if titles else [] + descriptions = descriptions if descriptions else [] + assert isinstance(d['_versioned'], dict) + for i, title in enumerate(titles): + d['_versioned']['title'][i] = title + for i, description in enumerate(descriptions): + d['_versioned']['description'][i] = description + return d + + @staticmethod + def proc_as_dict(id_: int = 1, + title: str = 'A', + description: str = '', + effort: float = 1.0, + conditions: None | list[int] = None, + disables: None | list[int] = None, + blockers: None | list[int] = None, + enables: None | list[int] = None + ) -> dict[str, object]: + """Return JSON of Process to expect.""" + # pylint: disable=too-many-arguments + d = {'id': id_, + 'calendarize': False, + 'suppressed_steps': [], + 'explicit_steps': [], + '_versioned': { + 'title': {0: title}, + 'description': {0: description}, + 'effort': {0: effort}}, + 'conditions': conditions if conditions else [], + 'disables': disables if disables else [], + 'enables': enables if enables else [], + 'blockers': blockers if blockers else []} + return d + + def check_redirect(self, target: str) -> None: + """Check that self.conn answers with a 302 redirect to target.""" + response = self.conn.getresponse() + self.assertEqual(response.status, 302) + self.assertEqual(response.getheader('Location'), target) + + def check_get(self, target: str, expected_code: int) -> None: + """Check that a GET to target yields expected_code.""" + self.conn.request('GET', target) + self.assertEqual(self.conn.getresponse().status, expected_code) + + def check_post(self, data: Mapping[str, object], target: str, + expected_code: int, redirect_location: str = '') -> None: + """Check that POST of data to target yields expected_code.""" + encoded_form_data = urlencode(data, doseq=True).encode('utf-8') + headers = {'Content-Type': 'application/x-www-form-urlencoded', + 'Content-Length': str(len(encoded_form_data))} + self.conn.request('POST', target, + body=encoded_form_data, headers=headers) + if 302 == expected_code: + if redirect_location == '': + redirect_location = target + self.check_redirect(redirect_location) + else: + self.assertEqual(self.conn.getresponse().status, expected_code) + + def check_get_defaults(self, path: str) -> None: + """Some standard model paths to test.""" + self.check_get(path, 200) + self.check_get(f'{path}?id=', 200) + self.check_get(f'{path}?id=foo', 400) + self.check_get(f'/{path}?id=0', 500) + self.check_get(f'{path}?id=1', 200) + + def post_process(self, id_: int = 1, + form_data: dict[str, Any] | None = None + ) -> dict[str, Any]: + """POST basic Process.""" + if not form_data: + form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1} + self.check_post(form_data, f'/process?id={id_}', 302, + f'/process?id={id_}') + return form_data + + def check_json_get(self, path: str, expected: dict[str, object]) -> None: + """Compare JSON on GET path with expected. + + To simplify comparison of VersionedAttribute histories, transforms + timestamp keys of VersionedAttribute history keys into integers + counting chronologically forward from 0. + """ + def rewrite_history_keys_in(item: Any) -> Any: + if isinstance(item, dict): + if '_versioned' in item.keys(): + for k in item['_versioned']: + vals = item['_versioned'][k].values() + history = {} + for i, val in enumerate(vals): + history[i] = val + item['_versioned'][k] = history + for k in list(item.keys()): + rewrite_history_keys_in(item[k]) + elif isinstance(item, list): + item[:] = [rewrite_history_keys_in(i) for i in item] + return item + self.conn.request('GET', path) + response = self.conn.getresponse() + self.assertEqual(response.status, 200) + retrieved = json_loads(response.read().decode()) + rewrite_history_keys_in(retrieved) + self.assertEqual(expected, retrieved)