From 25b71c6f0b10db05907128daf50c6e543e514c35 Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Tue, 18 Jun 2024 04:37:57 +0200 Subject: [PATCH] Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests. --- plomtask/conditions.py | 16 +++++++------ plomtask/days.py | 8 +++---- plomtask/db.py | 22 ++++++++++++------ plomtask/http.py | 12 +++++----- plomtask/processes.py | 1 + tests/conditions.py | 4 ---- tests/days.py | 6 +---- tests/misc.py | 2 +- tests/processes.py | 6 +---- tests/utils.py | 51 ++++++++++++++++++++++++++++-------------- 10 files changed, 71 insertions(+), 57 deletions(-) diff --git a/plomtask/conditions.py b/plomtask/conditions.py index 70365ce..b60d0af 100644 --- a/plomtask/conditions.py +++ b/plomtask/conditions.py @@ -11,6 +11,7 @@ class Condition(BaseModel[int]): to_save = ['is_active'] to_save_versioned = ['title', 'description'] to_search = ['title.newest', 'description.newest'] + can_create_by_id = True def __init__(self, id_: int | None, is_active: bool = False) -> None: super().__init__(id_) @@ -25,13 +26,14 @@ class Condition(BaseModel[int]): Checks for Todos and Processes that depend on Condition, prohibits deletion if found. """ - if self.id_ is None: - raise HandledException('cannot remove unsaved item') - for item in ('process', 'todo'): - for attr in ('conditions', 'blockers', 'enables', 'disables'): - table_name = f'{item}_{attr}' - for _ in db_conn.row_where(table_name, 'condition', self.id_): - raise HandledException('cannot remove Condition in use') + if self.id_ is not None: + for item in ('process', 'todo'): + for attr in ('conditions', 'blockers', 'enables', 'disables'): + table_name = f'{item}_{attr}' + for _ in db_conn.row_where(table_name, 'condition', + self.id_): + msg = 'cannot remove Condition in use' + raise HandledException(msg) super().remove(db_conn) diff --git a/plomtask/days.py b/plomtask/days.py index a924bbf..267156d 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -12,6 +12,7 @@ class Day(BaseModel[str]): """Individual days defined by their dates.""" table_name = 'days' to_save = ['comment'] + can_create_by_id = True def __init__(self, date: str, comment: str = '') -> None: id_ = valid_date(date) @@ -40,12 +41,9 @@ class Day(BaseModel[str]): return day @classmethod - def by_id(cls, - db_conn: DatabaseConnection, id_: str | None, - create: bool = False, - ) -> Day: + def by_id(cls, db_conn: DatabaseConnection, id_: str | None) -> Day: """Extend BaseModel.by_id checking for new/lost .todos.""" - day = super().by_id(db_conn, id_, create) + day = super().by_id(db_conn, id_) assert day.id_ is not None if day.id_ in Todo.days_to_update: Todo.days_to_update.remove(day.id_) diff --git a/plomtask/db.py b/plomtask/db.py index 853b4c6..f6ef1cb 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -238,6 +238,7 @@ class BaseModel(Generic[BaseModelId]): id_: None | BaseModelId cache_: dict[BaseModelId, Self] to_search: list[str] = [] + can_create_by_id = False _exists = True def __init__(self, id_: BaseModelId | None) -> None: @@ -388,15 +389,12 @@ class BaseModel(Generic[BaseModelId]): @classmethod def by_id(cls, db_conn: DatabaseConnection, - id_: BaseModelId | None, - # pylint: disable=unused-argument - create: bool = False) -> Self: + id_: BaseModelId | None + ) -> Self: """Retrieve by id_, on failure throw NotFoundException. First try to get from cls.cache_, only then check DB; if found, put into cache. - - If create=True, make anew (but do not cache yet). """ obj = None if id_ is not None: @@ -407,10 +405,20 @@ class BaseModel(Generic[BaseModelId]): break if obj: return obj - if create: + raise NotFoundException(f'found no object of ID {id_}') + + @classmethod + def by_id_or_create(cls, db_conn: DatabaseConnection, + id_: BaseModelId | None + ) -> Self: + """Wrapper around .by_id, creating (not caching/saving) if not find.""" + if not cls.can_create_by_id: + raise HandledException('Class cannot .by_id_or_create.') + try: + return cls.by_id(db_conn, id_) + except NotFoundException: obj = cls(id_) return obj - raise NotFoundException(f'found no object of ID {id_}') @classmethod def all(cls: type[BaseModelInstance], diff --git a/plomtask/http.py b/plomtask/http.py index a5e4613..be79159 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -244,7 +244,7 @@ class TaskHandler(BaseHTTPRequestHandler): def do_GET_day(self) -> dict[str, object]: """Show single Day of ?date=.""" date = self._params.get_str('date', date_in_n_days(0)) - day = Day.by_id(self.conn, date, create=True) + day = Day.by_id_or_create(self.conn, date) make_type = self._params.get_str('make_type') conditions_present = [] enablers_for = {} @@ -400,7 +400,7 @@ class TaskHandler(BaseHTTPRequestHandler): def do_GET_condition(self) -> dict[str, object]: """Show Condition of ?id=.""" id_ = self._params.get_int_or_none('id') - c = Condition.by_id(self.conn, id_, create=True) + c = Condition.by_id_or_create(self.conn, id_) ps = Process.all(self.conn) return {'condition': c, 'is_new': c.id_ is None, 'enabled_processes': [p for p in ps if c in p.conditions], @@ -423,7 +423,7 @@ class TaskHandler(BaseHTTPRequestHandler): def do_GET_process(self) -> dict[str, object]: """Show Process of ?id=.""" id_ = self._params.get_int_or_none('id') - process = Process.by_id(self.conn, id_, create=True) + process = Process.by_id_or_create(self.conn, id_) title_64 = self._params.get_str('title_b64') if title_64: title = b64decode(title_64.encode()).decode() @@ -501,7 +501,7 @@ class TaskHandler(BaseHTTPRequestHandler): def do_POST_day(self) -> str: """Update or insert Day of date and Todos mapped to it.""" date = self._params.get_str('date') - day = Day.by_id(self.conn, date, create=True) + day = Day.by_id_or_create(self.conn, date) day.comment = self._form_data.get_str('day_comment') day.save(self.conn) make_type = self._form_data.get_str('make_type') @@ -600,7 +600,7 @@ class TaskHandler(BaseHTTPRequestHandler): process = Process.by_id(self.conn, id_) process.remove(self.conn) return '/processes' - process = Process.by_id(self.conn, id_, create=True) + process = Process.by_id_or_create(self.conn, id_) process.title.set(self._form_data.get_str('title')) process.description.set(self._form_data.get_str('description')) process.effort.set(self._form_data.get_float('effort')) @@ -676,7 +676,7 @@ class TaskHandler(BaseHTTPRequestHandler): condition = Condition.by_id(self.conn, id_) condition.remove(self.conn) return '/conditions' - condition = Condition.by_id(self.conn, id_, create=True) + condition = Condition.by_id_or_create(self.conn, id_) condition.is_active = self._form_data.get_str('is_active') == 'True' condition.title.set(self._form_data.get_str('title')) condition.description.set(self._form_data.get_str('description')) diff --git a/plomtask/processes.py b/plomtask/processes.py index 4ff90ef..d007d0f 100644 --- a/plomtask/processes.py +++ b/plomtask/processes.py @@ -34,6 +34,7 @@ class Process(BaseModel[int], ConditionsRelations): ('process_step_suppressions', 'process', 'suppressed_steps', 0)] to_search = ['title.newest', 'description.newest'] + can_create_by_id = True def __init__(self, id_: int | None, calendarize: bool = False) -> None: BaseModel.__init__(self, id_) diff --git a/tests/conditions.py b/tests/conditions.py index 562dcd9..969942b 100644 --- a/tests/conditions.py +++ b/tests/conditions.py @@ -25,10 +25,6 @@ class TestsWithDB(TestCaseWithDB): self.check_versioned_from_table_row('title', str) self.check_versioned_from_table_row('description', str) - def test_Condition_by_id(self) -> None: - """Test .by_id(), including creation.""" - self.check_by_id() - def test_Condition_versioned_attributes_singularity(self) -> None: """Test behavior of VersionedAttributes on saving (with .title).""" self.check_versioned_singularity() diff --git a/tests/days.py b/tests/days.py index 1972dbd..02b6c22 100644 --- a/tests/days.py +++ b/tests/days.py @@ -53,10 +53,6 @@ class TestsWithDB(TestCaseWithDB): kwargs = {'date': self.default_ids[0], 'comment': 'foo'} self.check_saving_and_caching(**kwargs) - def test_Day_by_id(self) -> None: - """Test .by_id().""" - self.check_by_id() - def test_Day_by_date_range_filled(self) -> None: """Test Day.by_date_range_filled.""" date1, date2, date3 = self.default_ids @@ -87,7 +83,7 @@ class TestsWithDB(TestCaseWithDB): self.assertEqual(Day.by_date_range_filled(self.db_conn, day5.date, day7.date), [day5, day6, day7]) - self.check_storage([day1, day2, day3, day6]) + self.check_identity_with_cache_and_db([day1, day2, day3, day6]) # check 'today' is interpreted as today's date today = Day(date_in_n_days(0)) today.save(self.db_conn) diff --git a/tests/misc.py b/tests/misc.py index b0fb872..a27f0d0 100644 --- a/tests/misc.py +++ b/tests/misc.py @@ -151,7 +151,7 @@ class TestsWithServer(TestCaseWithServer): """Tests against our HTTP server/handler (and database).""" def test_do_GET(self) -> None: - """Test / redirect, and unknown targets failing.""" + """Test GET / redirect, and unknown targets failing.""" self.conn.request('GET', '/') self.check_redirect('/day') self.check_get('/foo', 404) diff --git a/tests/processes.py b/tests/processes.py index f495fd5..0f43a4d 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -182,10 +182,6 @@ class TestsWithDB(TestCaseWithDB): method(self.db_conn, [c1.id_, c2.id_]) self.assertEqual(getattr(p, target), [c1, c2]) - def test_Process_by_id(self) -> None: - """Test .by_id(), including creation""" - self.check_by_id() - def test_Process_versioned_attributes_singularity(self) -> None: """Test behavior of VersionedAttributes on saving (with .title).""" self.check_versioned_singularity() @@ -249,7 +245,7 @@ class TestsWithDBForProcessStep(TestCaseWithDB): p1.set_steps(self.db_conn, [step]) step.remove(self.db_conn) self.assertEqual(p1.explicit_steps, []) - self.check_storage([]) + self.check_identity_with_cache_and_db([]) class TestsWithServer(TestCaseWithServer): diff --git a/tests/utils.py b/tests/utils.py index 9d3d11d..55c948a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -87,8 +87,8 @@ class TestCaseWithDB(TestCase): attr = getattr(retrieved, attr_name) self.assertEqual(sorted(attr.history.values()), vals) - def check_storage(self, content: list[Any]) -> None: - """Test cache and DB equal content.""" + def check_identity_with_cache_and_db(self, content: list[Any]) -> None: + """Test both cache and DB equal content.""" expected_cache = {} for item in content: expected_cache[item.id_] = item @@ -108,30 +108,47 @@ class TestCaseWithDB(TestCase): """Test instance.save in its core without relations.""" obj = self.checked_class(**kwargs) # pylint: disable=not-callable # check object init itself doesn't store anything yet - self.check_storage([]) + self.check_identity_with_cache_and_db([]) # check saving sets core attributes properly obj.save(self.db_conn) for key, value in kwargs.items(): self.assertEqual(getattr(obj, key), value) # check saving stored properly in cache and DB - self.check_storage([obj]) + self.check_identity_with_cache_and_db([obj]) - def check_by_id(self) -> None: - """Test .by_id(), including creation.""" + @_within_checked_class + def test_by_id(self) -> None: + """Test .by_id().""" + id1, id2, _ = self.default_ids # check failure if not yet saved - id1, id2 = self.default_ids[0], self.default_ids[1] - obj = self.checked_class(id1) # pylint: disable=not-callable + obj1 = self.checked_class(id1, **self.default_init_kwargs) with self.assertRaises(NotFoundException): self.checked_class.by_id(self.db_conn, id1) + # check identity of cached and retrieved + obj1.cache() + self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1)) # check identity of saved and retrieved - obj.save(self.db_conn) - self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1)) - # check create=True acts like normal instantiation (sans saving) - by_id_created = self.checked_class.by_id(self.db_conn, id2, - create=True) - # pylint: disable=not-callable - self.assertEqual(self.checked_class(id2), by_id_created) - self.check_storage([obj]) + obj2 = self.checked_class(id2, **self.default_init_kwargs) + obj2.save(self.db_conn) + self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2)) + # obj1.save(self.db_conn) + # self.check_identity_with_cache_and_db([obj1, obj2]) + + @_within_checked_class + def test_by_id_or_create(self) -> None: + """Test .by_id_or_create.""" + # check .by_id_or_create acts like normal instantiation (sans saving) + id_ = self.default_ids[0] + if not self.checked_class.can_create_by_id: + with self.assertRaises(HandledException): + self.checked_class.by_id_or_create(self.db_conn, id_) + # check .by_id_or_create fails if wrong class + else: + by_id_created = self.checked_class.by_id_or_create(self.db_conn, + id_) + with self.assertRaises(NotFoundException): + self.checked_class.by_id(self.db_conn, id_) + self.assertEqual(self.checked_class(id_), by_id_created) @_within_checked_class def test_from_table_row(self) -> None: @@ -230,7 +247,7 @@ class TestCaseWithDB(TestCase): obj.remove(self.db_conn) obj.save(self.db_conn) obj.remove(self.db_conn) - self.check_storage([]) + self.check_identity_with_cache_and_db([]) class TestCaseWithServer(TestCaseWithDB): -- 2.30.2