From 8570f4ce4d44b813a1f02b72c5c45a57d2003bae Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 25 Apr 2024 08:40:48 +0200
Subject: [PATCH] Refactor object retrieval and creation.

---
 plomtask/conditions.py | 17 +-----------
 plomtask/days.py       | 16 ++---------
 plomtask/db.py         | 47 +++++++++++++++++++------------
 plomtask/processes.py  | 63 ++++++++++++++++--------------------------
 plomtask/todos.py      | 45 ++++++++++++------------------
 5 files changed, 74 insertions(+), 114 deletions(-)

diff --git a/plomtask/conditions.py b/plomtask/conditions.py
index 8d67e5a..6696125 100644
--- a/plomtask/conditions.py
+++ b/plomtask/conditions.py
@@ -4,7 +4,6 @@ from typing import Any
 from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
-from plomtask.exceptions import NotFoundException
 
 
 class Condition(BaseModel[int]):
@@ -13,7 +12,7 @@ class Condition(BaseModel[int]):
     to_save = ['is_active']
 
     def __init__(self, id_: int | None, is_active: bool = False) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.is_active = is_active
         self.title = VersionedAttribute(self, 'condition_titles', 'UNNAMED')
         self.description = VersionedAttribute(self, 'condition_descriptions',
@@ -30,20 +29,6 @@ class Condition(BaseModel[int]):
                 getattr(condition, name).history_from_row(row_)
         return condition
 
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
-              create: bool = False) -> Condition:
-        """Collect (or create) Condition and its VersionedAttributes."""
-        condition = None
-        if id_:
-            condition, _ = super()._by_id(db_conn, id_)
-        if not condition:
-            if not create:
-                raise NotFoundException(f'Condition not found of id: {id_}')
-            condition = cls(id_, False)
-            condition.save(db_conn)
-        return condition
-
     def save(self, db_conn: DatabaseConnection) -> None:
         """Save self and its VersionedAttributes to DB and cache."""
         self.save_core(db_conn)
diff --git a/plomtask/days.py b/plomtask/days.py
index 258d38d..78340e2 100644
--- a/plomtask/days.py
+++ b/plomtask/days.py
@@ -1,7 +1,7 @@
 """Collecting Day and date-related items."""
 from __future__ import annotations
 from datetime import datetime, timedelta
-from plomtask.exceptions import BadFormatException, NotFoundException
+from plomtask.exceptions import BadFormatException
 from plomtask.db import DatabaseConnection, BaseModel
 
 DATE_FORMAT = '%Y-%m-%d'
@@ -30,6 +30,7 @@ class Day(BaseModel[str]):
     to_save = ['comment']
 
     def __init__(self, date: str, comment: str = '') -> None:
+        super().__init__(date)
         self.id_: str = valid_date(date)
         self.datetime = datetime.strptime(self.date, DATE_FORMAT)
         self.comment = comment
@@ -65,19 +66,6 @@ class Day(BaseModel[str]):
             days = gapless_days
         return days
 
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection,
-              date: str, create: bool = False) -> Day:
-        """Retrieve Day by date if in DB (prefer cache), else return None."""
-        day, _ = super()._by_id(db_conn, date)
-        if day:
-            return day
-        if not create:
-            raise NotFoundException(f'Day not found for date: {date}')
-        day = cls(date)
-        day.cache()
-        return day
-
     @property
     def date(self) -> str:
         """Return self.id_ under the assumption it's a date string."""
diff --git a/plomtask/db.py b/plomtask/db.py
index fe67e5c..ebd8c6c 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -4,7 +4,7 @@ from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from typing import Any, Self, TypeVar, Generic
-from plomtask.exceptions import HandledException
+from plomtask.exceptions import HandledException, NotFoundException
 
 PATH_DB_SCHEMA = 'scripts/init.sql'
 EXPECTED_DB_VERSION = 0
@@ -123,6 +123,12 @@ class BaseModel(Generic[BaseModelId]):
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
 
+    def __init__(self, id_: BaseModelId | None) -> None:
+        if isinstance(id_, int) and id_ < 1:
+            msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
+            raise HandledException(msg)
+        self.id_ = id_
+
     @classmethod
     def get_cached(cls: type[BaseModelInstance],
                    id_: BaseModelId) -> BaseModelInstance | None:
@@ -173,20 +179,32 @@ class BaseModel(Generic[BaseModelId]):
         return obj
 
     @classmethod
