home · contact · privacy
Cache DB objects to ensure we do not accidentally edit clones.
[plomtask] / plomtask / processes.py
index 03fecb2fee3f570f0595ab785f096fef15cdee85..fe9bd4a63ff36a34afdbcbe0abed36edb891c8f5 100644 (file)
@@ -19,33 +19,35 @@ class Process:
         self.effort = VersionedAttribute(self, 'effort', 1.0)
         self.explicit_steps: list[ProcessStep] = []
 
-    def __eq__(self, other: object) -> bool:
-        return isinstance(other, self.__class__) and self.id_ == other.id_
-
     @classmethod
-    def from_table_row(cls, row: Row) -> Process:
+    def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Process:
         """Make Process from database row, with empty VersionedAttributes."""
-        return cls(row[0])
+        process = cls(row[0])
+        assert process.id_ is not None
+        db_conn.cached_processes[process.id_] = process
+        return process
 
     @classmethod
     def all(cls, db_conn: DatabaseConnection) -> list[Process]:
         """Collect all Processes and their connected VersionedAttributes."""
         processes = {}
-        for row in db_conn.exec('SELECT * FROM processes'):
-            process = cls.from_table_row(row)
-            processes[process.id_] = process
-        for row in db_conn.exec('SELECT * FROM process_titles'):
-            processes[row[0]].title.history[row[1]] = row[2]
-        for row in db_conn.exec('SELECT * FROM process_descriptions'):
-            processes[row[0]].description.history[row[1]] = row[2]
-        for row in db_conn.exec('SELECT * FROM process_efforts'):
-            processes[row[0]].effort.history[row[1]] = row[2]
+        for id_, process in db_conn.cached_processes.items():
+            processes[id_] = process
+        already_recorded = processes.keys()
+        for row in db_conn.exec('SELECT id FROM processes'):
+            if row[0] not in already_recorded:
+                process = cls.by_id(db_conn, row[0])
+                processes[process.id_] = process
         return list(processes.values())
 
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
               create: bool = False) -> Process:
         """Collect Process, its VersionedAttributes, and its child IDs."""
+        if id_ in db_conn.cached_processes.keys():
+            process = db_conn.cached_processes[id_]
+            assert isinstance(process, Process)
+            return process
         process = None
         for row in db_conn.exec('SELECT * FROM processes '
                                 'WHERE id = ?', (id_,)):
@@ -67,7 +69,9 @@ class Process:
                 process.effort.history[row[1]] = row[2]
             for row in db_conn.exec('SELECT * FROM process_steps '
                                     'WHERE owner_id = ?', (process.id_,)):
-                process.explicit_steps += [ProcessStep.from_table_row(row)]
+                process.explicit_steps += [ProcessStep.from_table_row(db_conn,
+                                                                      row)]
+        assert isinstance(process, Process)
         return process
 
     def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]:
@@ -151,6 +155,8 @@ class Process:
         self.title.save(db_conn)
         self.description.save(db_conn)
         self.effort.save(db_conn)
+        assert self.id_ is not None
+        db_conn.cached_processes[self.id_] = self
 
     def fix_steps(self, db_conn: DatabaseConnection) -> None:
         """Rewrite ProcessSteps from self.explicit_steps.
@@ -184,24 +190,34 @@ class ProcessStep:
         self.parent_step_id = parent_step_id
 
     @classmethod
-    def from_table_row(cls, row: Row) -> ProcessStep:
-        """Make ProcessStep from database row."""
-        return cls(row[0], row[1], row[2], row[3])
+    def from_table_row(cls, db_conn: DatabaseConnection,
+                       row: Row) -> ProcessStep:
+        """Make ProcessStep from database row, store in DB cache."""
+        step = cls(row[0], row[1], row[2], row[3])
+        assert step.id_ is not None
+        db_conn.cached_process_steps[step.id_] = step
+        return step
 
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep:
         """Retrieve ProcessStep by id_, or throw NotFoundException."""
+        if id_ in db_conn.cached_process_steps.keys():
+            step = db_conn.cached_process_steps[id_]
+            assert isinstance(step, ProcessStep)
+            return step
         for row in db_conn.exec('SELECT * FROM process_steps '
                                 'WHERE step_id = ?', (id_,)):
-            return cls.from_table_row(row)
+            return cls.from_table_row(db_conn, row)
         raise NotFoundException(f'found no ProcessStep of ID {id_}')
 
     def save(self, db_conn: DatabaseConnection) -> None:
-        """Save to database."""
+        """Save to database and cache."""
         cursor = db_conn.exec('REPLACE INTO process_steps VALUES (?, ?, ?, ?)',
                               (self.id_, self.owner_id, self.step_process_id,
                                self.parent_step_id))
         self.id_ = cursor.lastrowid
+        assert self.id_ is not None
+        db_conn.cached_process_steps[self.id_] = self
 
 
 class VersionedAttribute: