From 8f28c8c685fa91b9cbabb4b424da4091e52058cf Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Tue, 18 Jun 2024 07:02:04 +0200 Subject: [PATCH] Refactor saving and caching tests, treatment of None IDs. --- plomtask/days.py | 3 +- plomtask/db.py | 9 ++-- plomtask/http.py | 19 ++++--- plomtask/versioned_attributes.py | 5 +- tests/days.py | 9 ---- tests/processes.py | 10 ++-- tests/todos.py | 2 + tests/utils.py | 92 +++++++++++++++++++++++--------- 8 files changed, 95 insertions(+), 54 deletions(-) diff --git a/plomtask/days.py b/plomtask/days.py index 267156d..0bd942c 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -41,10 +41,9 @@ class Day(BaseModel[str]): return day @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: str | None) -> Day: + def by_id(cls, db_conn: DatabaseConnection, id_: str) -> Day: """Extend BaseModel.by_id checking for new/lost .todos.""" day = super().by_id(db_conn, id_) - assert day.id_ is not None if day.id_ in Todo.days_to_update: Todo.days_to_update.remove(day.id_) day.todos = Todo.by_date(db_conn, day.id_) diff --git a/plomtask/db.py b/plomtask/db.py index f6ef1cb..797b08e 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -388,9 +388,7 @@ class BaseModel(Generic[BaseModelId]): return obj @classmethod - def by_id(cls, db_conn: DatabaseConnection, - id_: BaseModelId | None - ) -> Self: + def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self: """Retrieve by id_, on failure throw NotFoundException. First try to get from cls.cache_, only then check DB; if found, @@ -414,11 +412,12 @@ class BaseModel(Generic[BaseModelId]): """Wrapper around .by_id, creating (not caching/saving) if not find.""" if not cls.can_create_by_id: raise HandledException('Class cannot .by_id_or_create.') + if id_ is None: + return cls(None) try: return cls.by_id(db_conn, id_) except NotFoundException: - obj = cls(id_) - return obj + return cls(id_) @classmethod def all(cls: type[BaseModelInstance], diff --git a/plomtask/http.py b/plomtask/http.py index be79159..7c7fbd4 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -335,7 +335,8 @@ class TaskHandler(BaseHTTPRequestHandler): adoptables: dict[int, list[Todo]] = {} any_adoptables = [Todo.by_id(self.conn, t.id_) for t in Todo.by_date(self.conn, todo.date) - if t != todo] + if t.id_ is not None + and t != todo] for id_ in collect_adoptables_keys(steps_todo_to_process): adoptables[id_] = [t for t in any_adoptables if t.process.id_ == id_] @@ -410,13 +411,13 @@ class TaskHandler(BaseHTTPRequestHandler): def do_GET_condition_titles(self) -> dict[str, object]: """Show title history of Condition of ?id=.""" - id_ = self._params.get_int_or_none('id') + id_ = self._params.get_int('id') condition = Condition.by_id(self.conn, id_) return {'condition': condition} def do_GET_condition_descriptions(self) -> dict[str, object]: """Show description historys of Condition of ?id=.""" - id_ = self._params.get_int_or_none('id') + id_ = self._params.get_int('id') condition = Condition.by_id(self.conn, id_) return {'condition': condition} @@ -443,19 +444,19 @@ class TaskHandler(BaseHTTPRequestHandler): def do_GET_process_titles(self) -> dict[str, object]: """Show title history of Process of ?id=.""" - id_ = self._params.get_int_or_none('id') + id_ = self._params.get_int('id') process = Process.by_id(self.conn, id_) return {'process': process} def do_GET_process_descriptions(self) -> dict[str, object]: """Show description historys of Process of ?id=.""" - id_ = self._params.get_int_or_none('id') + id_ = self._params.get_int('id') process = Process.by_id(self.conn, id_) return {'process': process} def do_GET_process_efforts(self) -> dict[str, object]: """Show default effort history of Process of ?id=.""" - id_ = self._params.get_int_or_none('id') + id_ = self._params.get_int('id') process = Process.by_id(self.conn, id_) return {'process': process} @@ -597,6 +598,8 @@ class TaskHandler(BaseHTTPRequestHandler): # pylint: disable=too-many-branches id_ = self._params.get_int_or_none('id') for _ in self._form_data.get_all_str('delete'): + if id_ is None: + raise NotFoundException('trying to delete non-saved Process') process = Process.by_id(self.conn, id_) process.remove(self.conn) return '/processes' @@ -673,7 +676,9 @@ class TaskHandler(BaseHTTPRequestHandler): """Update/insert Condition of ?id= and fields defined in postvars.""" id_ = self._params.get_int_or_none('id') for _ in self._form_data.get_all_str('delete'): - condition = Condition.by_id(self.conn, id_) + if id_ is None: + raise NotFoundException('trying to delete non-saved Condition') + condition = Condition.by_id_or_create(self.conn, id_) condition.remove(self.conn) return '/conditions' condition = Condition.by_id_or_create(self.conn, id_) diff --git a/plomtask/versioned_attributes.py b/plomtask/versioned_attributes.py index cbd1c8e..8861c98 100644 --- a/plomtask/versioned_attributes.py +++ b/plomtask/versioned_attributes.py @@ -4,7 +4,8 @@ from typing import Any from sqlite3 import Row from time import sleep from plomtask.db import DatabaseConnection -from plomtask.exceptions import HandledException, BadFormatException +from plomtask.exceptions import (HandledException, BadFormatException, + NotFoundException) TIMESTAMP_FMT = '%Y-%m-%d %H:%M:%S.%f' @@ -98,6 +99,8 @@ class VersionedAttribute: def save(self, db_conn: DatabaseConnection) -> None: """Save as self.history entries, but first wipe old ones.""" + if self.parent.id_ is None: + raise NotFoundException('cannot save attribute to parent if no ID') db_conn.rewrite_relations(self.table_name, 'parent', self.parent.id_, [[item[0], item[1]] for item in self.history.items()]) diff --git a/tests/days.py b/tests/days.py index 02b6c22..9fb12ad 100644 --- a/tests/days.py +++ b/tests/days.py @@ -44,15 +44,6 @@ class TestsWithDB(TestCaseWithDB): checked_class = Day default_ids = ('2024-01-01', '2024-01-02', '2024-01-03') - def test_saving_and_caching(self) -> None: - """Test storage of instances. - - We don't use the parent class's method here because the checked class - has too different a handling of IDs. - """ - kwargs = {'date': self.default_ids[0], 'comment': 'foo'} - self.check_saving_and_caching(**kwargs) - def test_Day_by_date_range_filled(self) -> None: """Test Day.by_date_range_filled.""" date1, date2, date3 = self.default_ids diff --git a/tests/processes.py b/tests/processes.py index 0f43a4d..d33aa8f 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -58,6 +58,7 @@ class TestsWithDB(TestCaseWithDB): def test_Process_conditions_saving(self) -> None: """Test .save/.save_core.""" p, set1, set2, set3 = self.p_of_conditions() + assert p.id_ is not None r = Process.by_id(self.db_conn, p.id_) self.assertEqual(sorted(r.conditions), sorted(set1)) self.assertEqual(sorted(r.enables), sorted(set2)) @@ -200,13 +201,15 @@ class TestsWithDB(TestCaseWithDB): p1.remove(self.db_conn) p2.set_steps(self.db_conn, []) with self.assertRaises(NotFoundException): + assert step_id is not None ProcessStep.by_id(self.db_conn, step_id) p1.remove(self.db_conn) step = ProcessStep(None, p2.id_, p3.id_, None) - step_id = step.id_ p2.set_steps(self.db_conn, [step]) + step_id = step.id_ p2.remove(self.db_conn) with self.assertRaises(NotFoundException): + assert step_id is not None ProcessStep.by_id(self.db_conn, step_id) todo = Todo(None, p3, False, '2024-01-01') todo.save(self.db_conn) @@ -229,10 +232,6 @@ class TestsWithDBForProcessStep(TestCaseWithDB): p = Process(2) p.save(self.db_conn) - def test_saving_and_caching(self) -> None: - """Test storage and initialization of instances and attributes.""" - self.check_saving_and_caching(id_=1, **self.default_init_kwargs) - def test_ProcessStep_remove(self) -> None: """Test .remove and unsetting of owner's .explicit_steps entry.""" p1 = Process(None) @@ -300,6 +299,7 @@ class TestsWithServer(TestCaseWithServer): self.post_process(1, form_data_1) retrieved_process = Process.by_id(self.db_conn, 1) self.assertEqual(retrieved_process.explicit_steps, []) + assert retrieved_step_id is not None with self.assertRaises(NotFoundException): ProcessStep.by_id(self.db_conn, retrieved_step_id) # post new first (top_level) step of process 3 to process 1 diff --git a/tests/todos.py b/tests/todos.py index 56aaf48..7632f39 100644 --- a/tests/todos.py +++ b/tests/todos.py @@ -206,6 +206,7 @@ class TestsWithDB(TestCaseWithDB, TestCaseSansDB): """Test removal.""" todo_1 = Todo(None, self.proc, False, self.date1) todo_1.save(self.db_conn) + assert todo_1.id_ is not None todo_0 = Todo(None, self.proc, False, self.date1) todo_0.save(self.db_conn) todo_0.add_child(todo_1) @@ -233,6 +234,7 @@ class TestsWithDB(TestCaseWithDB, TestCaseSansDB): todo_1.comment = 'foo' todo_1.effort = -0.1 todo_1.save(self.db_conn) + assert todo_1.id_ is not None Todo.by_id(self.db_conn, todo_1.id_) todo_1.comment = '' todo_1_id = todo_1.id_ diff --git a/tests/utils.py b/tests/utils.py index 55c948a..6015710 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -69,23 +69,49 @@ class TestCaseWithDB(TestCase): f(self) return wrapper + def _load_from_db(self, id_: int | str) -> list[object]: + db_found: list[object] = [] + for row in self.db_conn.row_where(self.checked_class.table_name, + 'id', id_): + db_found += [self.checked_class.from_table_row(self.db_conn, + row)] + return db_found + @_within_checked_class - def test_saving_and_caching(self) -> None: - """Test storage and initialization of instances and attributes.""" - self.check_saving_and_caching(id_=1, **self.default_init_kwargs) - obj = self.checked_class(None, **self.default_init_kwargs) - obj.save(self.db_conn) - self.assertEqual(obj.id_, 2) + def test_saving_versioned(self) -> None: + """Test storage and initialization of versioned attributes.""" + def retrieve_attr_vals() -> list[object]: + attr_vals_saved: list[object] = [] + assert hasattr(retrieved, 'id_') + for row in self.db_conn.row_where(attr.table_name, 'parent', + retrieved.id_): + attr_vals_saved += [row[2]] + return attr_vals_saved for attr_name, type_ in self.test_versioneds.items(): - owner = self.checked_class(None) + # fail saving attributes on non-saved owner + owner = self.checked_class(None, **self.default_init_kwargs) vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1] attr = getattr(owner, attr_name) attr.set(vals[0]) attr.set(vals[1]) + with self.assertRaises(NotFoundException): + attr.save(self.db_conn) owner.save(self.db_conn) - retrieved = owner.__class__.by_id(self.db_conn, owner.id_) + # check stored attribute is as expected + retrieved = self._load_from_db(owner.id_)[0] attr = getattr(retrieved, attr_name) self.assertEqual(sorted(attr.history.values()), vals) + # check owner.save() created entries in attr table + attr_vals_saved = retrieve_attr_vals() + self.assertEqual(vals, attr_vals_saved) + # check setting new val to attr inconsequential to DB without save + attr.set(vals[0]) + attr_vals_saved = retrieve_attr_vals() + self.assertEqual(vals, attr_vals_saved) + # check save finally adds new val + attr.save(self.db_conn) + attr_vals_saved = retrieve_attr_vals() + self.assertEqual(vals + [vals[0]], attr_vals_saved) def check_identity_with_cache_and_db(self, content: list[Any]) -> None: """Test both cache and DB equal content.""" @@ -97,24 +123,42 @@ class TestCaseWithDB(TestCase): db_found: list[Any] = [] for item in content: assert isinstance(item.id_, type(self.default_ids[0])) - for row in self.db_conn.row_where(self.checked_class.table_name, - 'id', item.id_): - db_found += [self.checked_class.from_table_row(self.db_conn, - row)] + db_found += self._load_from_db(item.id_) hashes_db_found = [hash(x) for x in db_found] self.assertEqual(sorted(hashes_content), sorted(hashes_db_found)) - def check_saving_and_caching(self, **kwargs: Any) -> None: - """Test instance.save in its core without relations.""" - obj = self.checked_class(**kwargs) # pylint: disable=not-callable - # check object init itself doesn't store anything yet - self.check_identity_with_cache_and_db([]) - # check saving sets core attributes properly - obj.save(self.db_conn) - for key, value in kwargs.items(): - self.assertEqual(getattr(obj, key), value) - # check saving stored properly in cache and DB - self.check_identity_with_cache_and_db([obj]) + @_within_checked_class + def test_saving_and_caching(self) -> None: + """Test effects of .cache() and .save().""" + id1 = self.default_ids[0] + # check failure to cache without ID (if None-ID input possible) + if isinstance(id1, int): + obj0 = self.checked_class(None, **self.default_init_kwargs) + with self.assertRaises(HandledException): + obj0.cache() + # check mere object init itself doesn't even store in cache + obj1 = self.checked_class(id1, **self.default_init_kwargs) + self.assertEqual(self.checked_class.get_cache(), {}) + # check .cache() fills cache, but not DB + obj1.cache() + self.assertEqual(self.checked_class.get_cache(), {id1: obj1}) + db_found = self._load_from_db(id1) + self.assertEqual(db_found, []) + # check .save() sets ID (for int IDs), updates cache, and fills DB + # (expect ID to be set to id1, despite obj1 already having that as ID: + # it's generated by cursor.lastrowid on the DB table, and with obj1 + # not written there, obj2 should get it first!) + id_input = None if isinstance(id1, int) else id1 + obj2 = self.checked_class(id_input, **self.default_init_kwargs) + obj2.save(self.db_conn) + obj2_hash = hash(obj2) + self.assertEqual(self.checked_class.get_cache(), {id1: obj2}) + db_found += self._load_from_db(id1) + self.assertEqual([hash(o) for o in db_found], [obj2_hash]) + # check we cannot overwrite obj2 with obj1 despite its same ID, + # since it has disappeared now + with self.assertRaises(HandledException): + obj1.save(self.db_conn) @_within_checked_class def test_by_id(self) -> None: @@ -131,8 +175,6 @@ class TestCaseWithDB(TestCase): obj2 = self.checked_class(id2, **self.default_init_kwargs) obj2.save(self.db_conn) self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2)) - # obj1.save(self.db_conn) - # self.check_identity_with_cache_and_db([obj1, obj2]) @_within_checked_class def test_by_id_or_create(self) -> None: -- 2.30.2