home · contact · privacy
Simplify BaseModel type and .id_ genealogy (at cost of adding two asserts). master
authorChristian Heller <c.heller@plomlompom.de>
Mon, 6 Jan 2025 15:02:24 +0000 (16:02 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Mon, 6 Jan 2025 15:02:24 +0000 (16:02 +0100)
plomtask/conditions.py
plomtask/days.py
plomtask/db.py
plomtask/http.py
plomtask/processes.py
plomtask/todos.py

index 8d4160424423cf6ddfccb06b5b04d77e2bea4b15..2240baf0ffaef3eafd84cbbf2fbda9393fab723f 100644 (file)
@@ -5,7 +5,7 @@ from plomtask.versioned_attributes import VersionedAttribute
 from plomtask.exceptions import HandledException
 
 
-class Condition(BaseModel[int]):
+class Condition(BaseModel):
     """Non-Process dependency for ProcessSteps and Todos."""
     table_name = 'conditions'
     to_save_simples = ['is_active']
index 05f93eb8846016c23406fc09346991ec1196f452..b576bb2c16262d76a8d0caa71d358364570a3969 100644 (file)
@@ -1,14 +1,14 @@
 """Collecting Day and date-related items."""
 from __future__ import annotations
-from typing import Any
+from typing import Any, Self
 from sqlite3 import Row
 from datetime import datetime, timedelta
-from plomtask.db import DatabaseConnection, BaseModel
+from plomtask.db import DatabaseConnection, BaseModel, BaseModelId
 from plomtask.todos import Todo
 from plomtask.dating import (DATE_FORMAT, valid_date)
 
 
-class Day(BaseModel[str]):
+class Day(BaseModel):
     """Individual days defined by their dates."""
     table_name = 'days'
     to_save_simples = ['comment']
@@ -22,12 +22,12 @@ class Day(BaseModel[str]):
         self.comment = comment
         self.todos: list[Todo] = []
 
-    def __lt__(self, other: Day) -> bool:
+    def __lt__(self, other: Self) -> bool:
         return self.date < other.date
 
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection, row: Row | list[Any]
-                       ) -> Day:
+                       ) -> Self:
         """Make from DB row, with linked Todos."""
         day = super().from_table_row(db_conn, row)
         assert isinstance(day.id_, str)
@@ -35,23 +35,25 @@ class Day(BaseModel[str]):
         return day
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: str) -> Day:
+    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
         """Extend BaseModel.by_id
 
         Checks Todo.days_to_update if we need to a retrieved Day's .todos,
         and also ensures we're looking for proper dates and not strings like
         "yesterday" by enforcing the valid_date translation.
         """
+        assert isinstance(id_, str)
         possibly_translated_date = valid_date(id_)
         day = super().by_id(db_conn, possibly_translated_date)
         if day.id_ in Todo.days_to_update:
+            assert isinstance(day.id_, str)
             Todo.days_to_update.remove(day.id_)
             day.todos = Todo.by_date(db_conn, day.id_)
         return day
 
     @classmethod
-    def with_filled_gaps(cls, days: list[Day], start_date: str, end_date: str
-                         ) -> list[Day]:
+    def with_filled_gaps(cls, days: list[Self], start_date: str, end_date: str
+                         ) -> list[Self]:
         """In days, fill with (un-stored) Days gaps between start/end_date."""
         days = days[:]
         start_date, end_date = valid_date(start_date), valid_date(end_date)
@@ -60,16 +62,16 @@ class Day(BaseModel[str]):
         days = [d for d in days if d.date >= start_date and d.date <= end_date]
         days.sort()
         if start_date not in [d.date for d in days]:
-            days[:] = [Day(start_date)] + days
+            days[:] = [cls(start_date)] + days
         if end_date not in [d.date for d in days]:
-            days += [Day(end_date)]
+            days += [cls(end_date)]
         if len(days) > 1:
             gapless_days = []
             for i, day in enumerate(days):
                 gapless_days += [day]
                 if i < len(days) - 1:
                     while day.next_date != days[i+1].date:
