From: Christian Heller Date: Mon, 6 Jan 2025 15:02:24 +0000 (+0100) Subject: Simplify BaseModel type and .id_ genealogy (at cost of adding two asserts). X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/task?a=commitdiff_plain;h=f6b010f979bf87ff0edb1f7c228196d402b82a4e;p=plomtask Simplify BaseModel type and .id_ genealogy (at cost of adding two asserts). --- diff --git a/plomtask/conditions.py b/plomtask/conditions.py index 8d41604..2240baf 100644 --- a/plomtask/conditions.py +++ b/plomtask/conditions.py @@ -5,7 +5,7 @@ from plomtask.versioned_attributes import VersionedAttribute from plomtask.exceptions import HandledException -class Condition(BaseModel[int]): +class Condition(BaseModel): """Non-Process dependency for ProcessSteps and Todos.""" table_name = 'conditions' to_save_simples = ['is_active'] diff --git a/plomtask/days.py b/plomtask/days.py index 05f93eb..b576bb2 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -1,14 +1,14 @@ """Collecting Day and date-related items.""" from __future__ import annotations -from typing import Any +from typing import Any, Self from sqlite3 import Row from datetime import datetime, timedelta -from plomtask.db import DatabaseConnection, BaseModel +from plomtask.db import DatabaseConnection, BaseModel, BaseModelId from plomtask.todos import Todo from plomtask.dating import (DATE_FORMAT, valid_date) -class Day(BaseModel[str]): +class Day(BaseModel): """Individual days defined by their dates.""" table_name = 'days' to_save_simples = ['comment'] @@ -22,12 +22,12 @@ class Day(BaseModel[str]): self.comment = comment self.todos: list[Todo] = [] - def __lt__(self, other: Day) -> bool: + def __lt__(self, other: Self) -> bool: return self.date < other.date @classmethod def from_table_row(cls, db_conn: DatabaseConnection, row: Row | list[Any] - ) -> Day: + ) -> Self: """Make from DB row, with linked Todos.""" day = super().from_table_row(db_conn, row) assert isinstance(day.id_, str) @@ -35,23 +35,25 @@ class Day(BaseModel[str]): return day @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: str) -> Day: + def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self: """Extend BaseModel.by_id Checks Todo.days_to_update if we need to a retrieved Day's .todos, and also ensures we're looking for proper dates and not strings like "yesterday" by enforcing the valid_date translation. """ + assert isinstance(id_, str) possibly_translated_date = valid_date(id_) day = super().by_id(db_conn, possibly_translated_date) if day.id_ in Todo.days_to_update: + assert isinstance(day.id_, str) Todo.days_to_update.remove(day.id_) day.todos = Todo.by_date(db_conn, day.id_) return day @classmethod - def with_filled_gaps(cls, days: list[Day], start_date: str, end_date: str - ) -> list[Day]: + def with_filled_gaps(cls, days: list[Self], start_date: str, end_date: str + ) -> list[Self]: """In days, fill with (un-stored) Days gaps between start/end_date.""" days = days[:] start_date, end_date = valid_date(start_date), valid_date(end_date) @@ -60,16 +62,16 @@ class Day(BaseModel[str]): days = [d for d in days if d.date >= start_date and d.date <= end_date] days.sort() if start_date not in [d.date for d in days]: - days[:] = [Day(start_date)] + days + days[:] = [cls(start_date)] + days if end_date not in [d.date for d in days]: - days += [Day(end_date)] + days += [cls(end_date)] if len(days) > 1: gapless_days = [] for i, day in enumerate(days): gapless_days += [day] if i < len(days) - 1: while day.next_date != days[i+1].date: - day = Day(day.next_date) + day = cls(day.next_date) gapless_days += [day] days[:] = gapless_days return days diff --git a/plomtask/db.py b/plomtask/db.py index b2e69f8..a8e11ba 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -4,7 +4,7 @@ from os import listdir from os.path import isfile from difflib import Differ from sqlite3 import connect as sql_connect, Cursor, Row -from typing import Any, Self, TypeVar, Generic, Callable +from typing import Any, Self, Callable from plomtask.exceptions import (HandledException, NotFoundException, BadFormatException) from plomtask.dating import valid_date @@ -225,11 +225,10 @@ class DatabaseConnection: self.exec(f'DELETE FROM {table_name} WHERE {key} =', (target,)) -BaseModelId = TypeVar('BaseModelId', int, str) -BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]') +BaseModelId = int | str -class BaseModel(Generic[BaseModelId]): +class BaseModel: """Template for most of the models we use/derive from the DB.""" table_name = '' to_save_simples: list[str] = [] @@ -281,11 +280,10 @@ class BaseModel(Generic[BaseModelId]): return list(cls.versioned_defaults.keys()) @property - def as_dict_and_refs(self) -> tuple[dict[str, object], - list[BaseModel[int] | BaseModel[str]]]: + def as_dict_and_refs(self) -> tuple[dict[str, object], list[BaseModel]]: """Return self as json.dumps-ready dict, list of referenced objects.""" d: dict[str, object] = {'id': self.id_} - refs: list[BaseModel[int] | BaseModel[str]] = [] + refs: list[BaseModel] = [] for to_save in self.to_save_simples: d[to_save] = getattr(self, to_save) if len(self.to_save_versioned()) > 0: @@ -366,18 +364,15 @@ class BaseModel(Generic[BaseModelId]): cls.cache_ = {} @classmethod - def get_cache(cls: type[BaseModelInstance] - ) -> dict[Any, BaseModelInstance]: + def get_cache(cls) -> dict[BaseModelId, Self]: """Get cache dictionary, create it if not yet existing.""" if not hasattr(cls, 'cache_'): - d: dict[Any, BaseModelInstance] = {} + d: dict[BaseModelId, BaseModel] = {} cls.cache_ = d return cls.cache_ @classmethod - def _get_cached(cls: type[BaseModelInstance], - id_: BaseModelId - ) -> BaseModelInstance | None: + def _get_cached(cls, id_: BaseModelId) -> Self | None: """Get object of id_ from class's cache, or None if not found.""" cache = cls.get_cache() if id_ in cache: @@ -412,10 +407,9 @@ class BaseModel(Generic[BaseModelId]): # object retrieval and generation @classmethod - def from_table_row(cls: type[BaseModelInstance], - # pylint: disable=unused-argument + def from_table_row(cls, db_conn: DatabaseConnection, - row: Row | list[Any]) -> BaseModelInstance: + row: Row | list[Any]) -> Self: """Make from DB row (sans relations), update DB cache with it.""" obj = cls(*row) assert obj.id_ is not None @@ -462,8 +456,7 @@ class BaseModel(Generic[BaseModelId]): return cls(id_) @classmethod - def all(cls: type[BaseModelInstance], - db_conn: DatabaseConnection) -> list[BaseModelInstance]: + def all(cls, db_conn: DatabaseConnection) -> list[Self]: """Collect all objects of class into list. Note that this primarily returns the contents of the cache, and only @@ -471,7 +464,7 @@ class BaseModel(Generic[BaseModelId]): cache is always instantly cleaned of any items that would be removed from the DB. """ - items: dict[BaseModelId, BaseModelInstance] = {} + items: dict[BaseModelId, Self] = {} for k, v in cls.get_cache().items(): items[k] = v already_recorded = items.keys() @@ -483,12 +476,11 @@ class BaseModel(Generic[BaseModelId]): return sorted(list(items.values())) @classmethod - def by_date_range_with_limits(cls: type[BaseModelInstance], + def by_date_range_with_limits(cls, db_conn: DatabaseConnection, date_range: tuple[str, str], date_col: str = 'day' - ) -> tuple[list[BaseModelInstance], str, - str]: + ) -> tuple[list[Self], str, str]: """Return list of items in DB within (closed) date_range interval. If no range values provided, defaults them to 'yesterday' and @@ -507,8 +499,7 @@ class BaseModel(Generic[BaseModelId]): return items, start_date, end_date @classmethod - def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection, - pattern: str) -> list[BaseModelInstance]: + def matching(cls, db_conn: DatabaseConnection, pattern: str) -> list[Self]: """Return all objects whose .to_search match pattern.""" items = cls.all(db_conn) if pattern: @@ -544,7 +535,7 @@ class BaseModel(Generic[BaseModelId]): table_name = self.table_name cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES', values) if not isinstance(self.id_, str): - self.id_ = cursor.lastrowid # type: ignore[assignment] + self.id_ = cursor.lastrowid self.cache() for attr_name in self.to_save_versioned(): getattr(self, attr_name).save(db_conn) diff --git a/plomtask/http.py b/plomtask/http.py index 4a5cb53..b30b22c 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -158,8 +158,7 @@ class TaskHandler(BaseHTTPRequestHandler): def flatten(node: object) -> object: - def update_library_with( - item: BaseModel[int] | BaseModel[str]) -> None: + def update_library_with(item: BaseModel) -> None: cls_name = item.__class__.__name__ if cls_name not in library: library[cls_name] = {} @@ -524,6 +523,7 @@ class TaskHandler(BaseHTTPRequestHandler): for process_id in owned_ids: Process.by_id(self._conn, process_id) # to ensure ID exists preset_top_step = process_id + assert not isinstance(process.id_, str) return {'process': process, 'is_new': not exists, 'preset_top_step': preset_top_step, @@ -642,7 +642,7 @@ class TaskHandler(BaseHTTPRequestHandler): # pylint: disable=too-many-locals # pylint: disable=too-many-branches # pylint: disable=too-many-statements - assert todo.id_ is not None + assert isinstance(todo.id_, int) adoptees = [(id_, todo.id_) for id_ in self._form.get_all_int('adopt')] to_make = {'full': [(id_, todo.id_) for id_ in self._form.get_all_int('make_full')], diff --git a/plomtask/processes.py b/plomtask/processes.py index c90519b..e23e97d 100644 --- a/plomtask/processes.py +++ b/plomtask/processes.py @@ -1,6 +1,6 @@ """Collecting Processes and Process-related items.""" from __future__ import annotations -from typing import Set, Any +from typing import Set, Self, Any from sqlite3 import Row from plomtask.misc import DictableNode from plomtask.db import DatabaseConnection, BaseModel @@ -23,7 +23,7 @@ class ProcessStepsNode(DictableNode): 'is_suppressed'] -class Process(BaseModel[int], ConditionsRelations): +class Process(BaseModel, ConditionsRelations): """Template for, and metadata for, Todos, and their arrangements.""" # pylint: disable=too-many-instance-attributes table_name = 'processes' @@ -44,7 +44,7 @@ class Process(BaseModel[int], ConditionsRelations): 'title': lambda p: p.title.newest} def __init__(self, id_: int | None, calendarize: bool = False) -> None: - BaseModel.__init__(self, id_) + super().__init__(id_) ConditionsRelations.__init__(self) for name in ['title', 'description', 'effort']: attr = VersionedAttribute(self, f'process_{name}s', @@ -56,8 +56,8 @@ class Process(BaseModel[int], ConditionsRelations): self.n_owners: int | None = None # only set by from_table_row @classmethod - def from_table_row(cls, db_conn: DatabaseConnection, - row: Row | list[Any]) -> Process: + def from_table_row(cls, db_conn: DatabaseConnection, row: Row | list[Any] + ) -> Self: """Make from DB row, with dependencies.""" process = super().from_table_row(db_conn, row) assert process.id_ is not None @@ -81,7 +81,7 @@ class Process(BaseModel[int], ConditionsRelations): process.n_owners = len(process.used_as_step_by(db_conn)) return process - def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]: + def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Self]: """Return Processes using self for a ProcessStep.""" if not self.id_: return [] @@ -93,7 +93,7 @@ class Process(BaseModel[int], ConditionsRelations): def get_steps(self, db_conn: DatabaseConnection, - external_owner: Process | None = None + external_owner: Self | None = None ) -> list[ProcessStepsNode]: """Return tree of depended-on explicit and implicit ProcessSteps.""" @@ -163,7 +163,7 @@ class Process(BaseModel[int], ConditionsRelations): owners_old = self.used_as_step_by(db_conn) losers = [o for o in owners_old if o.id_ not in owner_ids] owners_old_ids = [o.id_ for o in owners_old] - winners = [Process.by_id(db_conn, id_) for id_ in owner_ids + winners = [self.by_id(db_conn, id_) for id_ in owner_ids if id_ not in owners_old_ids] steps_to_remove = [] for loser in losers: @@ -244,7 +244,7 @@ class Process(BaseModel[int], ConditionsRelations): super().remove(db_conn) -class ProcessStep(BaseModel[int]): +class ProcessStep(BaseModel): """Sub-unit of Processes.""" table_name = 'process_steps' to_save_simples = ['owner_id', 'step_process_id', 'parent_step_id'] diff --git a/plomtask/todos.py b/plomtask/todos.py index 03881b5..5a71400 100644 --- a/plomtask/todos.py +++ b/plomtask/todos.py @@ -1,6 +1,6 @@ """Actionables.""" from __future__ import annotations -from typing import Any, Set +from typing import Any, Self, Set from sqlite3 import Row from plomtask.misc import DictableNode from plomtask.db import DatabaseConnection, BaseModel @@ -32,7 +32,7 @@ class TodoOrProcStepNode(DictableNode): _to_dict = ['node_id', 'todo', 'process', 'children', 'fillable'] -class Todo(BaseModel[int], ConditionsRelations): +class Todo(BaseModel, ConditionsRelations): """Individual actionable.""" # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-public-methods @@ -61,7 +61,7 @@ class Todo(BaseModel[int], ConditionsRelations): date: str, comment: str = '', effort: None | float = None, calendarize: bool = False) -> None: - BaseModel.__init__(self, id_) + super().__init__(id_) ConditionsRelations.__init__(self) if process.id_ is None: raise NotFoundException('Process of Todo without ID (not saved?)') @@ -82,7 +82,7 @@ class Todo(BaseModel[int], ConditionsRelations): @classmethod def by_date_range(cls, db_conn: DatabaseConnection, - date_range: tuple[str, str] = ('', '')) -> list[Todo]: + date_range: tuple[str, str] = ('', '')) -> list[Self]: """Collect Todos of Days within date_range.""" todos, _, _ = cls.by_date_range_with_limits(db_conn, date_range) return todos @@ -90,8 +90,8 @@ class Todo(BaseModel[int], ConditionsRelations): def ensure_children(self, db_conn: DatabaseConnection) -> None: """Ensure Todo children (create or adopt) demanded by Process chain.""" - def walk_steps(parent: Todo, step_node: ProcessStepsNode) -> Todo: - adoptables = [t for t in Todo.by_date(db_conn, parent.date) + def walk_steps(parent: Self, step_node: ProcessStepsNode) -> Todo: + adoptables = [t for t in self.by_date(db_conn, parent.date) if (t not in parent.children) and (t != parent) and step_node.process.id_ == t.process_id] @@ -100,7 +100,8 @@ class Todo(BaseModel[int], ConditionsRelations): satisfier = adoptable break if not satisfier: - satisfier = Todo(None, step_node.process, False, parent.date) + satisfier = self.__class__(None, step_node.process, False, + parent.date) satisfier.save(db_conn) sub_step_nodes = sorted( step_node.steps, @@ -129,7 +130,7 @@ class Todo(BaseModel[int], ConditionsRelations): @classmethod def from_table_row(cls, db_conn: DatabaseConnection, - row: Row | list[Any]) -> Todo: + row: Row | list[Any]) -> Self: """Make from DB row, with dependencies.""" if row[1] == 0: raise NotFoundException('calling Todo of ' @@ -154,12 +155,12 @@ class Todo(BaseModel[int], ConditionsRelations): @classmethod def by_process_id(cls, db_conn: DatabaseConnection, - process_id: int | None) -> list[Todo]: + process_id: int | None) -> list[Self]: """Collect all Todos of Process of process_id.""" return [t for t in cls.all(db_conn) if t.process.id_ == process_id] @classmethod - def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Todo]: + def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Self]: """Collect all Todos for Day of date.""" return cls.by_date_range(db_conn, (date, date)) @@ -253,7 +254,7 @@ class Todo(BaseModel[int], ConditionsRelations): def get_step_tree(self, seen_todos: set[int]) -> TodoNode: """Return tree of depended-on Todos.""" - def make_node(todo: Todo) -> TodoNode: + def make_node(todo: Self) -> TodoNode: children = [] seen = todo.id_ in seen_todos assert isinstance(todo.id_, int) @@ -268,7 +269,7 @@ class Todo(BaseModel[int], ConditionsRelations): def tree_effort(self) -> float: """Return sum of performed efforts of self and all descendants.""" - def walk_tree(node: Todo) -> float: + def walk_tree(node: Self) -> float: local_effort = 0.0 for child in node.children: local_effort += walk_tree(child) @@ -276,10 +277,10 @@ class Todo(BaseModel[int], ConditionsRelations): return walk_tree(self) - def add_child(self, child: Todo) -> None: + def add_child(self, child: Self) -> None: """Add child to self.children, avoid recursion, update parenthoods.""" - def walk_steps(node: Todo) -> None: + def walk_steps(node: Self) -> None: if node.id_ == self.id_: raise BadFormatException('bad child choice causes recursion') for child in node.children: @@ -295,7 +296,7 @@ class Todo(BaseModel[int], ConditionsRelations): self.children += [child] child.parents += [self] - def remove_child(self, child: Todo) -> None: + def remove_child(self, child: Self) -> None: """Remove child from self.children, update counter relations.""" if child not in self.children: raise HandledException('Cannot remove un-parented child.')