From 8570f4ce4d44b813a1f02b72c5c45a57d2003bae Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Thu, 25 Apr 2024 08:40:48 +0200 Subject: [PATCH] Refactor object retrieval and creation. --- plomtask/conditions.py | 17 +----------- plomtask/days.py | 16 ++--------- plomtask/db.py | 47 +++++++++++++++++++------------ plomtask/processes.py | 63 ++++++++++++++++-------------------------- plomtask/todos.py | 45 ++++++++++++------------------ 5 files changed, 74 insertions(+), 114 deletions(-) diff --git a/plomtask/conditions.py b/plomtask/conditions.py index 8d67e5a..6696125 100644 --- a/plomtask/conditions.py +++ b/plomtask/conditions.py @@ -4,7 +4,6 @@ from typing import Any from sqlite3 import Row from plomtask.db import DatabaseConnection, BaseModel from plomtask.misc import VersionedAttribute -from plomtask.exceptions import NotFoundException class Condition(BaseModel[int]): @@ -13,7 +12,7 @@ class Condition(BaseModel[int]): to_save = ['is_active'] def __init__(self, id_: int | None, is_active: bool = False) -> None: - self.set_int_id(id_) + super().__init__(id_) self.is_active = is_active self.title = VersionedAttribute(self, 'condition_titles', 'UNNAMED') self.description = VersionedAttribute(self, 'condition_descriptions', @@ -30,20 +29,6 @@ class Condition(BaseModel[int]): getattr(condition, name).history_from_row(row_) return condition - @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: int | None, - create: bool = False) -> Condition: - """Collect (or create) Condition and its VersionedAttributes.""" - condition = None - if id_: - condition, _ = super()._by_id(db_conn, id_) - if not condition: - if not create: - raise NotFoundException(f'Condition not found of id: {id_}') - condition = cls(id_, False) - condition.save(db_conn) - return condition - def save(self, db_conn: DatabaseConnection) -> None: """Save self and its VersionedAttributes to DB and cache.""" self.save_core(db_conn) diff --git a/plomtask/days.py b/plomtask/days.py index 258d38d..78340e2 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -1,7 +1,7 @@ """Collecting Day and date-related items.""" from __future__ import annotations from datetime import datetime, timedelta -from plomtask.exceptions import BadFormatException, NotFoundException +from plomtask.exceptions import BadFormatException from plomtask.db import DatabaseConnection, BaseModel DATE_FORMAT = '%Y-%m-%d' @@ -30,6 +30,7 @@ class Day(BaseModel[str]): to_save = ['comment'] def __init__(self, date: str, comment: str = '') -> None: + super().__init__(date) self.id_: str = valid_date(date) self.datetime = datetime.strptime(self.date, DATE_FORMAT) self.comment = comment @@ -65,19 +66,6 @@ class Day(BaseModel[str]): days = gapless_days return days - @classmethod - def by_id(cls, db_conn: DatabaseConnection, - date: str, create: bool = False) -> Day: - """Retrieve Day by date if in DB (prefer cache), else return None.""" - day, _ = super()._by_id(db_conn, date) - if day: - return day - if not create: - raise NotFoundException(f'Day not found for date: {date}') - day = cls(date) - day.cache() - return day - @property def date(self) -> str: """Return self.id_ under the assumption it's a date string.""" diff --git a/plomtask/db.py b/plomtask/db.py index fe67e5c..ebd8c6c 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -4,7 +4,7 @@ from os.path import isfile from difflib import Differ from sqlite3 import connect as sql_connect, Cursor, Row from typing import Any, Self, TypeVar, Generic -from plomtask.exceptions import HandledException +from plomtask.exceptions import HandledException, NotFoundException PATH_DB_SCHEMA = 'scripts/init.sql' EXPECTED_DB_VERSION = 0 @@ -123,6 +123,12 @@ class BaseModel(Generic[BaseModelId]): id_: None | BaseModelId cache_: dict[BaseModelId, Self] + def __init__(self, id_: BaseModelId | None) -> None: + if isinstance(id_, int) and id_ < 1: + msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}' + raise HandledException(msg) + self.id_ = id_ + @classmethod def get_cached(cls: type[BaseModelInstance], id_: BaseModelId) -> BaseModelInstance | None: @@ -173,20 +179,32 @@ class BaseModel(Generic[BaseModelId]): return obj @classmethod - def _by_id(cls, - db_conn: DatabaseConnection, - id_: BaseModelId) -> tuple[Self | None, bool]: + def _by_id(cls, db_conn: DatabaseConnection, + id_: BaseModelId) -> Self | None: """Return instance found by ID, or None, and if from cache or not.""" - from_cache = False obj = cls.get_cached(id_) - if obj: - from_cache = True - else: + if not obj: for row in db_conn.row_where(cls.table_name, 'id', id_): obj = cls.from_table_row(db_conn, row) obj.cache() break - return obj, from_cache + return obj + + @classmethod + def by_id(cls, db_conn: DatabaseConnection, + id_: BaseModelId | None, + # pylint: disable=unused-argument + create: bool = False) -> Self: + """Retrieve by id_, on failure throw NotFoundException.""" + obj = None + if id_ is not None: + obj = cls._by_id(db_conn, id_) + if obj: + return obj + if create: + obj = cls(id_) + return obj + raise NotFoundException(f'found no object of ID {id_}') @classmethod def all(cls: type[BaseModelInstance], @@ -199,18 +217,11 @@ class BaseModel(Generic[BaseModelId]): already_recorded = items.keys() for id_ in db_conn.column_all(cls.table_name, 'id'): if id_ not in already_recorded: - # pylint: disable=no-member - item = cls.by_id(db_conn, id_) # type: ignore[attr-defined] + item = cls.by_id(db_conn, id_) + assert item.id_ is not None items[item.id_] = item return list(items.values()) - def set_int_id(self, id_: int | None) -> None: - """Set id_ if >= 1 or None, else fail.""" - if (id_ is not None) and id_ < 1: - msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}' - raise HandledException(msg) - self.id_ = id_ # type: ignore[assignment] - def save_core(self, db_conn: DatabaseConnection, update_with_lastrowid: bool = True) -> None: """Write bare-bones self (sans connected items), ensuring self.id_.""" diff --git a/plomtask/processes.py b/plomtask/processes.py index 9705f17..c0b13b5 100644 --- a/plomtask/processes.py +++ b/plomtask/processes.py @@ -1,7 +1,8 @@ """Collecting Processes and Process-related items.""" from __future__ import annotations from dataclasses import dataclass -from typing import Set +from typing import Set, Any +from sqlite3 import Row from plomtask.db import DatabaseConnection, BaseModel from plomtask.misc import VersionedAttribute from plomtask.conditions import Condition, ConditionsRelations @@ -25,7 +26,7 @@ class Process(BaseModel[int], ConditionsRelations): # pylint: disable=too-many-instance-attributes def __init__(self, id_: int | None) -> None: - self.set_int_id(id_) + super().__init__(id_) self.title = VersionedAttribute(self, 'process_titles', 'UNNAMED') self.description = VersionedAttribute(self, 'process_descriptions', '') self.effort = VersionedAttribute(self, 'process_efforts', 1.0) @@ -35,34 +36,26 @@ class Process(BaseModel[int], ConditionsRelations): self.disables: list[Condition] = [] @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: int | None, - create: bool = False) -> Process: - """Collect Process, its VersionedAttributes, and its child IDs.""" - process = None - from_cache = False - if id_: - process, from_cache = super()._by_id(db_conn, id_) - if not from_cache: - if not process: - if not create: - raise NotFoundException(f'Process not found of id: {id_}') - process = Process(id_) - if isinstance(process.id_, int): - for name in ('title', 'description', 'effort'): - table = f'process_{name}s' - for row in db_conn.row_where(table, 'parent', process.id_): - getattr(process, name).history_from_row(row) - for row in db_conn.row_where('process_steps', 'owner', - process.id_): - step = ProcessStep.from_table_row(db_conn, row) - process.explicit_steps += [step] - for name in ('conditions', 'enables', 'disables'): - table = f'process_{name}' - for c_id in db_conn.column_where(table, 'condition', - 'process', process.id_): - target = getattr(process, name) - target += [Condition.by_id(db_conn, c_id)] - assert isinstance(process, Process) + def from_table_row(cls, db_conn: DatabaseConnection, + row: Row | list[Any]) -> Process: + """Make from DB row, with dependencies.""" + process = super().from_table_row(db_conn, row) + assert isinstance(process.id_, int) + for name in ('title', 'description', 'effort'): + table = f'process_{name}s' + for row_ in db_conn.row_where(table, 'parent', process.id_): + getattr(process, name).history_from_row(row_) + for row_ in db_conn.row_where('process_steps', 'owner', + process.id_): + step = ProcessStep.from_table_row(db_conn, row_) + process.explicit_steps += [step] # pylint: disable=no-member + for name in ('conditions', 'enables', 'disables'): + table = f'process_{name}' + assert isinstance(process.id_, int) + for c_id in db_conn.column_where(table, 'condition', + 'process', process.id_): + target = getattr(process, name) + target += [Condition.by_id(db_conn, c_id)] return process def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]: @@ -184,19 +177,11 @@ class ProcessStep(BaseModel[int]): def __init__(self, id_: int | None, owner_id: int, step_process_id: int, parent_step_id: int | None) -> None: - self.set_int_id(id_) + super().__init__(id_) self.owner_id = owner_id self.step_process_id = step_process_id self.parent_step_id = parent_step_id - @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep: - """Retrieve ProcessStep by id_, or throw NotFoundException.""" - step, _ = super()._by_id(db_conn, id_) - if step: - return step - raise NotFoundException(f'found no ProcessStep of ID {id_}') - def save(self, db_conn: DatabaseConnection) -> None: """Default to simply calling self.save_core for simple cases.""" self.save_core(db_conn) diff --git a/plomtask/todos.py b/plomtask/todos.py index 5901571..9b9bc0b 100644 --- a/plomtask/todos.py +++ b/plomtask/todos.py @@ -21,15 +21,13 @@ class TodoStepsNode: class Todo(BaseModel[int], ConditionsRelations): """Individual actionable.""" - # pylint: disable=too-many-instance-attributes - table_name = 'todos' to_save = ['process_id', 'is_done', 'date'] def __init__(self, id_: int | None, process: Process, is_done: bool, date: str) -> None: - self.set_int_id(id_) + super().__init__(id_) self.process = process self._is_done = is_done self.date = date @@ -46,36 +44,29 @@ class Todo(BaseModel[int], ConditionsRelations): @classmethod def from_table_row(cls, db_conn: DatabaseConnection, row: Row | list[Any]) -> Todo: - """Make from DB row, write to DB cache.""" + """Make from DB row, with dependencies.""" if row[1] == 0: raise NotFoundException('calling Todo of ' 'unsaved Process') row_as_list = list(row) row_as_list[1] = Process.by_id(db_conn, row[1]) todo = super().from_table_row(db_conn, row_as_list) - assert isinstance(todo, Todo) - return todo - - @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo: - """Get Todo of .id_=id_ and children (from DB cache if possible).""" - todo, from_cache = super()._by_id(db_conn, id_) - if todo is None: - raise NotFoundException(f'Todo of ID not found: {id_}') - if not from_cache: - for t_id in db_conn.column_where('todo_children', 'child', - 'parent', id_): - todo.children += [cls.by_id(db_conn, t_id)] - for t_id in db_conn.column_where('todo_children', 'parent', - 'child', id_): - todo.parents += [cls.by_id(db_conn, t_id)] - for name in ('conditions', 'enables', 'disables'): - table = f'todo_{name}' - assert isinstance(todo.id_, int) - for cond_id in db_conn.column_where(table, 'condition', - 'todo', todo.id_): - target = getattr(todo, name) - target += [Condition.by_id(db_conn, cond_id)] + assert isinstance(todo.id_, int) + for t_id in db_conn.column_where('todo_children', 'child', + 'parent', todo.id_): + # pylint: disable=no-member + todo.children += [cls.by_id(db_conn, t_id)] + for t_id in db_conn.column_where('todo_children', 'parent', + 'child', todo.id_): + # pylint: disable=no-member + todo.parents += [cls.by_id(db_conn, t_id)] + for name in ('conditions', 'enables', 'disables'): + table = f'todo_{name}' + assert isinstance(todo.id_, int) + for cond_id in db_conn.column_where(table, 'condition', + 'todo', todo.id_): + target = getattr(todo, name) + target += [Condition.by_id(db_conn, cond_id)] return todo @classmethod -- 2.30.2