home · contact · privacy
Draw Process descendant trees, and guard against recursion within them.
[plomtask] / plomtask / processes.py
index 4867227e5f4e492ebeeac33c484039f36afca4fd..ba9707be5c407116d1bdd13f0f9210147942b25e 100644 (file)
@@ -3,17 +3,20 @@ from __future__ import annotations
 from sqlite3 import Row
 from datetime import datetime
 from plomtask.db import DatabaseConnection
-from plomtask.misc import HandledException
+from plomtask.exceptions import NotFoundException, BadFormatException
 
 
 class Process:
     """Template for, and metadata for, Todos, and their arrangements."""
 
     def __init__(self, id_: int | None) -> None:
-        self.id_ = id_ if id_ != 0 else None  # to avoid DB-confusing rowid=0
+        if (id_ is not None) and id_ < 1:
+            raise BadFormatException(f'illegal Process ID, must be >=1: {id_}')
+        self.id_ = id_
         self.title = VersionedAttribute(self, 'title', 'UNNAMED')
         self.description = VersionedAttribute(self, 'description', '')
         self.effort = VersionedAttribute(self, 'effort', 1.0)
+        self.child_ids: list[int] = []
 
     @classmethod
     def from_table_row(cls, row: Row) -> Process:
@@ -38,7 +41,7 @@ class Process:
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
               create: bool = False) -> Process:
-        """Collect all Processes and their connected VersionedAttributes."""
+        """Collect Process, its VersionedAttributes, and its child IDs."""
         process = None
         for row in db_conn.exec('SELECT * FROM processes '
                                 'WHERE id = ?', (id_,)):
@@ -46,7 +49,7 @@ class Process:
             break
         if not process:
             if not create:
-                raise HandledException(f'Process not found of id: {id_}')
+                raise NotFoundException(f'Process not found of id: {id_}')
             process = Process(id_)
         if process:
             for row in db_conn.exec('SELECT * FROM process_titles '
@@ -58,15 +61,43 @@ class Process:
             for row in db_conn.exec('SELECT * FROM process_efforts '
                                     'WHERE process_id = ?', (process.id_,)):
                 process.effort.history[row[1]] = row[2]
+            for row in db_conn.exec('SELECT * FROM process_children '
+                                    'WHERE parent_id = ?', (process.id_,)):
+                process.child_ids += [row[1]]
         return process
 
+    def get_descendants(self, db_conn: DatabaseConnection) ->\
+            dict[int, dict[str, object]]:
+        """Return tree of descendant Processes"""
+        descendants = {}
+        for id_ in self.child_ids:
+            child = self.__class__.by_id(db_conn, id_)
+            descendants[id_] = {'process': child,
+                                'children': child.get_descendants(db_conn)}
+        return descendants
+
     def save(self, db_conn: DatabaseConnection) -> None:
-        """Add (or re-write) self and connected VersionedAttributes to DB."""
+        """Add (or re-write) self and connected VersionedAttributes to DB.
+
+        Also is the point at which descendancy recursion is checked.
+        """
+        def walk_descendants(node_id: int) -> None:
+            if node_id == self.id_:
+                raise BadFormatException('bad child selection: recursion')
+            descendant = self.by_id(db_conn, node_id)
+            for descendant_id in descendant.child_ids:
+                walk_descendants(descendant_id)
         cursor = db_conn.exec('REPLACE INTO processes VALUES (?)', (self.id_,))
         self.id_ = cursor.lastrowid
         self.title.save(db_conn)
         self.description.save(db_conn)
         self.effort.save(db_conn)
+        db_conn.exec('DELETE FROM process_children WHERE parent_id = ?',
+                     (self.id_,))
+        for child_id in self.child_ids:
+            walk_descendants(child_id)
+            db_conn.exec('INSERT INTO process_children VALUES (?, ?)',
+                         (self.id_, child_id))
 
 
 class VersionedAttribute: