From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 25 Apr 2024 03:38:31 +0000 (+0200)
Subject: Re-write caching.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/%7B%7Bprefix%7D%7D/static/bar%20baz.html?a=commitdiff_plain;h=ac5a85f6d0186d714415ce7e2b51597bf5dca248;p=plomtask

Re-write caching.
---

diff --git a/plomtask/conditions.py b/plomtask/conditions.py
index cd147cb..4b01249 100644
--- a/plomtask/conditions.py
+++ b/plomtask/conditions.py
@@ -7,7 +7,7 @@ from plomtask.misc import VersionedAttribute
 from plomtask.exceptions import NotFoundException
 
 
-class Condition(BaseModel):
+class Condition(BaseModel[int]):
     """Non Process-dependency for ProcessSteps and Todos."""
     table_name = 'conditions'
     to_save = ['is_active']
@@ -35,12 +35,13 @@ class Condition(BaseModel):
     def all(cls, db_conn: DatabaseConnection) -> list[Condition]:
         """Collect all Conditions and their VersionedAttributes."""
         conditions = {}
-        for id_, condition in db_conn.cached_conditions.items():
+        for id_, condition in cls.cache_.items():
             conditions[id_] = condition
         already_recorded = conditions.keys()
         for id_ in db_conn.column_all('conditions', 'id'):
             if id_ not in already_recorded:
                 condition = cls.by_id(db_conn, id_)
+                assert isinstance(condition.id_, int)
                 conditions[condition.id_] = condition
         return list(conditions.values())
 
@@ -65,7 +66,6 @@ class Condition(BaseModel):
         self.title.save(db_conn)
         self.description.save(db_conn)
         assert isinstance(self.id_, int)
-        db_conn.cached_conditions[self.id_] = self
 
 
 class ConditionsRelations:
diff --git a/plomtask/days.py b/plomtask/days.py
index d838039..5fe984b 100644
--- a/plomtask/days.py
+++ b/plomtask/days.py
@@ -24,7 +24,7 @@ def todays_date() -> str:
     return datetime.now().strftime(DATE_FORMAT)
 
 
-class Day(BaseModel):
+class Day(BaseModel[str]):
     """Individual days defined by their dates."""
     table_name = 'days'
     to_save = ['comment']
@@ -77,7 +77,7 @@ class Day(BaseModel):
         if not create:
             raise NotFoundException(f'Day not found for date: {date}')
         day = cls(date)
-        db_conn.cached_days[date] = day
+        day.cache()
         assert isinstance(day, Day)
         return day
 
diff --git a/plomtask/db.py b/plomtask/db.py
index dd2ee24..e8a542e 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -1,8 +1,9 @@
 """Database management."""
+from __future__ import annotations
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Dict
+from typing import Any, Self, TypeVar, Generic
 from plomtask.exceptions import HandledException
 
 PATH_DB_SCHEMA = 'scripts/init.sql'
@@ -61,11 +62,6 @@ class DatabaseConnection:
     def __init__(self, db_file: DatabaseFile) -> None:
         self.file = db_file
         self.conn = sql_connect(self.file.path)
-        self.cached_todos: Dict[int, Any] = {}
-        self.cached_days: Dict[str, Any] = {}
-        self.cached_process_steps: Dict[int, Any] = {}
-        self.cached_processes: Dict[int, Any] = {}
-        self.cached_conditions: Dict[int, Any] = {}
 
     def commit(self) -> None:
         """Commit SQL transaction."""
@@ -116,38 +112,41 @@ class DatabaseConnection:
         return '(' + ','.join(['?'] * len(values)) + ')'
 
 
-class BaseModel:
+X = TypeVar('X', int, str)
+T = TypeVar('T', bound='BaseModel[Any]')
+
+
+class BaseModel(Generic[X]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
     to_save: list[str] = []
-    id_: None | int | str
-    id_type: type[Any] = int
+    id_: None | X
+    cache_: dict[X, Self] = {}
 
     @classmethod
-    def from_table_row(cls, db_conn: DatabaseConnection,
-                       row: Row | list[Any]) -> Any:
+    def from_table_row(cls: type[T],
+                       # pylint: disable=unused-argument
+                       db_conn: DatabaseConnection,
+                       row: Row | list[Any]) -> T:
         """Make from DB row, write to DB cache."""
         obj = cls(*row)
-        assert isinstance(obj.id_, cls.id_type)
-        cache = getattr(db_conn, f'cached_{cls.table_name}')
-        cache[obj.id_] = obj
+        obj.cache()
         return obj
 
     @classmethod
     def _by_id(cls,
                db_conn: DatabaseConnection,
-               id_: int | str) -> tuple[Any, bool]:
+               id_: X) -> tuple[Self | None, bool]:
         """Return instance found by ID, or None, and if from cache or not."""
         from_cache = False
-        obj = None
-        cache = getattr(db_conn, f'cached_{cls.table_name}')
-        if id_ in cache.keys():
-            obj = cache[id_]
+        obj = cls.get_cached(id_)
+        if obj:
             from_cache = True
         else:
             for row in db_conn.row_where(cls.table_name, 'id', id_):
                 obj = cls.from_table_row(db_conn, row)
-                cache[id_] = obj
+                assert isinstance(obj, cls)
+                obj.cache()
                 break
         return obj, from_cache
 
@@ -156,7 +155,7 @@ class BaseModel:
         if (id_ is not None) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
-        self.id_ = id_
+        self.id_ = id_  # type: ignore[assignment]
 
     def save_core(self, db_conn: DatabaseConnection,
                   update_with_lastrowid: bool = True) -> None:
@@ -168,6 +167,32 @@ class BaseModel:
         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
                               values)
         if update_with_lastrowid:
-            self.id_ = cursor.lastrowid
-        cache = getattr(db_conn, f'cached_{table_name}')
-        cache[self.id_] = self
+            self.id_ = cursor.lastrowid  # type: ignore[assignment]
+        self.cache()
+
+    @classmethod
+    def get_cached(cls: type[T], id_: X) -> T | None:
+        """Get object of id_ from class's cache, or None if not found."""
+        # pylint: disable=consider-iterating-dictionary
+        if id_ in cls.cache_.keys():
+            obj = cls.cache_[id_]
+            assert isinstance(obj, cls)
+            return obj
+        return None
+
+    def cache(self) -> None:
+        """Update object in class's cache."""
+        if self.id_ is None:
+            raise HandledException('Cannot cache object without ID.')
+        self.__class__.cache_[self.id_] = self
+
+    def uncache(self) -> None:
+        """Remove self from cache."""
+        if self.id_ is None:
+            raise HandledException('Cannot un-cache object without ID.')
+        del self.__class__.cache_[self.id_]
+
+    @classmethod
+    def empty_cache(cls) -> None:
+        """Empty class's cache."""
+        cls.cache_ = {}
diff --git a/plomtask/processes.py b/plomtask/processes.py
index 6249d48..590c5bc 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -18,7 +18,7 @@ class ProcessStepsNode:
     seen: bool
 
 
-class Process(BaseModel, ConditionsRelations):
+class Process(BaseModel[int], ConditionsRelations):
     """Template for, and metadata for, Todos, and their arrangements."""
     table_name = 'processes'
 
@@ -38,12 +38,13 @@ class Process(BaseModel, ConditionsRelations):
     def all(cls, db_conn: DatabaseConnection) -> list[Process]:
         """Collect all Processes and their connected VersionedAttributes."""
         processes = {}
-        for id_, process in db_conn.cached_processes.items():
+        for id_, process in cls.cache_.items():
             processes[id_] = process
         already_recorded = processes.keys()
         for id_ in db_conn.column_all('processes', 'id'):
             if id_ not in already_recorded:
                 process = cls.by_id(db_conn, id_)
+                assert isinstance(process.id_, int)
                 processes[process.id_] = process
         return list(processes.values())
 
@@ -165,8 +166,7 @@ class Process(BaseModel, ConditionsRelations):
         """Set self.explicit_steps in bulk."""
         assert isinstance(self.id_, int)
         for step in self.explicit_steps:
