From 5a5d713ce0b223ab2f6ef34c15bb82b614bdda98 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Fri, 19 Apr 2024 04:58:34 +0200
Subject: [PATCH] Refactor models' .by_id().

---
 plomtask/conditions.py | 14 ++++++--------
 plomtask/days.py       | 16 ++++++----------
 plomtask/db.py         | 22 +++++++++++++++++++++-
 plomtask/http.py       |  4 ++--
 plomtask/processes.py  | 18 ++++--------------
 plomtask/todos.py      | 35 +++++++++++++++++++++--------------
 scripts/init.sql       |  6 +++---
 tests/days.py          | 14 +++++++-------
 tests/todos.py         | 12 +++++++-----
 9 files changed, 77 insertions(+), 64 deletions(-)

diff --git a/plomtask/conditions.py b/plomtask/conditions.py
index 9fab77f..80fc13e 100644
--- a/plomtask/conditions.py
+++ b/plomtask/conditions.py
@@ -1,5 +1,6 @@
 """Non-doable elements of ProcessStep/Todo chains."""
 from __future__ import annotations
+from typing import Any
 from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
@@ -20,7 +21,7 @@ class Condition(BaseModel):
 
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
-                       row: Row) -> Condition:
+                       row: Row | list[Any]) -> Condition:
         """Build condition from row, including VersionedAttributes."""
         condition = super().from_table_row(db_conn, row)
         assert isinstance(condition, Condition)
@@ -52,17 +53,14 @@ class Condition(BaseModel):
               create: bool = False) -> Condition:
         """Collect (or create) Condition and its VersionedAttributes."""
         condition = None
-        if id_ in db_conn.cached_conditions.keys():
-            condition = db_conn.cached_conditions[id_]
-        else:
-            for row in db_conn.exec('SELECT * FROM conditions WHERE id = ?',
-                                    (id_,)):
-                condition = cls.from_table_row(db_conn, row)
-                break
+        if id_:
+            condition, _ = super()._by_id(db_conn, id_)
         if not condition:
             if not create:
                 raise NotFoundException(f'Condition not found of id: {id_}')
             condition = cls(id_, False)
+            condition.save(db_conn)
+        assert isinstance(condition, Condition)
         return condition
 
     def save(self, db_conn: DatabaseConnection) -> None:
diff --git a/plomtask/days.py b/plomtask/days.py
index a21b4ef..d838039 100644
--- a/plomtask/days.py
+++ b/plomtask/days.py
@@ -51,9 +51,9 @@ class Day(BaseModel):
         start_date = valid_date(date_range[0] if date_range[0] else min_date)
         end_date = valid_date(date_range[1] if date_range[1] else max_date)
         days = []
-        sql = 'SELECT date FROM days WHERE date >= ? AND date <= ?'
+        sql = 'SELECT id FROM days WHERE id >= ? AND id <= ?'
         for row in db_conn.exec(sql, (start_date, end_date)):
-            days += [cls.by_date(db_conn, row[0])]
+            days += [cls.by_id(db_conn, row[0])]
         days.sort()
         if fill_gaps and len(days) > 1:
             gapless_days = []
@@ -67,15 +67,11 @@ class Day(BaseModel):
         return days
 
     @classmethod
-    def by_date(cls, db_conn: DatabaseConnection,
-                date: str, create: bool = False) -> Day:
+    def by_id(cls, db_conn: DatabaseConnection,
+              date: str, create: bool = False) -> Day:
         """Retrieve Day by date if in DB (prefer cache), else return None."""
-        if date in db_conn.cached_days.keys():
-            day = db_conn.cached_days[date]
-            assert isinstance(day, Day)
-            return day
-        for row in db_conn.exec('SELECT * FROM days WHERE date = ?', (date,)):
-            day = cls.from_table_row(db_conn, row)
+        day, _ = super()._by_id(db_conn, date)
+        if day:
             assert isinstance(day, Day)
             return day
         if not create:
diff --git a/plomtask/db.py b/plomtask/db.py
index 2cc1d64..abd8f61 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -76,7 +76,8 @@ class BaseModel:
     id_type: type[Any] = int
 
     @classmethod
-    def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Any:
+    def from_table_row(cls, db_conn: DatabaseConnection,
+                       row: Row | list[Any]) -> Any:
         """Make from DB row, write to DB cache."""
         obj = cls(*row)
         assert isinstance(obj.id_, cls.id_type)
@@ -84,6 +85,25 @@ class BaseModel:
         cache[obj.id_] = obj
         return obj
 
+    @classmethod
+    def _by_id(cls,
+               db_conn: DatabaseConnection,
+               id_: int | str) -> tuple[Any, bool]:
+        """Return instance found by ID, or None, and if from cache or not."""
+        from_cache = False
+        obj = None
+        cache = getattr(db_conn, f'cached_{cls.table_name}')
+        if id_ in cache.keys():
+            obj = cache[id_]
+            from_cache = True
+        else:
+            for row in db_conn.exec(f'SELECT * FROM {cls.table_name} '
+                                    'WHERE id = ?', (id_,)):
+                obj = cls.from_table_row(db_conn, row)
+                cache[id_] = obj
+                break
+        return obj, from_cache
+
     def set_int_id(self, id_: int | None) -> None:
         """Set id_ if >= 1 or None, else fail."""
         if (id_ is not None) and id_ < 1:
diff --git a/plomtask/http.py b/plomtask/http.py
index 55120ff..5f739b6 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -127,7 +127,7 @@ class TaskHandler(BaseHTTPRequestHandler):
                     'condition': condition,
                     'enablers': enablers,
                     'disablers': disablers}]
-        return {'day': Day.by_date(self.conn, date, create=True),
+        return {'day': Day.by_id(self.conn, date, create=True),
                 'todos': Todo.by_date(self.conn, date),
                 'processes': Process.all(self.conn),
                 'conditions_listing': conditions_listing}
@@ -187,7 +187,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_POST_day(self) -> None:
         """Update or insert Day of date and Todos mapped to it."""
         date = self.params.get_str('date')
-        day = Day.by_date(self.conn, date, create=True)
+        day = Day.by_id(self.conn, date, create=True)
         day.comment = self.form_data.get_str('comment')
         day.save(self.conn)
         process_id = self.form_data.get_int_or_none('new_todo')
diff --git a/plomtask/processes.py b/plomtask/processes.py
index 2f8c2d5..7872c33 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -40,15 +40,9 @@ class Process(BaseModel):
     def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
               create: bool = False) -> Process:
         """Collect Process, its VersionedAttributes, and its child IDs."""
-        if id_ in db_conn.cached_processes.keys():
-            process = db_conn.cached_processes[id_]
-            assert isinstance(process, Process)
-            return process
         process = None
-        for row in db_conn.exec('SELECT * FROM processes '
-                                'WHERE id = ?', (id_,)):
-            process = cls(row[0])
-            break
+        if id_:
+            process, _ = super()._by_id(db_conn, id_)
         if not process:
             if not create:
                 raise NotFoundException(f'Process not found of id: {id_}')
@@ -230,14 +224,10 @@ class ProcessStep(BaseModel):
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep:
         """Retrieve ProcessStep by id_, or throw NotFoundException."""
-        if id_ in db_conn.cached_process_steps.keys():
-            step = db_conn.cached_process_steps[id_]
+        step, _ = super()._by_id(db_conn, id_)
+        if step:
             assert isinstance(step, ProcessStep)
             return step
-        for row in db_conn.exec('SELECT * FROM process_steps '
-                                'WHERE step_id = ?', (id_,)):
-            step = cls.from_table_row(db_conn, row)
-            assert isinstance(step, ProcessStep)
         raise NotFoundException(f'found no ProcessStep of ID {id_}')
 
     def save(self, db_conn: DatabaseConnection) -> None:
diff --git a/plomtask/todos.py b/plomtask/todos.py
index cfac5b5..840c298 100644
--- a/plomtask/todos.py
+++ b/plomtask/todos.py
@@ -1,5 +1,7 @@
 """Actionables."""
 from __future__ import annotations
+from typing import Any
+from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.processes import Process
 from plomtask.conditions import Condition
@@ -31,24 +33,29 @@ class Todo(BaseModel):
             self.fulfills = process.fulfills[:]
             self.undoes = process.undoes[:]
 
+    @classmethod
+    def from_table_row(cls, db_conn: DatabaseConnection,
+                       row: Row | list[Any]) -> Todo:
+        """Make from DB row, write to DB cache."""
+        if row[1] == 0:
+            raise NotFoundException('calling Todo of '
+                                    'unsaved Process')
+        row_as_list = list(row)
+        row_as_list[1] = Process.by_id(db_conn, row[1])
+        todo = super().from_table_row(db_conn, row_as_list)
+        assert isinstance(todo, Todo)
+        return todo
+
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int | None) -> Todo:
         """Get Todo of .id_=id_ and children (from DB cache if possible)."""
-        if id_ in db_conn.cached_todos.keys():
-            todo = db_conn.cached_todos[id_]
+        if id_:
+            todo, from_cache = super()._by_id(db_conn, id_)
         else:
-            todo = None
-            for row in db_conn.exec('SELECT * FROM todos WHERE id = ?',
-                                    (id_,)):
-                row = list(row)
-                if row[1] == 0:
-                    raise NotFoundException('calling Todo of '
-                                            'unsaved Process')
-                row[1] = Process.by_id(db_conn, row[1])
-                todo = cls.from_table_row(db_conn, row)
-                break
-            if todo is None:
-                raise NotFoundException(f'Todo of ID not found: {id_}')
+            todo, from_cache = None, False
+        if todo is None:
+            raise NotFoundException(f'Todo of ID not found: {id_}')
+        if not from_cache:
             for row in db_conn.exec('SELECT child FROM todo_children '
                                     'WHERE parent = ?', (id_,)):
                 todo.children += [cls.by_id(db_conn, row[0])]
diff --git a/scripts/init.sql b/scripts/init.sql
index 807e1e7..870e845 100644
--- a/scripts/init.sql
+++ b/scripts/init.sql
@@ -17,7 +17,7 @@ CREATE TABLE conditions (
     is_active BOOLEAN NOT NULL
 );
 CREATE TABLE days (
-    date TEXT PRIMARY KEY,
+    id TEXT PRIMARY KEY,
     comment TEXT NOT NULL
 );
 CREATE TABLE process_conditions (
@@ -49,7 +49,7 @@ CREATE TABLE process_fulfills (
     FOREIGN KEY (condition) REFERENCES conditions(id)
 );
 CREATE TABLE process_steps (
-    step_id INTEGER PRIMARY KEY,
+    id INTEGER PRIMARY KEY,
     owner_id INTEGER NOT NULL,
     step_process_id INTEGER NOT NULL,
     parent_step_id INTEGER,
@@ -108,5 +108,5 @@ CREATE TABLE todos (
     is_done BOOLEAN NOT NULL,
     day TEXT NOT NULL,
     FOREIGN KEY (process_id) REFERENCES processes(id),
-    FOREIGN KEY (day) REFERENCES days(date)
+    FOREIGN KEY (day) REFERENCES days(id)
 );
diff --git a/tests/days.py b/tests/days.py
index 3524a66..895f59d 100644
--- a/tests/days.py
+++ b/tests/days.py
@@ -35,17 +35,17 @@ class TestsSansDB(TestCase):
 class TestsWithDB(TestCaseWithDB):
     """Tests requiring DB, but not server setup."""
 
-    def test_Day_by_date(self) -> None:
-        """Test Day.by_date()."""
+    def test_Day_by_id(self) -> None:
+        """Test Day.by_id()."""
         with self.assertRaises(NotFoundException):
-            Day.by_date(self.db_conn, '2024-01-01')
+            Day.by_id(self.db_conn, '2024-01-01')
         Day('2024-01-01').save(self.db_conn)
         self.assertEqual(Day('2024-01-01'),
-                         Day.by_date(self.db_conn, '2024-01-01'))
+                         Day.by_id(self.db_conn, '2024-01-01'))
         with self.assertRaises(NotFoundException):
-            Day.by_date(self.db_conn, '2024-01-02')
+            Day.by_id(self.db_conn, '2024-01-02')
         self.assertEqual(Day('2024-01-02'),
-                         Day.by_date(self.db_conn, '2024-01-02', create=True))
+                         Day.by_id(self.db_conn, '2024-01-02', create=True))
 
     def test_Day_all(self) -> None:
         """Test Day.all(), especially in regards to date range filtering."""
@@ -94,7 +94,7 @@ class TestsWithDB(TestCaseWithDB):
         """Test pointers made for single object keep pointing to it."""
         day = Day('2024-01-01')
         day.save(self.db_conn)
-        retrieved_day = Day.by_date(self.db_conn, '2024-01-01')
+        retrieved_day = Day.by_id(self.db_conn, '2024-01-01')
         day.comment = 'foo'
         self.assertEqual(retrieved_day.comment, 'foo')
 
diff --git a/tests/todos.py b/tests/todos.py
index a90f466..17454c5 100644
--- a/tests/todos.py
+++ b/tests/todos.py
@@ -121,19 +121,21 @@ class TestsWithDB(TestCaseWithDB):
         """Test Todo.children relations."""
         todo_1 = Todo(None, self.proc, False, self.date1)
         todo_2 = Todo(None, self.proc, False, self.date1)
+        todo_2.save(self.db_conn)
         with self.assertRaises(HandledException):
             todo_1.add_child(todo_2)
         todo_1.save(self.db_conn)
+        todo_3 = Todo(None, self.proc, False, self.date1)
         with self.assertRaises(HandledException):
-            todo_1.add_child(todo_2)
-        todo_2.save(self.db_conn)
-        todo_1.add_child(todo_2)
+            todo_1.add_child(todo_3)
+        todo_3.save(self.db_conn)
+        todo_1.add_child(todo_3)
         todo_1.save(self.db_conn)
         assert isinstance(todo_1.id_, int)
         todo_retrieved = Todo.by_id(self.db_conn, todo_1.id_)
-        self.assertEqual(todo_retrieved.children, [todo_2])
+        self.assertEqual(todo_retrieved.children, [todo_3])
         with self.assertRaises(BadFormatException):
-            todo_2.add_child(todo_1)
+            todo_3.add_child(todo_1)
 
     def test_Todo_conditioning(self) -> None:
         """Test Todo.doability conditions."""
-- 
2.30.2