-    def _by_id(cls,
-               db_conn: DatabaseConnection,
-               id_: BaseModelId) -> tuple[Self | None, bool]:
+    def _by_id(cls, db_conn: DatabaseConnection,
+               id_: BaseModelId) -> Self | None:
         """Return instance found by ID, or None, and if from cache or not."""
-        from_cache = False
         obj = cls.get_cached(id_)
-        if obj:
-            from_cache = True
-        else:
+        if not obj:
             for row in db_conn.row_where(cls.table_name, 'id', id_):
                 obj = cls.from_table_row(db_conn, row)
                 obj.cache()
                 break
-        return obj, from_cache
+        return obj
+
+    @classmethod
+    def by_id(cls, db_conn: DatabaseConnection,
+              id_: BaseModelId | None,
+              # pylint: disable=unused-argument
+              create: bool = False) -> Self:
+        """Retrieve by id_, on failure throw NotFoundException."""
+        obj = None
+        if id_ is not None:
+            obj = cls._by_id(db_conn, id_)
+        if obj:
+            return obj
+        if create:
+            obj = cls(id_)
+            return obj
+        raise NotFoundException(f'found no object of ID {id_}')
 
     @classmethod
     def all(cls: type[BaseModelInstance],
@@ -199,18 +217,11 @@ class BaseModel(Generic[BaseModelId]):
         already_recorded = items.keys()
         for id_ in db_conn.column_all(cls.table_name, 'id'):
             if id_ not in already_recorded:
-                # pylint: disable=no-member
-                item = cls.by_id(db_conn, id_)  # type: ignore[attr-defined]
+                item = cls.by_id(db_conn, id_)
+                assert item.id_ is not None
                 items[item.id_] = item
         return list(items.values())
 
-    def set_int_id(self, id_: int | None) -> None:
-        """Set id_ if >= 1 or None, else fail."""
-        if (id_ is not None) and id_ < 1:
-            msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
-            raise HandledException(msg)
-        self.id_ = id_  # type: ignore[assignment]
-
     def save_core(self, db_conn: DatabaseConnection,
                   update_with_lastrowid: bool = True) -> None:
         """Write bare-bones self (sans connected items), ensuring self.id_."""
diff --git a/plomtask/processes.py b/plomtask/processes.py
index 9705f17..c0b13b5 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -1,7 +1,8 @@
 """Collecting Processes and Process-related items."""
 from __future__ import annotations
 from dataclasses import dataclass
-from typing import Set
+from typing import Set, Any
+from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
 from plomtask.conditions import Condition, ConditionsRelations
@@ -25,7 +26,7 @@ class Process(BaseModel[int], ConditionsRelations):
     # pylint: disable=too-many-instance-attributes
 
     def __init__(self, id_: int | None) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.title = VersionedAttribute(self, 'process_titles', 'UNNAMED')
         self.description = VersionedAttribute(self, 'process_descriptions', '')
         self.effort = VersionedAttribute(self, 'process_efforts', 1.0)
@@ -35,34 +36,26 @@ class Process(BaseModel[int], ConditionsRelations):
         self.disables: list[Condition] = []
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
-              create: bool = False) -> Process:
-        """Collect Process, its VersionedAttributes, and its child IDs."""
-        process = None
-        from_cache = False
-        if id_:
-            process, from_cache = super()._by_id(db_conn, id_)
-        if not from_cache:
-            if not process:
-                if not create:
-                    raise NotFoundException(f'Process not found of id: {id_}')
-                process = Process(id_)
-            if isinstance(process.id_, int):
-                for name in ('title', 'description', 'effort'):
-                    table = f'process_{name}s'
-                    for row in db_conn.row_where(table, 'parent', process.id_):
-                        getattr(process, name).history_from_row(row)
-                for row in db_conn.row_where('process_steps', 'owner',
-                                             process.id_):
-                    step = ProcessStep.from_table_row(db_conn, row)
-                    process.explicit_steps += [step]
-                for name in ('conditions', 'enables', 'disables'):
-                    table = f'process_{name}'
-                    for c_id in db_conn.column_where(table, 'condition',
-                                                     'process', process.id_):
-                        target = getattr(process, name)
-                        target += [Condition.by_id(db_conn, c_id)]
-        assert isinstance(process, Process)
+    def from_table_row(cls, db_conn: DatabaseConnection,
+                       row: Row | list[Any]) -> Process:
+        """Make from DB row, with dependencies."""
+        process = super().from_table_row(db_conn, row)
+        assert isinstance(process.id_, int)
+        for name in ('title', 'description', 'effort'):
+            table = f'process_{name}s'
+            for row_ in db_conn.row_where(table, 'parent', process.id_):
+                getattr(process, name).history_from_row(row_)
+        for row_ in db_conn.row_where('process_steps', 'owner',
+                                      process.id_):
+            step = ProcessStep.from_table_row(db_conn, row_)
+            process.explicit_steps += [step]  # pylint: disable=no-member
+        for name in ('conditions', 'enables', 'disables'):
+            table = f'process_{name}'
+            assert isinstance(process.id_, int)
+            for c_id in db_conn.column_where(table, 'condition',
+                                             'process', process.id_):
+                target = getattr(process, name)
+                target += [Condition.by_id(db_conn, c_id)]
         return process
 
     def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]:
@@ -184,19 +177,11 @@ class ProcessStep(BaseModel[int]):
 
     def __init__(self, id_: int | None, owner_id: int, step_process_id: int,
                  parent_step_id: int | None) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.owner_id = owner_id
         self.step_process_id = step_process_id
         self.parent_step_id = parent_step_id
 
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep:
-        """Retrieve ProcessStep by id_, or throw NotFoundException."""
-        step, _ = super()._by_id(db_conn, id_)
-        if step:
-            return step
-        raise NotFoundException(f'found no ProcessStep of ID {id_}')
-
     def save(self, db_conn: DatabaseConnection) -> None:
         """Default to simply calling self.save_core for simple cases."""
         self.save_core(db_conn)
diff --git a/plomtask/todos.py b/plomtask/todos.py
index 5901571..9b9bc0b 100644
--- a/plomtask/todos.py
+++ b/plomtask/todos.py
@@ -21,15 +21,13 @@ class TodoStepsNode:
 
 class Todo(BaseModel[int], ConditionsRelations):
     """Individual actionable."""
-
     # pylint: disable=too-many-instance-attributes
-
     table_name = 'todos'
     to_save = ['process_id', 'is_done', 'date']
 
     def __init__(self, id_: int | None, process: Process,
                  is_done: bool, date: str) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.process = process
         self._is_done = is_done
         self.date = date
@@ -46,36 +44,29 @@ class Todo(BaseModel[int], ConditionsRelations):
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> Todo:
-        """Make from DB row, write to DB cache."""
+        """Make from DB row, with dependencies."""
         if row[1] == 0:
             raise NotFoundException('calling Todo of '
                                     'unsaved Process')
         row_as_list = list(row)
         row_as_list[1] = Process.by_id(db_conn, row[1])
         todo = super().from_table_row(db_conn, row_as_list)
-        assert isinstance(todo, Todo)
-        return todo
-
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo:
-        """Get Todo of .id_=id_ and children (from DB cache if possible)."""
-        todo, from_cache = super()._by_id(db_conn, id_)
-        if todo is None:
-            raise NotFoundException(f'Todo of ID not found: {id_}')
-        if not from_cache:
-            for t_id in db_conn.column_where('todo_children', 'child',
-                                             'parent', id_):
-                todo.children += [cls.by_id(db_conn, t_id)]
-            for t_id in db_conn.column_where('todo_children', 'parent',
-                                             'child', id_):
-                todo.parents += [cls.by_id(db_conn, t_id)]
-            for name in ('conditions', 'enables', 'disables'):
-                table = f'todo_{name}'
-                assert isinstance(todo.id_, int)
-                for cond_id in db_conn.column_where(table, 'condition',
-                                                    'todo', todo.id_):
-                    target = getattr(todo, name)
-                    target += [Condition.by_id(db_conn, cond_id)]
+        assert isinstance(todo.id_, int)
+        for t_id in db_conn.column_where('todo_children', 'child',
+                                         'parent', todo.id_):
+            # pylint: disable=no-member
+            todo.children += [cls.by_id(db_conn, t_id)]
+        for t_id in db_conn.column_where('todo_children', 'parent',
+                                         'child', todo.id_):
+            # pylint: disable=no-member
+            todo.parents += [cls.by_id(db_conn, t_id)]
+        for name in ('conditions', 'enables', 'disables'):
+            table = f'todo_{name}'
+            assert isinstance(todo.id_, int)
+            for cond_id in db_conn.column_where(table, 'condition',
+                                                'todo', todo.id_):
+                target = getattr(todo, name)
+                target += [Condition.by_id(db_conn, cond_id)]
         return todo
 
     @classmethod
-- 
2.30.2