From 6b329a28bb4aec8d1846f5cc5402ed6fca5eb3da Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sat, 13 Apr 2024 00:00:58 +0200
Subject: [PATCH] Cache DB objects to ensure we do not accidentally edit
 clones.

---
 plomtask/days.py      | 28 +++++++++++++++-------
 plomtask/db.py        |  6 ++++-
 plomtask/processes.py | 56 +++++++++++++++++++++++++++----------------
 plomtask/todos.py     | 28 +++++++++++++---------
 tests/days.py         | 10 +++++++-
 tests/processes.py    | 45 ++++++++++++++++++++++++----------
 tests/todos.py        | 24 ++++++++++++-------
 7 files changed, 135 insertions(+), 62 deletions(-)

diff --git a/plomtask/days.py b/plomtask/days.py
index afdea33..abfce06 100644
--- a/plomtask/days.py
+++ b/plomtask/days.py
@@ -40,9 +40,11 @@ class Day:
         return self.date < other.date
 
     @classmethod
-    def from_table_row(cls, row: Row) -> Day:
-        """Make Day from database row."""
-        return cls(row[0], row[1])
+    def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Day:
+        """Make Day from database row, write to cache."""
+        day = cls(row[0], row[1])
+        db_conn.cached_days[day.date] = day
+        return day
 
     @classmethod
     def all(cls, db_conn: DatabaseConnection,
@@ -54,9 +56,9 @@ class Day:
         start_date = valid_date(date_range[0] if date_range[0] else min_date)
         end_date = valid_date(date_range[1] if date_range[1] else max_date)
         days = []
-        sql = 'SELECT * FROM days WHERE date >= ? AND date <= ?'
+        sql = 'SELECT date FROM days WHERE date >= ? AND date <= ?'
         for row in db_conn.exec(sql, (start_date, end_date)):
-            days += [cls.from_table_row(row)]
+            days += [cls.by_date(db_conn, row[0])]
         days.sort()
         if fill_gaps and len(days) > 1:
             gapless_days = []
@@ -72,12 +74,19 @@ class Day:
     @classmethod
     def by_date(cls, db_conn: DatabaseConnection,
                 date: str, create: bool = False) -> Day:
-        """Retrieve Day by date if in DB, else return None."""
+        """Retrieve Day by date if in DB (prefer cache), else return None."""
+        if date in db_conn.cached_days.keys():
+            day = db_conn.cached_days[date]
+            assert isinstance(day, Day)
+            return day
         for row in db_conn.exec('SELECT * FROM days WHERE date = ?', (date,)):
-            return cls.from_table_row(row)
+            return cls.from_table_row(db_conn, row)
         if not create:
             raise NotFoundException(f'Day not found for date: {date}')
-        return cls(date)
+        day = cls(date)
+        db_conn.cached_days[date] = day
+        assert isinstance(day, Day)
+        return day
 
     @property
     def weekday(self) -> str:
@@ -97,6 +106,7 @@ class Day:
         return next_datetime.strftime(DATE_FORMAT)
 
     def save(self, db_conn: DatabaseConnection) -> None:
-        """Add (or re-write) self to database."""
+        """Add (or re-write) self to DB and cache."""
         db_conn.exec('REPLACE INTO days VALUES (?, ?)',
                      (self.date, self.comment))
+        db_conn.cached_days[self.date] = self
diff --git a/plomtask/db.py b/plomtask/db.py
index 929a733..01bc3e9 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -2,7 +2,7 @@
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor
-from typing import Any
+from typing import Any, Dict
 from plomtask.exceptions import HandledException
 
 PATH_DB_SCHEMA = 'scripts/init.sql'
@@ -49,6 +49,10 @@ class DatabaseConnection:
     def __init__(self, db_file: DatabaseFile) -> None:
         self.file = db_file
         self.conn = sql_connect(self.file.path)
+        self.cached_todos: Dict[int, Any] = {}
+        self.cached_days: Dict[str, Any] = {}
+        self.cached_process_steps: Dict[int, Any] = {}
+        self.cached_processes: Dict[int, Any] = {}
 
     def commit(self) -> None:
         """Commit SQL transaction."""
diff --git a/plomtask/processes.py b/plomtask/processes.py
index 03fecb2..fe9bd4a 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -19,33 +19,35 @@ class Process:
         self.effort = VersionedAttribute(self, 'effort', 1.0)
         self.explicit_steps: list[ProcessStep] = []
 
-    def __eq__(self, other: object) -> bool:
-        return isinstance(other, self.__class__) and self.id_ == other.id_
-
     @classmethod
-    def from_table_row(cls, row: Row) -> Process:
+    def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Process:
         """Make Process from database row, with empty VersionedAttributes."""
-        return cls(row[0])
+        process = cls(row[0])
+        assert process.id_ is not None
+        db_conn.cached_processes[process.id_] = process
+        return process
 
     @classmethod
     def all(cls, db_conn: DatabaseConnection) -> list[Process]:
         """Collect all Processes and their connected VersionedAttributes."""
         processes = {}
-        for row in db_conn.exec('SELECT * FROM processes'):
-            process = cls.from_table_row(row)
-            processes[process.id_] = process
-        for row in db_conn.exec('SELECT * FROM process_titles'):
-            processes[row[0]].title.history[row[1]] = row[2]
-        for row in db_conn.exec('SELECT * FROM process_descriptions'):
-            processes[row[0]].description.history[row[1]] = row[2]
-        for row in db_conn.exec('SELECT * FROM process_efforts'):
-            processes[row[0]].effort.history[row[1]] = row[2]
+        for id_, process in db_conn.cached_processes.items():
+            processes[id_] = process
+        already_recorded = processes.keys()
+        for row in db_conn.exec('SELECT id FROM processes'):
+            if row[0] not in already_recorded:
+                process = cls.by_id(db_conn, row[0])
+                processes[process.id_] = process
         return list(processes.values())
 
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
               create: bool = False) -> Process:
         """Collect Process, its VersionedAttributes, and its child IDs."""
+        if id_ in db_conn.cached_processes.keys():
+            process = db_conn.cached_processes[id_]
+            assert isinstance(process, Process)
+            return process
         process = None
         for row in db_conn.exec('SELECT * FROM processes '
                                 'WHERE id = ?', (id_,)):
@@ -67,7 +69,9 @@ class Process:
                 process.effort.history[row[1]] = row[2]
             for row in db_conn.exec('SELECT * FROM process_steps '
                                     'WHERE owner_id = ?', (process.id_,)):
-                process.explicit_steps += [ProcessStep.from_table_row(row)]
+                process.explicit_steps += [ProcessStep.from_table_row(db_conn,
+                                                                      row)]
+        assert isinstance(process, Process)
         return process
 
     def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]:
@@ -151,6 +155,8 @@ class Process:
         self.title.save(db_conn)
         self.description.save(db_conn)
         self.effort.save(db_conn)
+        assert self.id_ is not None
+        db_conn.cached_processes[self.id_] = self
 
     def fix_steps(self, db_conn: DatabaseConnection) -> None:
         """Rewrite ProcessSteps from self.explicit_steps.
@@ -184,24 +190,34 @@ class ProcessStep:
         self.parent_step_id = parent_step_id
 
     @classmethod
-    def from_table_row(cls, row: Row) -> ProcessStep:
-        """Make ProcessStep from database row."""
-        return cls(row[0], row[1], row[2], row[3])
+    def from_table_row(cls, db_conn: DatabaseConnection,
+                       row: Row) -> ProcessStep:
+        """Make ProcessStep from database row, store in DB cache."""
+        step = cls(row[0], row[1], row[2], row[3])
+        assert step.id_ is not None
+        db_conn.cached_process_steps[step.id_] = step
+        return step
 
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep:
         """Retrieve ProcessStep by id_, or throw NotFoundException."""
+        if id_ in db_conn.cached_process_steps.keys():
+            step = db_conn.cached_process_steps[id_]
+            assert isinstance(step, ProcessStep)
+            return step
         for row in db_conn.exec('SELECT * FROM process_steps '
                                 'WHERE step_id = ?', (id_,)):
-            return cls.from_table_row(row)
+            return cls.from_table_row(db_conn, row)
         raise NotFoundException(f'found no ProcessStep of ID {id_}')
 
     def save(self, db_conn: DatabaseConnection) -> None:
-        """Save to database."""
+        """Save to database and cache."""
         cursor = db_conn.exec('REPLACE INTO process_steps VALUES (?, ?, ?, ?)',
                               (self.id_, self.owner_id, self.step_process_id,
                                self.parent_step_id))
         self.id_ = cursor.lastrowid
+        assert self.id_ is not None
+        db_conn.cached_process_steps[self.id_] = self
 
 
 class VersionedAttribute:
diff --git a/plomtask/todos.py b/plomtask/todos.py
index 7150f0d..f1d98ad 100644
--- a/plomtask/todos.py
+++ b/plomtask/todos.py
@@ -17,37 +17,43 @@ class Todo:
         self.is_done = is_done
         self.day = day
 
-    def __eq__(self, other: object) -> bool:
-        return isinstance(other, self.__class__) and self.id_ == other.id_
-
     @classmethod
-    def from_table_row(cls, row: Row, db_conn: DatabaseConnection) -> Todo:
-        """Make Todo from database row."""
-        return cls(id_=row[0],
+    def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Todo:
+        """Make Todo from database row, write to DB cache."""
+        todo = cls(id_=row[0],
                    process=Process.by_id(db_conn, row[1]),
                    is_done=row[2],
                    day=Day.by_date(db_conn, row[3]))
+        assert todo.id_ is not None
+        db_conn.cached_todos[todo.id_] = todo
+        return todo
 
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo:
-        """Get Todo of .id_=id_."""
+        """Get Todo of .id_=id_ – from DB cache if possible."""
+        if id_ in db_conn.cached_todos.keys():
+            todo = db_conn.cached_todos[id_]
+            assert isinstance(todo, Todo)
+            return todo
         for row in db_conn.exec('SELECT * FROM todos WHERE id = ?', (id_,)):
-            return cls.from_table_row(row, db_conn)
+            return cls.from_table_row(db_conn, row)
         raise NotFoundException(f'Todo of ID not found: {id_}')
 
     @classmethod
     def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Todo]:
         """Collect all Todos for Day of date."""
         todos = []
-        for row in db_conn.exec('SELECT * FROM todos WHERE day = ?', (date,)):
-            todos += [cls.from_table_row(row, db_conn)]
+        for row in db_conn.exec('SELECT id FROM todos WHERE day = ?', (date,)):
+            todos += [cls.by_id(db_conn, row[0])]
         return todos
 
     def save(self, db_conn: DatabaseConnection) -> None:
-        """Write self to DB."""
+        """Write self to DB and its cache."""
         if self.process.id_ is None:
             raise NotFoundException('Process of Todo without ID (not saved?)')
         cursor = db_conn.exec('REPLACE INTO todos VALUES (?,?,?,?)',
                               (self.id_, self.process.id_,
                                self.is_done, self.day.date))
         self.id_ = cursor.lastrowid
+        assert self.id_ is not None
+        db_conn.cached_todos[self.id_] = self
diff --git a/tests/days.py b/tests/days.py
index 81dbcf5..2e2ef50 100644
--- a/tests/days.py
+++ b/tests/days.py
@@ -33,7 +33,7 @@ class TestsSansDB(TestCase):
 
 
 class TestsWithDB(TestCaseWithDB):
-    """Days module tests not requiring DB setup."""
+    """Tests requiring DB, but not server setup."""
 
     def test_Day_by_date(self) -> None:
         """Test Day.by_date()."""
@@ -86,6 +86,14 @@ class TestsWithDB(TestCaseWithDB):
         self.assertEqual(Day('2024-01-01').prev_date, '2023-12-31')
         self.assertEqual(Day('2023-02-28').next_date, '2023-03-01')
 
+    def test_Day_singularity(self) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        day = Day('2024-01-01')
+        day.save(self.db_conn)
+        retrieved_day = Day.by_date(self.db_conn, '2024-01-01')
+        day.comment = 'foo'
+        self.assertEqual(retrieved_day.comment, 'foo')
+
 
 class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
diff --git a/tests/processes.py b/tests/processes.py
index 87b0b09..bda6275 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -2,7 +2,7 @@
 from unittest import TestCase
 from typing import Any
 from tests.utils import TestCaseWithDB, TestCaseWithServer
-from plomtask.processes import Process
+from plomtask.processes import Process, ProcessStep
 from plomtask.exceptions import NotFoundException, BadFormatException
 
 
@@ -39,17 +39,6 @@ class TestsWithDB(TestCaseWithDB):
         self.assertEqual(p.id_,
                          Process.by_id(self.db_conn, 5, create=False).id_)
 
-    def test_Process_versioned_attributes(self) -> None:
-        """Test behavior of VersionedAttributes on saving (with .title)."""
-        p = Process(None)
-        p.save_without_steps(self.db_conn)
-        p.title.set('named')
-        p_loaded = Process.by_id(self.db_conn, p.id_)
-        self.assertNotEqual(p.title.history, p_loaded.title.history)
-        p.save_without_steps(self.db_conn)
-        p_loaded = Process.by_id(self.db_conn, p.id_)
-        self.assertEqual(p.title.history, p_loaded.title.history)
-
     def test_Process_steps(self) -> None:
         """Test addition, nesting, and non-recursion of ProcessSteps"""
         p_1 = Process(1)
@@ -128,6 +117,38 @@ class TestsWithDB(TestCaseWithDB):
         self.assertEqual({p_1.id_, p_2.id_},
                          set(p.id_ for p in Process.all(self.db_conn)))
 
+    def test_ProcessStep_singularity(self) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        p_1 = Process(None)
+        p_1.save_without_steps(self.db_conn)
+        p_2 = Process(None)
+        p_2.save_without_steps(self.db_conn)
+        assert p_2.id_ is not None
+        step = p_1.add_step(self.db_conn, None, p_2.id_, None)
+        assert step.id_ is not None
+        step_retrieved = ProcessStep.by_id(self.db_conn, step.id_)
+        step.parent_step_id = 99
+        self.assertEqual(step.parent_step_id, step_retrieved.parent_step_id)
+
+    def test_Process_singularity(self) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        p_1 = Process(None)
+        p_1.save_without_steps(self.db_conn)
+        p_2 = Process(None)
+        p_2.save_without_steps(self.db_conn)
+        assert p_2.id_ is not None
+        p_1.add_step(self.db_conn, None, p_2.id_, None)
+        p_retrieved = Process.by_id(self.db_conn, p_1.id_)
+        self.assertEqual(p_1.explicit_steps, p_retrieved.explicit_steps)
+
+    def test_Process_versioned_attributes_singularity(self) -> None:
+        """Test behavior of VersionedAttributes on saving (with .title)."""
+        p = Process(None)
+        p.save_without_steps(self.db_conn)
+        p.title.set('named')
+        p_loaded = Process.by_id(self.db_conn, p.id_)
+        self.assertEqual(p.title.history, p_loaded.title.history)
+
 
 class TestsWithServer(TestCaseWithServer):
     """Module tests against our HTTP server/handler (and database)."""
diff --git a/tests/todos.py b/tests/todos.py
index 8bcd181..db4ad9c 100644
--- a/tests/todos.py
+++ b/tests/todos.py
@@ -7,7 +7,7 @@ from plomtask.exceptions import NotFoundException
 
 
 class TestsWithDB(TestCaseWithDB):
-    """Tests not requiring DB setup."""
+    """Tests requiring DB, but not server setup."""
 
     def test_Todo_by_id(self) -> None:
         """Test creation and findability of Todos."""
@@ -18,9 +18,6 @@ class TestsWithDB(TestCaseWithDB):
             todo.save(self.db_conn)
         process.save_without_steps(self.db_conn)
         todo.save(self.db_conn)
-        with self.assertRaises(NotFoundException):
-            _ = Todo.by_id(self.db_conn, 1)
-        day.save(self.db_conn)
         self.assertEqual(Todo.by_id(self.db_conn, 1), todo)
         with self.assertRaises(NotFoundException):
             self.assertEqual(Todo.by_id(self.db_conn, 0), todo)
@@ -37,14 +34,25 @@ class TestsWithDB(TestCaseWithDB):
         todo1.save(self.db_conn)
         todo2 = Todo(None, process, False, day1)
         todo2.save(self.db_conn)
-        with self.assertRaises(NotFoundException):
-            _ = Todo.by_date(self.db_conn, day1.date)
-        day1.save(self.db_conn)
-        day2.save(self.db_conn)
         self.assertEqual(Todo.by_date(self.db_conn, day1.date), [todo1, todo2])
         self.assertEqual(Todo.by_date(self.db_conn, day2.date), [])
         self.assertEqual(Todo.by_date(self.db_conn, 'foo'), [])
 
+    def test_Todo_singularity(self) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        day = Day('2024-01-01')
+        day.save(self.db_conn)
+        process = Process(None)
+        process.save_without_steps(self.db_conn)
+        todo = Todo(None, process, False, day)
+        todo.save(self.db_conn)
+        retrieved_todo = Todo.by_id(self.db_conn, 1)
+        todo.is_done = True
+        self.assertEqual(retrieved_todo.is_done, True)
+        retrieved_todo = Todo.by_date(self.db_conn, '2024-01-01')[0]
+        retrieved_todo.is_done = False
+        self.assertEqual(todo.is_done, False)
+
 
 class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
-- 
2.30.2