home · contact · privacy
Re-write caching.
authorChristian Heller <c.heller@plomlompom.de>
Thu, 25 Apr 2024 03:38:31 +0000 (05:38 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Thu, 25 Apr 2024 03:38:31 +0000 (05:38 +0200)
plomtask/conditions.py
plomtask/days.py
plomtask/db.py
plomtask/processes.py
plomtask/todos.py
tests/todos.py
tests/utils.py

index cd147cb79cafcfd81f4672fae0a0b6296201ff6a..4b012491d6151ed74995849378fa8dc1c04f20f8 100644 (file)
@@ -7,7 +7,7 @@ from plomtask.misc import VersionedAttribute
 from plomtask.exceptions import NotFoundException
 
 
-class Condition(BaseModel):
+class Condition(BaseModel[int]):
     """Non Process-dependency for ProcessSteps and Todos."""
     table_name = 'conditions'
     to_save = ['is_active']
@@ -35,12 +35,13 @@ class Condition(BaseModel):
     def all(cls, db_conn: DatabaseConnection) -> list[Condition]:
         """Collect all Conditions and their VersionedAttributes."""
         conditions = {}
-        for id_, condition in db_conn.cached_conditions.items():
+        for id_, condition in cls.cache_.items():
             conditions[id_] = condition
         already_recorded = conditions.keys()
         for id_ in db_conn.column_all('conditions', 'id'):
             if id_ not in already_recorded:
                 condition = cls.by_id(db_conn, id_)
+                assert isinstance(condition.id_, int)
                 conditions[condition.id_] = condition
         return list(conditions.values())
 
@@ -65,7 +66,6 @@ class Condition(BaseModel):
         self.title.save(db_conn)
         self.description.save(db_conn)
         assert isinstance(self.id_, int)
-        db_conn.cached_conditions[self.id_] = self
 
 
 class ConditionsRelations:
index d838039a715677ffb3c9022f82715859d80fd652..5fe984b7c49596121246bb7e822c8b4bb1536139 100644 (file)
@@ -24,7 +24,7 @@ def todays_date() -> str:
     return datetime.now().strftime(DATE_FORMAT)
 
 
-class Day(BaseModel):
+class Day(BaseModel[str]):
     """Individual days defined by their dates."""
     table_name = 'days'
     to_save = ['comment']
@@ -77,7 +77,7 @@ class Day(BaseModel):
         if not create:
             raise NotFoundException(f'Day not found for date: {date}')
         day = cls(date)
-        db_conn.cached_days[date] = day
+        day.cache()
         assert isinstance(day, Day)
         return day
 
index dd2ee2452a7480878c47b38f2669c08cee229ca6..e8a542e6176937a7c205a9a5909a5c776bdbd7db 100644 (file)
@@ -1,8 +1,9 @@
 """Database management."""
+from __future__ import annotations
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Dict
+from typing import Any, Self, TypeVar, Generic
 from plomtask.exceptions import HandledException
 
 PATH_DB_SCHEMA = 'scripts/init.sql'
@@ -61,11 +62,6 @@ class DatabaseConnection:
     def __init__(self, db_file: DatabaseFile) -> None:
         self.file = db_file
         self.conn = sql_connect(self.file.path)
-        self.cached_todos: Dict[int, Any] = {}
-        self.cached_days: Dict[str, Any] = {}
-        self.cached_process_steps: Dict[int, Any] = {}
-        self.cached_processes: Dict[int, Any] = {}
-        self.cached_conditions: Dict[int, Any] = {}
 
     def commit(self) -> None:
         """Commit SQL transaction."""
@@ -116,38 +112,41 @@ class DatabaseConnection:
         return '(' + ','.join(['?'] * len(values)) + ')'
 
 
-class BaseModel:
+X = TypeVar('X', int, str)
+T = TypeVar('T', bound='BaseModel[Any]')
+
+
+class BaseModel(Generic[X]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
     to_save: list[str] = []
-    id_: None | int | str
-    id_type: type[Any] = int
+    id_: None | X
+    cache_: dict[X, Self] = {}
 
     @classmethod
-    def from_table_row(cls, db_conn: DatabaseConnection,
-                       row: Row | list[Any]) -> Any:
+    def from_table_row(cls: type[T],
+                       # pylint: disable=unused-argument
+                       db_conn: DatabaseConnection,
+                       row: Row | list[Any]) -> T:
         """Make from DB row, write to DB cache."""
         obj = cls(*row)
-        assert isinstance(obj.id_, cls.id_type)
-        cache = getattr(db_conn, f'cached_{cls.table_name}')
-        cache[obj.id_] = obj
+        obj.cache()
         return obj
 
     @classmethod
     def _by_id(cls,
                db_conn: DatabaseConnection,
-               id_: int | str) -> tuple[Any, bool]:
+               id_: X) -> tuple[Self | None, bool]:
         """Return instance found by ID, or None, and if from cache or not."""
         from_cache = False
-        obj = None
-        cache = getattr(db_conn, f'cached_{cls.table_name}')
-        if id_ in cache.keys():
-            obj = cache[id_]
+        obj = cls.get_cached(id_)
+        if obj:
             from_cache = True
         else:
             for row in db_conn.row_where(cls.table_name, 'id', id_):
                 obj = cls.from_table_row(db_conn, row)
-                cache[id_] = obj
+                assert isinstance(obj, cls)
+                obj.cache()
                 break
         return obj, from_cache
 
@@ -156,7 +155,7 @@ class BaseModel:
         if (id_ is not None) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
-        self.id_ = id_
+        self.id_ = id_  # type: ignore[assignment]
 
     def save_core(self, db_conn: DatabaseConnection,
                   update_with_lastrowid: bool = True) -> None:
@@ -168,6 +167,32 @@ class BaseModel:
         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
                               values)
         if update_with_lastrowid:
-            self.id_ = cursor.lastrowid
-        cache = getattr(db_conn, f'cached_{table_name}')
-        cache[self.id_] = self
+            self.id_ = cursor.lastrowid  # type: ignore[assignment]
+        self.cache()
+
+    @classmethod
+    def get_cached(cls: type[T], id_: X) -> T | None:
+        """Get object of id_ from class's cache, or None if not found."""
+        # pylint: disable=consider-iterating-dictionary
+        if id_ in cls.cache_.keys():
+            obj = cls.cache_[id_]
+            assert isinstance(obj, cls)
+            return obj
+        return None
+
+    def cache(self) -> None:
+        """Update object in class's cache."""
+        if self.id_ is None:
+            raise HandledException('Cannot cache object without ID.')
+        self.__class__.cache_[self.id_] = self
+
+    def uncache(self) -> None:
+        """Remove self from cache."""
+        if self.id_ is None:
+            raise HandledException('Cannot un-cache object without ID.')
+        del self.__class__.cache_[self.id_]
+
+    @classmethod
+    def empty_cache(cls) -> None:
+        """Empty class's cache."""
+        cls.cache_ = {}
index 6249d48445a3488a8f8c4bdd5748b63a5a727ae5..590c5bca56d3f84c1b67f34ee3f80c084903b1c9 100644 (file)
@@ -18,7 +18,7 @@ class ProcessStepsNode:
     seen: bool
 
 
-class Process(BaseModel, ConditionsRelations):
+class Process(BaseModel[int], ConditionsRelations):
     """Template for, and metadata for, Todos, and their arrangements."""
     table_name = 'processes'
 
@@ -38,12 +38,13 @@ class Process(BaseModel, ConditionsRelations):
     def all(cls, db_conn: DatabaseConnection) -> list[Process]:
         """Collect all Processes and their connected VersionedAttributes."""
         processes = {}
-        for id_, process in db_conn.cached_processes.items():
+        for id_, process in cls.cache_.items():
             processes[id_] = process
         already_recorded = processes.keys()
         for id_ in db_conn.column_all('processes', 'id'):
             if id_ not in already_recorded:
                 process = cls.by_id(db_conn, id_)
+                assert isinstance(process.id_, int)
                 processes[process.id_] = process
         return list(processes.values())
 
@@ -165,8 +166,7 @@ class Process(BaseModel, ConditionsRelations):
         """Set self.explicit_steps in bulk."""
         assert isinstance(self.id_, int)
         for step in self.explicit_steps:
-            assert isinstance(step.id_, int)
-            del db_conn.cached_process_steps[step.id_]
+            step.uncache()
         self.explicit_steps = []
         db_conn.delete_where('process_steps', 'owner', self.id_)
         for step_tuple in steps:
@@ -189,10 +189,9 @@ class Process(BaseModel, ConditionsRelations):
         db_conn.delete_where('process_steps', 'owner', self.id_)
         for step in self.explicit_steps:
             step.save(db_conn)
-        db_conn.cached_processes[self.id_] = self
 
 
-class ProcessStep(BaseModel):
+class ProcessStep(BaseModel[int]):
     """Sub-unit of Processes."""
     table_name = 'process_steps'
     to_save = ['owner_id', 'step_process_id', 'parent_step_id']
index 80dc97c25ac481089b5760a0c8072af5f755ef08..0b42d47b3604bb50885c4cda69dcbfb5f263afc8 100644 (file)
@@ -19,7 +19,7 @@ class TodoStepsNode:
     seen: bool
 
 
-class Todo(BaseModel, ConditionsRelations):
+class Todo(BaseModel[int], ConditionsRelations):
     """Individual actionable."""
 
     # pylint: disable=too-many-instance-attributes
@@ -71,6 +71,7 @@ class Todo(BaseModel, ConditionsRelations):
                 todo.parents += [cls.by_id(db_conn, t_id)]
             for name in ('conditions', 'enables', 'disables'):
                 table = f'todo_{name}'
+                assert isinstance(todo.id_, int)
                 for cond_id in db_conn.column_where(table, 'condition',
                                                     'todo', todo.id_):
                     target = getattr(todo, name)
@@ -233,7 +234,6 @@ class Todo(BaseModel, ConditionsRelations):
             raise NotFoundException('Process of Todo without ID (not saved?)')
         self.save_core(db_conn)
         assert isinstance(self.id_, int)
-        db_conn.cached_todos[self.id_] = self
         db_conn.rewrite_relations('todo_children', 'parent', self.id_,
                                   [[c.id_] for c in self.children])
         db_conn.rewrite_relations('todo_conditions', 'todo', self.id_,
index 7aed5f83a0eee3e2eb50845c42de9fd57e3aee0f..b47834c980a39c0ee7b7ce09f52917f712c6813c 100644 (file)
@@ -31,9 +31,9 @@ class TestsWithDB(TestCaseWithDB):
         todo.save(self.db_conn)
         self.assertEqual(Todo.by_id(self.db_conn, 1), todo)
         with self.assertRaises(NotFoundException):
-            self.assertEqual(Todo.by_id(self.db_conn, 0), todo)
+            Todo.by_id(self.db_conn, 0)
         with self.assertRaises(NotFoundException):
-            self.assertEqual(Todo.by_id(self.db_conn, 2), todo)
+            Todo.by_id(self.db_conn, 2)
 
     def test_Todo_by_date(self) -> None:
         """Test findability of Todos by date."""
@@ -301,7 +301,6 @@ class TestsWithServer(TestCaseWithServer):
         def post_and_reload(form_data: dict[str, object],
                             status: int = 302) -> Todo:
             self.check_post(form_data, '/todo?id=1', status, '/')
-            self.db_conn.cached_todos = {}
             return Todo.by_date(self.db_conn, '2024-01-01')[0]
         # test minimum
         form_data = {'title': '', 'description': '', 'effort': 1}
index c80b34da9f3405ddabf563e577e8406af4e155eb..63b07e93e6f14ca51426e2fc00e959fdbfca7bf1 100644 (file)
@@ -8,12 +8,21 @@ from os import remove as remove_file
 from typing import Mapping
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
+from plomtask.processes import Process, ProcessStep
+from plomtask.conditions import Condition
+from plomtask.days import Day
+from plomtask.todos import Todo
 
 
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
 
     def setUp(self) -> None:
+        Condition.empty_cache()
+        Day.empty_cache()
+        Process.empty_cache()
+        ProcessStep.empty_cache()
+        Todo.empty_cache()
         timestamp = datetime.now().timestamp()
         self.db_file = DatabaseFile(f'test_db:{timestamp}')
         self.db_file.remake()