home · contact · privacy
Refactor from_table_row methods of core DB models.
[plomtask] / plomtask / processes.py
index dc0613f8dbb2b01f941603203f57cbddb3766693..2f8c2d537a168062cee953985242b732f1973a72 100644 (file)
@@ -1,22 +1,20 @@
 """Collecting Processes and Process-related items."""
 from __future__ import annotations
-from sqlite3 import Row
 from typing import Any, Set
-from plomtask.db import DatabaseConnection
+from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
 from plomtask.conditions import Condition
 from plomtask.exceptions import NotFoundException, BadFormatException
 
 
-class Process:
+class Process(BaseModel):
     """Template for, and metadata for, Todos, and their arrangements."""
+    table_name = 'processes'
 
     # pylint: disable=too-many-instance-attributes
 
     def __init__(self, id_: int | None) -> None:
-        if (id_ is not None) and id_ < 1:
-            raise BadFormatException(f'illegal Process ID, must be >=1: {id_}')
-        self.id_ = id_
+        self.set_int_id(id_)
         self.title = VersionedAttribute(self, 'process_titles', 'UNNAMED')
         self.description = VersionedAttribute(self, 'process_descriptions', '')
         self.effort = VersionedAttribute(self, 'process_efforts', 1.0)
@@ -25,14 +23,6 @@ class Process:
         self.fulfills: list[Condition] = []
         self.undoes: list[Condition] = []
 
-    @classmethod
-    def from_table_row(cls, db_conn: DatabaseConnection, row: Row) -> Process:
-        """Make Process from database row, with empty VersionedAttributes."""
-        process = cls(row[0])
-        assert process.id_ is not None
-        db_conn.cached_processes[process.id_] = process
-        return process
-
     @classmethod
     def all(cls, db_conn: DatabaseConnection) -> list[Process]:
         """Collect all Processes and their connected VersionedAttributes."""
@@ -125,7 +115,7 @@ class Process:
             external_owner = self
         for step in [s for s in self.explicit_steps
                      if s.parent_step_id is None]:
-            assert step.id_ is not None  # for mypy
+            assert isinstance(step.id_, int)
             steps[step.id_] = make_node(step)
         for step_id, step_node in steps.items():
             walk_steps(step_id, step_node)
@@ -149,12 +139,15 @@ class Process:
         """Set self.undoes to Conditions identified by ids."""
         self.set_conditions(db_conn, ids, 'undoes')
 
-    def add_step(self, db_conn: DatabaseConnection, id_: int | None,
-                 step_process_id: int,
-                 parent_step_id: int | None) -> ProcessStep:
+    def _add_step(self,
+                  db_conn: DatabaseConnection,
+                  id_: int | None,
+                  step_process_id: int,
+                  parent_step_id: int | None) -> ProcessStep:
         """Create new ProcessStep, save and add it to self.explicit_steps.
 
         Also checks against step recursion.
+
         The new step's parent_step_id will fall back to None either if no
         matching ProcessStep is found (which can be assumed in case it was
         just deleted under its feet), or if the parent step would not be
@@ -173,17 +166,29 @@ class Process:
                     parent_step_id = None
             except NotFoundException:
                 parent_step_id = None
-        assert self.id_ is not None
+        assert isinstance(self.id_, int)
         step = ProcessStep(id_, self.id_, step_process_id, parent_step_id)
         walk_steps(step)
         self.explicit_steps += [step]
         step.save(db_conn)  # NB: This ensures a non-None step.id_.
         return step
 
-    def save_without_steps(self, db_conn: DatabaseConnection) -> None:
-        """Add (or re-write) self and connected VersionedAttributes to DB."""
-        cursor = db_conn.exec('REPLACE INTO processes VALUES (?)', (self.id_,))
-        self.id_ = cursor.lastrowid
+    def set_steps(self, db_conn: DatabaseConnection,
+                  steps: list[tuple[int | None, int, int | None]]) -> None:
+        """Set self.explicit_steps in bulk."""
+        for step in self.explicit_steps:
+            assert isinstance(step.id_, int)
+            del db_conn.cached_process_steps[step.id_]
+        self.explicit_steps = []
+        db_conn.exec('DELETE FROM process_steps WHERE owner_id = ?',
+                     (self.id_,))
+        for step_tuple in steps:
+            self._add_step(db_conn, step_tuple[0],
+                           step_tuple[1], step_tuple[2])
+
+    def save(self, db_conn: DatabaseConnection) -> None:
+        """Add (or re-write) self and connected items to DB."""
+        self.save_core(db_conn)
         self.title.save(db_conn)
         self.description.save(db_conn)
         self.effort.save(db_conn)
@@ -202,49 +207,26 @@ class Process:
         for condition in self.undoes:
             db_conn.exec('INSERT INTO process_undoes VALUES (?,?)',
                          (self.id_, condition.id_))
-        assert self.id_ is not None
-        db_conn.cached_processes[self.id_] = self
-
-    def fix_steps(self, db_conn: DatabaseConnection) -> None:
-        """Rewrite ProcessSteps from self.explicit_steps.
-
-        This also fixes illegal Step.parent_step_id values, i.e. those pointing
-        to steps now absent, or owned by a different Process, fall back into
-        .parent_step_id=None
-        """
+        assert isinstance(self.id_, int)
         db_conn.exec('DELETE FROM process_steps WHERE owner_id = ?',
                      (self.id_,))
         for step in self.explicit_steps:
-            if step.parent_step_id is not None:
-                try:
-                    parent_step = ProcessStep.by_id(db_conn,
-                                                    step.parent_step_id)
-                    if parent_step.owner_id != self.id_:
-                        step.parent_step_id = None
-                except NotFoundException:
-                    step.parent_step_id = None
             step.save(db_conn)
+        db_conn.cached_processes[self.id_] = self
 
 
-class ProcessStep:
+class ProcessStep(BaseModel):
     """Sub-unit of Processes."""
+    table_name = 'process_steps'
+    to_save = ['owner_id', 'step_process_id', 'parent_step_id']
 
     def __init__(self, id_: int | None, owner_id: int, step_process_id: int,
                  parent_step_id: int | None) -> None:
-        self.id_ = id_
+        self.set_int_id(id_)
         self.owner_id = owner_id
         self.step_process_id = step_process_id
         self.parent_step_id = parent_step_id
 
-    @classmethod
-    def from_table_row(cls, db_conn: DatabaseConnection,
-                       row: Row) -> ProcessStep:
-        """Make ProcessStep from database row, store in DB cache."""
-        step = cls(row[0], row[1], row[2], row[3])
-        assert step.id_ is not None
-        db_conn.cached_process_steps[step.id_] = step
-        return step
-
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep:
         """Retrieve ProcessStep by id_, or throw NotFoundException."""
@@ -254,14 +236,10 @@ class ProcessStep:
             return step
         for row in db_conn.exec('SELECT * FROM process_steps '
                                 'WHERE step_id = ?', (id_,)):
-            return cls.from_table_row(db_conn, row)
+            step = cls.from_table_row(db_conn, row)
+            assert isinstance(step, ProcessStep)
         raise NotFoundException(f'found no ProcessStep of ID {id_}')
 
     def save(self, db_conn: DatabaseConnection) -> None:
-        """Save to database and cache."""
-        cursor = db_conn.exec('REPLACE INTO process_steps VALUES (?, ?, ?, ?)',
-                              (self.id_, self.owner_id, self.step_process_id,
-                               self.parent_step_id))
-        self.id_ = cursor.lastrowid
-        assert self.id_ is not None
-        db_conn.cached_process_steps[self.id_] = self
+        """Default to simply calling self.save_core for simple cases."""
+        self.save_core(db_conn)