-            assert isinstance(step.id_, int)
-            del db_conn.cached_process_steps[step.id_]
+            step.uncache()
         self.explicit_steps = []
         db_conn.delete_where('process_steps', 'owner', self.id_)
         for step_tuple in steps:
@@ -189,10 +189,9 @@ class Process(BaseModel, ConditionsRelations):
         db_conn.delete_where('process_steps', 'owner', self.id_)
         for step in self.explicit_steps:
             step.save(db_conn)
-        db_conn.cached_processes[self.id_] = self
 
 
-class ProcessStep(BaseModel):
+class ProcessStep(BaseModel[int]):
     """Sub-unit of Processes."""
     table_name = 'process_steps'
     to_save = ['owner_id', 'step_process_id', 'parent_step_id']
diff --git a/plomtask/todos.py b/plomtask/todos.py
index 80dc97c..0b42d47 100644
--- a/plomtask/todos.py
+++ b/plomtask/todos.py
@@ -19,7 +19,7 @@ class TodoStepsNode:
     seen: bool
 
 
-class Todo(BaseModel, ConditionsRelations):
+class Todo(BaseModel[int], ConditionsRelations):
     """Individual actionable."""
 
     # pylint: disable=too-many-instance-attributes
@@ -71,6 +71,7 @@ class Todo(BaseModel, ConditionsRelations):
                 todo.parents += [cls.by_id(db_conn, t_id)]
             for name in ('conditions', 'enables', 'disables'):
                 table = f'todo_{name}'
+                assert isinstance(todo.id_, int)
                 for cond_id in db_conn.column_where(table, 'condition',
                                                     'todo', todo.id_):
                     target = getattr(todo, name)
@@ -233,7 +234,6 @@ class Todo(BaseModel, ConditionsRelations):
             raise NotFoundException('Process of Todo without ID (not saved?)')
         self.save_core(db_conn)
         assert isinstance(self.id_, int)
-        db_conn.cached_todos[self.id_] = self
         db_conn.rewrite_relations('todo_children', 'parent', self.id_,
                                   [[c.id_] for c in self.children])
         db_conn.rewrite_relations('todo_conditions', 'todo', self.id_,
diff --git a/tests/todos.py b/tests/todos.py
index 7aed5f8..b47834c 100644
--- a/tests/todos.py
+++ b/tests/todos.py
@@ -31,9 +31,9 @@ class TestsWithDB(TestCaseWithDB):
         todo.save(self.db_conn)
         self.assertEqual(Todo.by_id(self.db_conn, 1), todo)
         with self.assertRaises(NotFoundException):
-            self.assertEqual(Todo.by_id(self.db_conn, 0), todo)
+            Todo.by_id(self.db_conn, 0)
         with self.assertRaises(NotFoundException):
-            self.assertEqual(Todo.by_id(self.db_conn, 2), todo)
+            Todo.by_id(self.db_conn, 2)
 
     def test_Todo_by_date(self) -> None:
         """Test findability of Todos by date."""
@@ -301,7 +301,6 @@ class TestsWithServer(TestCaseWithServer):
         def post_and_reload(form_data: dict[str, object],
                             status: int = 302) -> Todo:
             self.check_post(form_data, '/todo?id=1', status, '/')
-            self.db_conn.cached_todos = {}
             return Todo.by_date(self.db_conn, '2024-01-01')[0]
         # test minimum
         form_data = {'title': '', 'description': '', 'effort': 1}
diff --git a/tests/utils.py b/tests/utils.py
index c80b34d..63b07e9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -8,12 +8,21 @@ from os import remove as remove_file
 from typing import Mapping
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
+from plomtask.processes import Process, ProcessStep
+from plomtask.conditions import Condition
+from plomtask.days import Day
+from plomtask.todos import Todo
 
 
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
 
     def setUp(self) -> None:
+        Condition.empty_cache()
+        Day.empty_cache()
+        Process.empty_cache()
+        ProcessStep.empty_cache()
+        Todo.empty_cache()
         timestamp = datetime.now().timestamp()
         self.db_file = DatabaseFile(f'test_db:{timestamp}')
         self.db_file.remake()