-                        day = Day(day.next_date)
+                        day = cls(day.next_date)
                         gapless_days += [day]
             days[:] = gapless_days
         return days
index b2e69f811e42c726da52968897baf5339565c8f0..a8e11badc379e121d5c1a8a8923c1ce5232710bd 100644 (file)
@@ -4,7 +4,7 @@ from os import listdir
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Self, TypeVar, Generic, Callable
+from typing import Any, Self, Callable
 from plomtask.exceptions import (HandledException, NotFoundException,
                                  BadFormatException)
 from plomtask.dating import valid_date
@@ -225,11 +225,10 @@ class DatabaseConnection:
         self.exec(f'DELETE FROM {table_name} WHERE {key} =', (target,))
 
 
-BaseModelId = TypeVar('BaseModelId', int, str)
-BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
+BaseModelId = int | str
 
 
-class BaseModel(Generic[BaseModelId]):
+class BaseModel:
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
     to_save_simples: list[str] = []
@@ -281,11 +280,10 @@ class BaseModel(Generic[BaseModelId]):
         return list(cls.versioned_defaults.keys())
 
     @property
-    def as_dict_and_refs(self) -> tuple[dict[str, object],
-                                        list[BaseModel[int] | BaseModel[str]]]:
+    def as_dict_and_refs(self) -> tuple[dict[str, object], list[BaseModel]]:
         """Return self as json.dumps-ready dict, list of referenced objects."""
         d: dict[str, object] = {'id': self.id_}
-        refs: list[BaseModel[int] | BaseModel[str]] = []
+        refs: list[BaseModel] = []
         for to_save in self.to_save_simples:
             d[to_save] = getattr(self, to_save)
         if len(self.to_save_versioned()) > 0:
@@ -366,18 +364,15 @@ class BaseModel(Generic[BaseModelId]):
         cls.cache_ = {}
 
     @classmethod
-    def get_cache(cls: type[BaseModelInstance]
-                  ) -> dict[Any, BaseModelInstance]:
+    def get_cache(cls) -> dict[BaseModelId, Self]:
         """Get cache dictionary, create it if not yet existing."""
         if not hasattr(cls, 'cache_'):
-            d: dict[Any, BaseModelInstance] = {}
+            d: dict[BaseModelId, BaseModel] = {}
             cls.cache_ = d
         return cls.cache_
 
     @classmethod
-    def _get_cached(cls: type[BaseModelInstance],
-                    id_: BaseModelId
-                    ) -> BaseModelInstance | None:
+    def _get_cached(cls, id_: BaseModelId) -> Self | None:
         """Get object of id_ from class's cache, or None if not found."""
         cache = cls.get_cache()
         if id_ in cache:
@@ -412,10 +407,9 @@ class BaseModel(Generic[BaseModelId]):
     # object retrieval and generation
 
     @classmethod
-    def from_table_row(cls: type[BaseModelInstance],
-                       # pylint: disable=unused-argument
+    def from_table_row(cls,
                        db_conn: DatabaseConnection,
-                       row: Row | list[Any]) -> BaseModelInstance:
+                       row: Row | list[Any]) -> Self:
         """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
         assert obj.id_ is not None
@@ -462,8 +456,7 @@ class BaseModel(Generic[BaseModelId]):
             return cls(id_)
 
     @classmethod
-    def all(cls: type[BaseModelInstance],
-            db_conn: DatabaseConnection) -> list[BaseModelInstance]:
+    def all(cls, db_conn: DatabaseConnection) -> list[Self]:
         """Collect all objects of class into list.
 
         Note that this primarily returns the contents of the cache, and only
