From 6b329a28bb4aec8d1846f5cc5402ed6fca5eb3da Mon Sep 17 00:00:00 2001 From: Christian Heller <c.heller@plomlompom.de> Date: Sat, 13 Apr 2024 00:00:58 +0200 Subject: [PATCH] Cache DB objects to ensure we do not accidentally edit clones. --- plomtask/days.py | 28 +++++++++++++++------- plomtask/db.py | 6 ++++- plomtask/processes.py | 56 +++++++++++++++++++++++++++---------------- plomtask/todos.py | 28 +++++++++++++--------- tests/days.py | 10 +++++++- tests/processes.py | 45 ++++++++++++++++++++++++---------- tests/todos.py | 24 ++++++++++++------- 7 files changed, 135 insertions(+), 62 deletions(-) diff --git a/plomtask/days.py b/plomtask/days.py index afdea33..abfce06 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -40,9 +40,11 @@ class Day: return self.date < other.date @classmethod - def from_table_row(cls, row: Row) -> Day: - """Make Day from database row.""" - return cls(row[0], row[1]) + def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Day: + """Make Day from database row, write to cache.""" + day = cls(row[0], row[1]) + db_conn.cached_days[day.date] = day + return day @classmethod def all(cls, db_conn: DatabaseConnection, @@ -54,9 +56,9 @@ class Day: start_date = valid_date(date_range[0] if date_range[0] else min_date) end_date = valid_date(date_range[1] if date_range[1] else max_date) days = [] - sql = 'SELECT * FROM days WHERE date >= ? AND date <= ?' + sql = 'SELECT date FROM days WHERE date >= ? AND date <= ?' for row in db_conn.exec(sql, (start_date, end_date)): - days += [cls.from_table_row(row)] + days += [cls.by_date(db_conn, row[0])] days.sort() if fill_gaps and len(days) > 1: gapless_days = [] @@ -72,12 +74,19 @@ class Day: @classmethod def by_date(cls, db_conn: DatabaseConnection, date: str, create: bool = False) -> Day: - """Retrieve Day by date if in DB, else return None.""" + """Retrieve Day by date if in DB (prefer cache), else return None.""" + if date in db_conn.cached_days.keys(): + day = db_conn.cached_days[date] + assert isinstance(day, Day) + return day for row in db_conn.exec('SELECT * FROM days WHERE date = ?', (date,)): - return cls.from_table_row(row) + return cls.from_table_row(db_conn, row) if not create: raise NotFoundException(f'Day not found for date: {date}') - return cls(date) + day = cls(date) + db_conn.cached_days[date] = day + assert isinstance(day, Day) + return day @property def weekday(self) -> str: @@ -97,6 +106,7 @@ class Day: return next_datetime.strftime(DATE_FORMAT) def save(self, db_conn: DatabaseConnection) -> None: - """Add (or re-write) self to database.""" + """Add (or re-write) self to DB and cache.""" db_conn.exec('REPLACE INTO days VALUES (?, ?)', (self.date, self.comment)) + db_conn.cached_days[self.date] = self diff --git a/plomtask/db.py b/plomtask/db.py index 929a733..01bc3e9 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -2,7 +2,7 @@ from os.path import isfile from difflib import Differ from sqlite3 import connect as sql_connect, Cursor -from typing import Any +from typing import Any, Dict from plomtask.exceptions import HandledException PATH_DB_SCHEMA = 'scripts/init.sql' @@ -49,6 +49,10 @@ class DatabaseConnection: def __init__(self, db_file: DatabaseFile) -> None: self.file = db_file self.conn = sql_connect(self.file.path) + self.cached_todos: Dict[int, Any] = {} + self.cached_days: Dict[str, Any] = {} + self.cached_process_steps: Dict[int, Any] = {} + self.cached_processes: Dict[int, Any] = {} def commit(self) -> None: """Commit SQL transaction.""" diff --git a/plomtask/processes.py b/plomtask/processes.py index 03fecb2..fe9bd4a 100644 --- a/plomtask/processes.py +++ b/plomtask/processes.py @@ -19,33 +19,35 @@ class Process: self.effort = VersionedAttribute(self, 'effort', 1.0) self.explicit_steps: list[ProcessStep] = [] - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and self.id_ == other.id_ - @classmethod - def from_table_row(cls, row: Row) -> Process: + def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Process: """Make Process from database row, with empty VersionedAttributes.""" - return cls(row[0]) + process = cls(row[0]) + assert process.id_ is not None + db_conn.cached_processes[process.id_] = process + return process @classmethod def all(cls, db_conn: DatabaseConnection) -> list[Process]: """Collect all Processes and their connected VersionedAttributes.""" processes = {} - for row in db_conn.exec('SELECT * FROM processes'): - process = cls.from_table_row(row) - processes[process.id_] = process - for row in db_conn.exec('SELECT * FROM process_titles'): - processes[row[0]].title.history[row[1]] = row[2] - for row in db_conn.exec('SELECT * FROM process_descriptions'): - processes[row[0]].description.history[row[1]] = row[2] - for row in db_conn.exec('SELECT * FROM process_efforts'): - processes[row[0]].effort.history[row[1]] = row[2] + for id_, process in db_conn.cached_processes.items(): + processes[id_] = process + already_recorded = processes.keys() + for row in db_conn.exec('SELECT id FROM processes'): + if row[0] not in already_recorded: + process = cls.by_id(db_conn, row[0]) + processes[process.id_] = process return list(processes.values()) @classmethod def by_id(cls, db_conn: DatabaseConnection, id_: int | None, create: bool = False) -> Process: """Collect Process, its VersionedAttributes, and its child IDs.""" + if id_ in db_conn.cached_processes.keys(): + process = db_conn.cached_processes[id_] + assert isinstance(process, Process) + return process process = None for row in db_conn.exec('SELECT * FROM processes ' 'WHERE id = ?', (id_,)): @@ -67,7 +69,9 @@ class Process: process.effort.history[row[1]] = row[2] for row in db_conn.exec('SELECT * FROM process_steps ' 'WHERE owner_id = ?', (process.id_,)): - process.explicit_steps += [ProcessStep.from_table_row(row)] + process.explicit_steps += [ProcessStep.from_table_row(db_conn, + row)] + assert isinstance(process, Process) return process def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]: @@ -151,6 +155,8 @@ class Process: self.title.save(db_conn) self.description.save(db_conn) self.effort.save(db_conn) + assert self.id_ is not None + db_conn.cached_processes[self.id_] = self def fix_steps(self, db_conn: DatabaseConnection) -> None: """Rewrite ProcessSteps from self.explicit_steps. @@ -184,24 +190,34 @@ class ProcessStep: self.parent_step_id = parent_step_id @classmethod - def from_table_row(cls, row: Row) -> ProcessStep: - """Make ProcessStep from database row.""" - return cls(row[0], row[1], row[2], row[3]) + def from_table_row(cls, db_conn: DatabaseConnection, + row: Row) -> ProcessStep: + """Make ProcessStep from database row, store in DB cache.""" + step = cls(row[0], row[1], row[2], row[3]) + assert step.id_ is not None + db_conn.cached_process_steps[step.id_] = step + return step @classmethod def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep: """Retrieve ProcessStep by id_, or throw NotFoundException.""" + if id_ in db_conn.cached_process_steps.keys(): + step = db_conn.cached_process_steps[id_] + assert isinstance(step, ProcessStep) + return step for row in db_conn.exec('SELECT * FROM process_steps ' 'WHERE step_id = ?', (id_,)): - return cls.from_table_row(row) + return cls.from_table_row(db_conn, row) raise NotFoundException(f'found no ProcessStep of ID {id_}') def save(self, db_conn: DatabaseConnection) -> None: - """Save to database.""" + """Save to database and cache.""" cursor = db_conn.exec('REPLACE INTO process_steps VALUES (?, ?, ?, ?)', (self.id_, self.owner_id, self.step_process_id, self.parent_step_id)) self.id_ = cursor.lastrowid + assert self.id_ is not None + db_conn.cached_process_steps[self.id_] = self class VersionedAttribute: diff --git a/plomtask/todos.py b/plomtask/todos.py index 7150f0d..f1d98ad 100644 --- a/plomtask/todos.py +++ b/plomtask/todos.py @@ -17,37 +17,43 @@ class Todo: self.is_done = is_done self.day = day - def __eq__(self, other: object) -> bool: - return isinstance(other, self.__class__) and self.id_ == other.id_ - @classmethod - def from_table_row(cls, row: Row, db_conn: DatabaseConnection) -> Todo: - """Make Todo from database row.""" - return cls(id_=row[0], + def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Todo: + """Make Todo from database row, write to DB cache.""" + todo = cls(id_=row[0], process=Process.by_id(db_conn, row[1]), is_done=row[2], day=Day.by_date(db_conn, row[3])) + assert todo.id_ is not None + db_conn.cached_todos[todo.id_] = todo + return todo @classmethod def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo: - """Get Todo of .id_=id_.""" + """Get Todo of .id_=id_ â from DB cache if possible.""" + if id_ in db_conn.cached_todos.keys(): + todo = db_conn.cached_todos[id_] + assert isinstance(todo, Todo) + return todo for row in db_conn.exec('SELECT * FROM todos WHERE id = ?', (id_,)): - return cls.from_table_row(row, db_conn) + return cls.from_table_row(db_conn, row) raise NotFoundException(f'Todo of ID not found: {id_}') @classmethod def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Todo]: """Collect all Todos for Day of date.""" todos = [] - for row in db_conn.exec('SELECT * FROM todos WHERE day = ?', (date,)): - todos += [cls.from_table_row(row, db_conn)] + for row in db_conn.exec('SELECT id FROM todos WHERE day = ?', (date,)): + todos += [cls.by_id(db_conn, row[0])] return todos def save(self, db_conn: DatabaseConnection) -> None: - """Write self to DB.""" + """Write self to DB and its cache.""" if self.process.id_ is None: raise NotFoundException('Process of Todo without ID (not saved?)') cursor = db_conn.exec('REPLACE INTO todos VALUES (?,?,?,?)', (self.id_, self.process.id_, self.is_done, self.day.date)) self.id_ = cursor.lastrowid + assert self.id_ is not None + db_conn.cached_todos[self.id_] = self diff --git a/tests/days.py b/tests/days.py index 81dbcf5..2e2ef50 100644 --- a/tests/days.py +++ b/tests/days.py @@ -33,7 +33,7 @@ class TestsSansDB(TestCase): class TestsWithDB(TestCaseWithDB): - """Days module tests not requiring DB setup.""" + """Tests requiring DB, but not server setup.""" def test_Day_by_date(self) -> None: """Test Day.by_date().""" @@ -86,6 +86,14 @@ class TestsWithDB(TestCaseWithDB): self.assertEqual(Day('2024-01-01').prev_date, '2023-12-31') self.assertEqual(Day('2023-02-28').next_date, '2023-03-01') + def test_Day_singularity(self) -> None: + """Test pointers made for single object keep pointing to it.""" + day = Day('2024-01-01') + day.save(self.db_conn) + retrieved_day = Day.by_date(self.db_conn, '2024-01-01') + day.comment = 'foo' + self.assertEqual(retrieved_day.comment, 'foo') + class TestsWithServer(TestCaseWithServer): """Tests against our HTTP server/handler (and database).""" diff --git a/tests/processes.py b/tests/processes.py index 87b0b09..bda6275 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -2,7 +2,7 @@ from unittest import TestCase from typing import Any from tests.utils import TestCaseWithDB, TestCaseWithServer -from plomtask.processes import Process +from plomtask.processes import Process, ProcessStep from plomtask.exceptions import NotFoundException, BadFormatException @@ -39,17 +39,6 @@ class TestsWithDB(TestCaseWithDB): self.assertEqual(p.id_, Process.by_id(self.db_conn, 5, create=False).id_) - def test_Process_versioned_attributes(self) -> None: - """Test behavior of VersionedAttributes on saving (with .title).""" - p = Process(None) - p.save_without_steps(self.db_conn) - p.title.set('named') - p_loaded = Process.by_id(self.db_conn, p.id_) - self.assertNotEqual(p.title.history, p_loaded.title.history) - p.save_without_steps(self.db_conn) - p_loaded = Process.by_id(self.db_conn, p.id_) - self.assertEqual(p.title.history, p_loaded.title.history) - def test_Process_steps(self) -> None: """Test addition, nesting, and non-recursion of ProcessSteps""" p_1 = Process(1) @@ -128,6 +117,38 @@ class TestsWithDB(TestCaseWithDB): self.assertEqual({p_1.id_, p_2.id_}, set(p.id_ for p in Process.all(self.db_conn))) + def test_ProcessStep_singularity(self) -> None: + """Test pointers made for single object keep pointing to it.""" + p_1 = Process(None) + p_1.save_without_steps(self.db_conn) + p_2 = Process(None) + p_2.save_without_steps(self.db_conn) + assert p_2.id_ is not None + step = p_1.add_step(self.db_conn, None, p_2.id_, None) + assert step.id_ is not None + step_retrieved = ProcessStep.by_id(self.db_conn, step.id_) + step.parent_step_id = 99 + self.assertEqual(step.parent_step_id, step_retrieved.parent_step_id) + + def test_Process_singularity(self) -> None: + """Test pointers made for single object keep pointing to it.""" + p_1 = Process(None) + p_1.save_without_steps(self.db_conn) + p_2 = Process(None) + p_2.save_without_steps(self.db_conn) + assert p_2.id_ is not None + p_1.add_step(self.db_conn, None, p_2.id_, None) + p_retrieved = Process.by_id(self.db_conn, p_1.id_) + self.assertEqual(p_1.explicit_steps, p_retrieved.explicit_steps) + + def test_Process_versioned_attributes_singularity(self) -> None: + """Test behavior of VersionedAttributes on saving (with .title).""" + p = Process(None) + p.save_without_steps(self.db_conn) + p.title.set('named') + p_loaded = Process.by_id(self.db_conn, p.id_) + self.assertEqual(p.title.history, p_loaded.title.history) + class TestsWithServer(TestCaseWithServer): """Module tests against our HTTP server/handler (and database).""" diff --git a/tests/todos.py b/tests/todos.py index 8bcd181..db4ad9c 100644 --- a/tests/todos.py +++ b/tests/todos.py @@ -7,7 +7,7 @@ from plomtask.exceptions import NotFoundException class TestsWithDB(TestCaseWithDB): - """Tests not requiring DB setup.""" + """Tests requiring DB, but not server setup.""" def test_Todo_by_id(self) -> None: """Test creation and findability of Todos.""" @@ -18,9 +18,6 @@ class TestsWithDB(TestCaseWithDB): todo.save(self.db_conn) process.save_without_steps(self.db_conn) todo.save(self.db_conn) - with self.assertRaises(NotFoundException): - _ = Todo.by_id(self.db_conn, 1) - day.save(self.db_conn) self.assertEqual(Todo.by_id(self.db_conn, 1), todo) with self.assertRaises(NotFoundException): self.assertEqual(Todo.by_id(self.db_conn, 0), todo) @@ -37,14 +34,25 @@ class TestsWithDB(TestCaseWithDB): todo1.save(self.db_conn) todo2 = Todo(None, process, False, day1) todo2.save(self.db_conn) - with self.assertRaises(NotFoundException): - _ = Todo.by_date(self.db_conn, day1.date) - day1.save(self.db_conn) - day2.save(self.db_conn) self.assertEqual(Todo.by_date(self.db_conn, day1.date), [todo1, todo2]) self.assertEqual(Todo.by_date(self.db_conn, day2.date), []) self.assertEqual(Todo.by_date(self.db_conn, 'foo'), []) + def test_Todo_singularity(self) -> None: + """Test pointers made for single object keep pointing to it.""" + day = Day('2024-01-01') + day.save(self.db_conn) + process = Process(None) + process.save_without_steps(self.db_conn) + todo = Todo(None, process, False, day) + todo.save(self.db_conn) + retrieved_todo = Todo.by_id(self.db_conn, 1) + todo.is_done = True + self.assertEqual(retrieved_todo.is_done, True) + retrieved_todo = Todo.by_date(self.db_conn, '2024-01-01')[0] + retrieved_todo.is_done = False + self.assertEqual(todo.is_done, False) + class TestsWithServer(TestCaseWithServer): """Tests against our HTTP server/handler (and database).""" -- 2.30.2