home · contact · privacy
Refactor object retrieval and creation.
authorChristian Heller <c.heller@plomlompom.de>
Thu, 25 Apr 2024 06:40:48 +0000 (08:40 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Thu, 25 Apr 2024 06:40:48 +0000 (08:40 +0200)
plomtask/conditions.py
plomtask/days.py
plomtask/db.py
plomtask/processes.py
plomtask/todos.py

index 8d67e5a78c85b1672aaaaac133ab0422e80a6d73..66961256e28ad2843ae9d72ba9babbbd552b6f60 100644 (file)
@@ -4,7 +4,6 @@ from typing import Any
 from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
-from plomtask.exceptions import NotFoundException
 
 
 class Condition(BaseModel[int]):
@@ -13,7 +12,7 @@ class Condition(BaseModel[int]):
     to_save = ['is_active']
 
     def __init__(self, id_: int | None, is_active: bool = False) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.is_active = is_active
         self.title = VersionedAttribute(self, 'condition_titles', 'UNNAMED')
         self.description = VersionedAttribute(self, 'condition_descriptions',
@@ -30,20 +29,6 @@ class Condition(BaseModel[int]):
                 getattr(condition, name).history_from_row(row_)
         return condition
 
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
-              create: bool = False) -> Condition:
-        """Collect (or create) Condition and its VersionedAttributes."""
-        condition = None
-        if id_:
-            condition, _ = super()._by_id(db_conn, id_)
-        if not condition:
-            if not create:
-                raise NotFoundException(f'Condition not found of id: {id_}')
-            condition = cls(id_, False)
-            condition.save(db_conn)
-        return condition
-
     def save(self, db_conn: DatabaseConnection) -> None:
         """Save self and its VersionedAttributes to DB and cache."""
         self.save_core(db_conn)
index 258d38dbbf1d7f920f36ca06c9f8292506d16807..78340e2d13000e7ef49bff1bdc24f27e561993c1 100644 (file)
@@ -1,7 +1,7 @@
 """Collecting Day and date-related items."""
 from __future__ import annotations
 from datetime import datetime, timedelta
-from plomtask.exceptions import BadFormatException, NotFoundException
+from plomtask.exceptions import BadFormatException
 from plomtask.db import DatabaseConnection, BaseModel
 
 DATE_FORMAT = '%Y-%m-%d'
@@ -30,6 +30,7 @@ class Day(BaseModel[str]):
     to_save = ['comment']
 
     def __init__(self, date: str, comment: str = '') -> None:
+        super().__init__(date)
         self.id_: str = valid_date(date)
         self.datetime = datetime.strptime(self.date, DATE_FORMAT)
         self.comment = comment
@@ -65,19 +66,6 @@ class Day(BaseModel[str]):
             days = gapless_days
         return days
 
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection,
-              date: str, create: bool = False) -> Day:
-        """Retrieve Day by date if in DB (prefer cache), else return None."""
-        day, _ = super()._by_id(db_conn, date)
-        if day:
-            return day
-        if not create:
-            raise NotFoundException(f'Day not found for date: {date}')
-        day = cls(date)
-        day.cache()
-        return day
-
     @property
     def date(self) -> str:
         """Return self.id_ under the assumption it's a date string."""
index fe67e5cc799303efada094f510b537616a30bbd0..ebd8c6c544fd9dd3aca7e040bccb3d354629806d 100644 (file)
@@ -4,7 +4,7 @@ from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from typing import Any, Self, TypeVar, Generic
-from plomtask.exceptions import HandledException
+from plomtask.exceptions import HandledException, NotFoundException
 
 PATH_DB_SCHEMA = 'scripts/init.sql'
 EXPECTED_DB_VERSION = 0
@@ -123,6 +123,12 @@ class BaseModel(Generic[BaseModelId]):
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
 
+    def __init__(self, id_: BaseModelId | None) -> None:
+        if isinstance(id_, int) and id_ < 1:
+            msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
+            raise HandledException(msg)
+        self.id_ = id_
+
     @classmethod
     def get_cached(cls: type[BaseModelInstance],
                    id_: BaseModelId) -> BaseModelInstance | None:
@@ -173,20 +179,32 @@ class BaseModel(Generic[BaseModelId]):
         return obj
 
     @classmethod
-    def _by_id(cls,
-               db_conn: DatabaseConnection,
-               id_: BaseModelId) -> tuple[Self | None, bool]:
+    def _by_id(cls, db_conn: DatabaseConnection,
+               id_: BaseModelId) -> Self | None:
         """Return instance found by ID, or None, and if from cache or not."""
-        from_cache = False
         obj = cls.get_cached(id_)
-        if obj:
-            from_cache = True
-        else:
+        if not obj:
             for row in db_conn.row_where(cls.table_name, 'id', id_):
                 obj = cls.from_table_row(db_conn, row)
                 obj.cache()
                 break
-        return obj, from_cache
+        return obj
+
+    @classmethod
+    def by_id(cls, db_conn: DatabaseConnection,
+              id_: BaseModelId | None,
+              # pylint: disable=unused-argument
+              create: bool = False) -> Self:
+        """Retrieve by id_, on failure throw NotFoundException."""
+        obj = None
+        if id_ is not None:
+            obj = cls._by_id(db_conn, id_)
+        if obj:
+            return obj
+        if create:
+            obj = cls(id_)
+            return obj
+        raise NotFoundException(f'found no object of ID {id_}')
 
     @classmethod
     def all(cls: type[BaseModelInstance],
@@ -199,18 +217,11 @@ class BaseModel(Generic[BaseModelId]):
         already_recorded = items.keys()
         for id_ in db_conn.column_all(cls.table_name, 'id'):
             if id_ not in already_recorded:
-                # pylint: disable=no-member
-                item = cls.by_id(db_conn, id_)  # type: ignore[attr-defined]
+                item = cls.by_id(db_conn, id_)
+                assert item.id_ is not None
                 items[item.id_] = item
         return list(items.values())
 
-    def set_int_id(self, id_: int | None) -> None:
-        """Set id_ if >= 1 or None, else fail."""
-        if (id_ is not None) and id_ < 1:
-            msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
-            raise HandledException(msg)
-        self.id_ = id_  # type: ignore[assignment]
-
     def save_core(self, db_conn: DatabaseConnection,
                   update_with_lastrowid: bool = True) -> None:
         """Write bare-bones self (sans connected items), ensuring self.id_."""
index 9705f17a5672336371ba186ff5d52cdd5fe002ef..c0b13b551862b0c8a00c927c4cc0665ae178daeb 100644 (file)
@@ -1,7 +1,8 @@
 """Collecting Processes and Process-related items."""
 from __future__ import annotations
 from dataclasses import dataclass
-from typing import Set
+from typing import Set, Any
+from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
 from plomtask.conditions import Condition, ConditionsRelations
@@ -25,7 +26,7 @@ class Process(BaseModel[int], ConditionsRelations):
     # pylint: disable=too-many-instance-attributes
 
     def __init__(self, id_: int | None) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.title = VersionedAttribute(self, 'process_titles', 'UNNAMED')
         self.description = VersionedAttribute(self, 'process_descriptions', '')
         self.effort = VersionedAttribute(self, 'process_efforts', 1.0)
@@ -35,34 +36,26 @@ class Process(BaseModel[int], ConditionsRelations):
         self.disables: list[Condition] = []
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
-              create: bool = False) -> Process:
-        """Collect Process, its VersionedAttributes, and its child IDs."""
-        process = None
-        from_cache = False
-        if id_:
-            process, from_cache = super()._by_id(db_conn, id_)
-        if not from_cache:
-            if not process:
-                if not create:
-                    raise NotFoundException(f'Process not found of id: {id_}')
-                process = Process(id_)
-            if isinstance(process.id_, int):
-                for name in ('title', 'description', 'effort'):
-                    table = f'process_{name}s'
-                    for row in db_conn.row_where(table, 'parent', process.id_):
-                        getattr(process, name).history_from_row(row)
-                for row in db_conn.row_where('process_steps', 'owner',
-                                             process.id_):
-                    step = ProcessStep.from_table_row(db_conn, row)
-                    process.explicit_steps += [step]
-                for name in ('conditions', 'enables', 'disables'):
-                    table = f'process_{name}'
-                    for c_id in db_conn.column_where(table, 'condition',
-                                                     'process', process.id_):
-                        target = getattr(process, name)
-                        target += [Condition.by_id(db_conn, c_id)]
-        assert isinstance(process, Process)
+    def from_table_row(cls, db_conn: DatabaseConnection,
+                       row: Row | list[Any]) -> Process:
+        """Make from DB row, with dependencies."""
+        process = super().from_table_row(db_conn, row)
+        assert isinstance(process.id_, int)
+        for name in ('title', 'description', 'effort'):
+            table = f'process_{name}s'
+            for row_ in db_conn.row_where(table, 'parent', process.id_):
+                getattr(process, name).history_from_row(row_)
+        for row_ in db_conn.row_where('process_steps', 'owner',
+                                      process.id_):
+            step = ProcessStep.from_table_row(db_conn, row_)
+            process.explicit_steps += [step]  # pylint: disable=no-member
+        for name in ('conditions', 'enables', 'disables'):
+            table = f'process_{name}'
+            assert isinstance(process.id_, int)
+            for c_id in db_conn.column_where(table, 'condition',
+                                             'process', process.id_):
+                target = getattr(process, name)
+                target += [Condition.by_id(db_conn, c_id)]
         return process
 
     def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]:
@@ -184,19 +177,11 @@ class ProcessStep(BaseModel[int]):
 
     def __init__(self, id_: int | None, owner_id: int, step_process_id: int,
                  parent_step_id: int | None) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.owner_id = owner_id
         self.step_process_id = step_process_id
         self.parent_step_id = parent_step_id
 
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep:
-        """Retrieve ProcessStep by id_, or throw NotFoundException."""
-        step, _ = super()._by_id(db_conn, id_)
-        if step:
-            return step
-        raise NotFoundException(f'found no ProcessStep of ID {id_}')
-
     def save(self, db_conn: DatabaseConnection) -> None:
         """Default to simply calling self.save_core for simple cases."""
         self.save_core(db_conn)
index 5901571d6e2e52c1539e48587d6e8ea7590fb1d1..9b9bc0b95527a901bd5a6008e1707f719df74b06 100644 (file)
@@ -21,15 +21,13 @@ class TodoStepsNode:
 
 class Todo(BaseModel[int], ConditionsRelations):
     """Individual actionable."""
-
     # pylint: disable=too-many-instance-attributes
-
     table_name = 'todos'
     to_save = ['process_id', 'is_done', 'date']
 
     def __init__(self, id_: int | None, process: Process,
                  is_done: bool, date: str) -> None:
-        self.set_int_id(id_)
+        super().__init__(id_)
         self.process = process
         self._is_done = is_done
         self.date = date
@@ -46,36 +44,29 @@ class Todo(BaseModel[int], ConditionsRelations):
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> Todo:
-        """Make from DB row, write to DB cache."""
+        """Make from DB row, with dependencies."""
         if row[1] == 0:
             raise NotFoundException('calling Todo of '
                                     'unsaved Process')
         row_as_list = list(row)
         row_as_list[1] = Process.by_id(db_conn, row[1])
         todo = super().from_table_row(db_conn, row_as_list)
-        assert isinstance(todo, Todo)
-        return todo
-
-    @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo:
-        """Get Todo of .id_=id_ and children (from DB cache if possible)."""
-        todo, from_cache = super()._by_id(db_conn, id_)
-        if todo is None:
-            raise NotFoundException(f'Todo of ID not found: {id_}')
-        if not from_cache:
-            for t_id in db_conn.column_where('todo_children', 'child',
-                                             'parent', id_):
-                todo.children += [cls.by_id(db_conn, t_id)]
-            for t_id in db_conn.column_where('todo_children', 'parent',
-                                             'child', id_):
-                todo.parents += [cls.by_id(db_conn, t_id)]
-            for name in ('conditions', 'enables', 'disables'):
-                table = f'todo_{name}'
-                assert isinstance(todo.id_, int)
-                for cond_id in db_conn.column_where(table, 'condition',
-                                                    'todo', todo.id_):
-                    target = getattr(todo, name)
-                    target += [Condition.by_id(db_conn, cond_id)]
+        assert isinstance(todo.id_, int)
+        for t_id in db_conn.column_where('todo_children', 'child',
+                                         'parent', todo.id_):
+            # pylint: disable=no-member
+            todo.children += [cls.by_id(db_conn, t_id)]
+        for t_id in db_conn.column_where('todo_children', 'parent',
+                                         'child', todo.id_):
+            # pylint: disable=no-member
+            todo.parents += [cls.by_id(db_conn, t_id)]
+        for name in ('conditions', 'enables', 'disables'):
+            table = f'todo_{name}'
+            assert isinstance(todo.id_, int)
+            for cond_id in db_conn.column_where(table, 'condition',
+                                                'todo', todo.id_):
+                target = getattr(todo, name)
+                target += [Condition.by_id(db_conn, cond_id)]
         return todo
 
     @classmethod