@@ -471,7 +464,7 @@ class BaseModel(Generic[BaseModelId]):
         cache is always instantly cleaned of any items that would be removed
         from the DB.
         """
-        items: dict[BaseModelId, BaseModelInstance] = {}
+        items: dict[BaseModelId, Self] = {}
         for k, v in cls.get_cache().items():
             items[k] = v
         already_recorded = items.keys()
@@ -483,12 +476,11 @@ class BaseModel(Generic[BaseModelId]):
         return sorted(list(items.values()))
 
     @classmethod
-    def by_date_range_with_limits(cls: type[BaseModelInstance],
+    def by_date_range_with_limits(cls,
                                   db_conn: DatabaseConnection,
                                   date_range: tuple[str, str],
                                   date_col: str = 'day'
-                                  ) -> tuple[list[BaseModelInstance], str,
-                                             str]:
+                                  ) -> tuple[list[Self], str, str]:
         """Return list of items in DB within (closed) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
@@ -507,8 +499,7 @@ class BaseModel(Generic[BaseModelId]):
         return items, start_date, end_date
 
     @classmethod
-    def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
-                 pattern: str) -> list[BaseModelInstance]:
+    def matching(cls, db_conn: DatabaseConnection, pattern: str) -> list[Self]:
         """Return all objects whose .to_search match pattern."""
         items = cls.all(db_conn)
         if pattern:
@@ -544,7 +535,7 @@ class BaseModel(Generic[BaseModelId]):
         table_name = self.table_name
         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES', values)
         if not isinstance(self.id_, str):
-            self.id_ = cursor.lastrowid  # type: ignore[assignment]
+            self.id_ = cursor.lastrowid
         self.cache()
         for attr_name in self.to_save_versioned():
             getattr(self, attr_name).save(db_conn)
index 4a5cb53d62b5345c0e283f6daabf60b4174e37be..b30b22c6e0edb6c92829fc3e5de06755b2a775b9 100644 (file)
@@ -158,8 +158,7 @@ class TaskHandler(BaseHTTPRequestHandler):
 
         def flatten(node: object) -> object:
 
-            def update_library_with(
-                    item: BaseModel[int] | BaseModel[str]) -> None:
+            def update_library_with(item: BaseModel) -> None:
                 cls_name = item.__class__.__name__
                 if cls_name not in library:
                     library[cls_name] = {}
@@ -524,6 +523,7 @@ class TaskHandler(BaseHTTPRequestHandler):
         for process_id in owned_ids:
             Process.by_id(self._conn, process_id)  # to ensure ID exists
             preset_top_step = process_id
+        assert not isinstance(process.id_, str)
         return {'process': process,
                 'is_new': not exists,
                 'preset_top_step': preset_top_step,
@@ -642,7 +642,7 @@ class TaskHandler(BaseHTTPRequestHandler):
         # pylint: disable=too-many-locals
         # pylint: disable=too-many-branches
         # pylint: disable=too-many-statements
-        assert todo.id_ is not None
+        assert isinstance(todo.id_, int)
         adoptees = [(id_, todo.id_) for id_ in self._form.get_all_int('adopt')]
         to_make = {'full': [(id_, todo.id_)
                             for id_ in self._form.get_all_int('make_full')],
index c90519ba41adb4665dfe1ea380065b23c3473c68..e23e97d9febb03899a005128877e32b2084443fc 100644 (file)
@@ -1,6 +1,6 @@
 """Collecting Processes and Process-related items."""
 from __future__ import annotations
-from typing import Set, Any
+from typing import Set, Self, Any
 from sqlite3 import Row
 from plomtask.misc import DictableNode
 from plomtask.db import DatabaseConnection, BaseModel
@@ -23,7 +23,7 @@ class ProcessStepsNode(DictableNode):
                 'is_suppressed']
 
 
-class Process(BaseModel[int], ConditionsRelations):
+class Process(BaseModel, ConditionsRelations):
     """Template for, and metadata for, Todos, and their arrangements."""
     # pylint: disable=too-many-instance-attributes
     table_name = 'processes'
@@ -44,7 +44,7 @@ class Process(BaseModel[int], ConditionsRelations):
                'title': lambda p: p.title.newest}
 
     def __init__(self, id_: int | None, calendarize: bool = False) -> None:
-        BaseModel.__init__(self, id_)
+        super().__init__(id_)
         ConditionsRelations.__init__(self)
         for name in ['title', 'description', 'effort']:
             attr = VersionedAttribute(self, f'process_{name}s',
@@ -56,8 +56,8 @@ class Process(BaseModel[int], ConditionsRelations):
         self.n_owners: int | None = None  # only set by from_table_row
 
     @classmethod
-    def from_table_row(cls, db_conn: DatabaseConnection,
-                       row: Row | list[Any]) -> Process:
+    def from_table_row(cls, db_conn: DatabaseConnection, row: Row | list[Any]
+                       ) -> Self:
         """Make from DB row, with dependencies."""
         process = super().from_table_row(db_conn, row)
         assert process.id_ is not None
@@ -81,7 +81,7 @@ class Process(BaseModel[int], ConditionsRelations):
         process.n_owners = len(process.used_as_step_by(db_conn))
         return process
 
-    def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]:
+    def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Self]:
         """Return Processes using self for a ProcessStep."""
         if not self.id_:
             return []
@@ -93,7 +93,7 @@ class Process(BaseModel[int], ConditionsRelations):
 
     def get_steps(self,
                   db_conn: DatabaseConnection,
-                  external_owner: Process | None = None
+                  external_owner: Self | None = None
                   ) -> list[ProcessStepsNode]:
         """Return tree of depended-on explicit and implicit ProcessSteps."""
 
@@ -163,7 +163,7 @@ class Process(BaseModel[int], ConditionsRelations):
         owners_old = self.used_as_step_by(db_conn)
         losers = [o for o in owners_old if o.id_ not in owner_ids]
         owners_old_ids = [o.id_ for o in owners_old]
-        winners = [Process.by_id(db_conn, id_) for id_ in owner_ids
+        winners = [self.by_id(db_conn, id_) for id_ in owner_ids
                    if id_ not in owners_old_ids]
         steps_to_remove = []
         for loser in losers:
@@ -244,7 +244,7 @@ class Process(BaseModel[int], ConditionsRelations):
         super().remove(db_conn)
 
 
-class ProcessStep(BaseModel[int]):
+class ProcessStep(BaseModel):
     """Sub-unit of Processes."""
     table_name = 'process_steps'
     to_save_simples = ['owner_id', 'step_process_id', 'parent_step_id']
index 03881b5ea220085584c0f8387d3d03d1dfe7765f..5a71400406578bb8720bf372076ded2a876ee36b 100644 (file)
@@ -1,6 +1,6 @@
 """Actionables."""
 from __future__ import annotations
-from typing import Any, Set
+from typing import Any, Self, Set
 from sqlite3 import Row
 from plomtask.misc import DictableNode
 from plomtask.db import DatabaseConnection, BaseModel
@@ -32,7 +32,7 @@ class TodoOrProcStepNode(DictableNode):
     _to_dict = ['node_id', 'todo', 'process', 'children', 'fillable']
 
 
-class Todo(BaseModel[int], ConditionsRelations):
+class Todo(BaseModel, ConditionsRelations):
     """Individual actionable."""
     # pylint: disable=too-many-instance-attributes
     # pylint: disable=too-many-public-methods
@@ -61,7 +61,7 @@ class Todo(BaseModel[int], ConditionsRelations):
                  date: str, comment: str = '',
                  effort: None | float = None,
                  calendarize: bool = False) -> None:
-        BaseModel.__init__(self, id_)
+        super().__init__(id_)
         ConditionsRelations.__init__(self)
         if process.id_ is None:
             raise NotFoundException('Process of Todo without ID (not saved?)')
@@ -82,7 +82,7 @@ class Todo(BaseModel[int], ConditionsRelations):
 
     @classmethod
     def by_date_range(cls, db_conn: DatabaseConnection,
-                      date_range: tuple[str, str] = ('', '')) -> list[Todo]:
+                      date_range: tuple[str, str] = ('', '')) -> list[Self]:
         """Collect Todos of Days within date_range."""
         todos, _, _ = cls.by_date_range_with_limits(db_conn, date_range)
         return todos
@@ -90,8 +90,8 @@ class Todo(BaseModel[int], ConditionsRelations):
     def ensure_children(self, db_conn: DatabaseConnection) -> None:
         """Ensure Todo children (create or adopt) demanded by Process chain."""
 
-        def walk_steps(parent: Todo, step_node: ProcessStepsNode) -> Todo:
-            adoptables = [t for t in Todo.by_date(db_conn, parent.date)
+        def walk_steps(parent: Self, step_node: ProcessStepsNode) -> Todo:
+            adoptables = [t for t in self.by_date(db_conn, parent.date)
                           if (t not in parent.children)
                           and (t != parent)
                           and step_node.process.id_ == t.process_id]
@@ -100,7 +100,8 @@ class Todo(BaseModel[int], ConditionsRelations):
                 satisfier = adoptable
                 break
             if not satisfier:
-                satisfier = Todo(None, step_node.process, False, parent.date)
+                satisfier = self.__class__(None, step_node.process, False,
+                                           parent.date)
                 satisfier.save(db_conn)
             sub_step_nodes = sorted(
                     step_node.steps,
@@ -129,7 +130,7 @@ class Todo(BaseModel[int], ConditionsRelations):
 
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
-                       row: Row | list[Any]) -> Todo:
+                       row: Row | list[Any]) -> Self:
         """Make from DB row, with dependencies."""
         if row[1] == 0:
             raise NotFoundException('calling Todo of '
@@ -154,12 +155,12 @@ class Todo(BaseModel[int], ConditionsRelations):
 
     @classmethod
     def by_process_id(cls, db_conn: DatabaseConnection,
-                      process_id: int | None) -> list[Todo]:
+                      process_id: int | None) -> list[Self]:
         """Collect all Todos of Process of process_id."""
         return [t for t in cls.all(db_conn) if t.process.id_ == process_id]
 
     @classmethod
-    def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Todo]:
+    def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Self]:
         """Collect all Todos for Day of date."""
         return cls.by_date_range(db_conn, (date, date))
 
@@ -253,7 +254,7 @@ class Todo(BaseModel[int], ConditionsRelations):
     def get_step_tree(self, seen_todos: set[int]) -> TodoNode:
         """Return tree of depended-on Todos."""
 
-        def make_node(todo: Todo) -> TodoNode:
+        def make_node(todo: Self) -> TodoNode:
             children = []
             seen = todo.id_ in seen_todos
             assert isinstance(todo.id_, int)
@@ -268,7 +269,7 @@ class Todo(BaseModel[int], ConditionsRelations):
     def tree_effort(self) -> float:
         """Return sum of performed efforts of self and all descendants."""
 
-        def walk_tree(node: Todo) -> float:
+        def walk_tree(node: Self) -> float:
             local_effort = 0.0
             for child in node.children:
                 local_effort += walk_tree(child)
@@ -276,10 +277,10 @@ class Todo(BaseModel[int], ConditionsRelations):
 
         return walk_tree(self)
 
-    def add_child(self, child: Todo) -> None:
+    def add_child(self, child: Self) -> None:
         """Add child to self.children, avoid recursion, update parenthoods."""
 
-        def walk_steps(node: Todo) -> None:
+        def walk_steps(node: Self) -> None:
             if node.id_ == self.id_:
                 raise BadFormatException('bad child choice causes recursion')
             for child in node.children:
@@ -295,7 +296,7 @@ class Todo(BaseModel[int], ConditionsRelations):
         self.children += [child]
         child.parents += [self]
 
-    def remove_child(self, child: Todo) -> None:
+    def remove_child(self, child: Self) -> None:
         """Remove child from self.children, update counter relations."""
         if child not in self.children:
             raise HandledException('Cannot remove un-parented child.')