From c6bc1fddcf12ae9523cf5b1b595638c762677c52 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Wed, 10 Apr 2024 03:24:03 +0200
Subject: [PATCH] Replace ProcessChildren with more flexible ProcessStep
 infrastructure.

---
 plomtask/http.py       |  26 ++++++-
 plomtask/processes.py  | 151 +++++++++++++++++++++++++++++++++--------
 scripts/init.sql       |  15 ++--
 templates/process.html |  27 +++++---
 tests/processes.py     | 140 +++++++++++++++++++++++---------------
 5 files changed, 257 insertions(+), 102 deletions(-)

diff --git a/plomtask/http.py b/plomtask/http.py
index f368232..54450e4 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -70,6 +70,13 @@ class PostvarsParser:
             msg = f'cannot int form field value: {val}'
             raise BadFormatException(msg) from e
 
+    def get_int_or_none(self, key: str) -> int | None:
+        """Retrieve int value of key from self.postvars, or None."""
+        if key not in self.postvars or \
+                0 == len(''.join(list(self.postvars[key]))):
+            return None
+        return self.get_int(key)
+
     def get_float(self, key: str) -> float:
         """Retrieve float value of key from self.postvars."""
         val = self.get_str(key)
@@ -138,7 +145,7 @@ class TaskHandler(BaseHTTPRequestHandler):
         id_ = params.get_int_or_none('id')
         process = Process.by_id(conn, id_, create=True)
         return self.server.jinja.get_template('process.html').render(
-                process=process, children=process.get_descendants(conn),
+                process=process, steps=process.get_steps(conn),
                 candidates=Process.all(conn))
 
     def do_GET_processes(self, conn: DatabaseConnection,
@@ -183,8 +190,21 @@ class TaskHandler(BaseHTTPRequestHandler):
         process.title.set(form_data.get_str('title'))
         process.description.set(form_data.get_str('description'))
         process.effort.set(form_data.get_float('effort'))
-        process.child_ids = form_data.get_all_int('children')
-        process.save(conn)
+        process.save_without_steps(conn)
+        assert process.id_ is not None  # for mypy
+        process.explicit_steps = []
+        for step_id in form_data.get_all_int('steps'):
+            for step_process_id in\
+                    form_data.get_all_int(f'new_step_to_{step_id}'):
+                process.add_step(conn, None, step_process_id, step_id)
+            if step_id not in form_data.get_all_int('keep_step'):
+                continue
+            step_process_id = form_data.get_int(f'step_{step_id}_process_id')
+            parent_id = form_data.get_int_or_none(f'step_{step_id}_parent_id')
+            process.add_step(conn, step_id, step_process_id, parent_id)
+        for step_process_id in form_data.get_all_int('new_top_step'):
+            process.add_step(conn, None, step_process_id, None)
+        process.fix_steps(conn)
 
     def _init_handling(self) -> tuple[DatabaseConnection, str, ParamsParser]:
         conn = DatabaseConnection(self.server.db)
diff --git a/plomtask/processes.py b/plomtask/processes.py
index 4f97f62..23cded7 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -2,6 +2,7 @@
 from __future__ import annotations
 from sqlite3 import Row
 from datetime import datetime
+from typing import Any
 from plomtask.db import DatabaseConnection
 from plomtask.exceptions import NotFoundException, BadFormatException
 
@@ -16,7 +17,10 @@ class Process:
         self.title = VersionedAttribute(self, 'title', 'UNNAMED')
         self.description = VersionedAttribute(self, 'description', '')
         self.effort = VersionedAttribute(self, 'effort', 1.0)
-        self.child_ids: list[int] = []
+        self.explicit_steps: list[ProcessStep] = []
+
+    def __eq__(self, other: object) -> bool:
+        return isinstance(other, self.__class__) and self.id_ == other.id_
 
     @classmethod
     def from_table_row(cls, row: Row) -> Process:
@@ -61,43 +65,132 @@ class Process:
             for row in db_conn.exec('SELECT * FROM process_efforts '
                                     'WHERE process_id = ?', (process.id_,)):
                 process.effort.history[row[1]] = row[2]
-            for row in db_conn.exec('SELECT * FROM process_children '
-                                    'WHERE parent_id = ?', (process.id_,)):
-                process.child_ids += [row[1]]
+            for row in db_conn.exec('SELECT * FROM process_steps '
+                                    'WHERE owner_id = ?', (process.id_,)):
+                process.explicit_steps += [ProcessStep.from_table_row(row)]
         return process
 
-    def get_descendants(self, db_conn: DatabaseConnection) ->\
-            list[dict[str, object]]:
-        """Return tree of descendant Processes"""
-        descendants = []
-        for id_ in self.child_ids:
-            child = self.__class__.by_id(db_conn, id_)
-            descendants += [{'process': child,
-                             'children': child.get_descendants(db_conn)}]
-        return descendants
-
-    def save(self, db_conn: DatabaseConnection) -> None:
-        """Add (or re-write) self and connected VersionedAttributes to DB.
-
-        Also is the point at which descendancy recursion is checked.
+    def get_steps(self, db_conn: DatabaseConnection, external_owner:
+                  Process | None = None) -> dict[int, dict[str, object]]:
+        """Return tree of depended-on explicit and implicit ProcessSteps."""
+
+        def make_node(step: ProcessStep) -> dict[str, object]:
+            step_process = self.__class__.by_id(db_conn, step.step_process_id)
+            is_explicit = False
+            if external_owner is not None:
+                is_explicit = step.owner_id == external_owner.id_
+            step_steps = step_process.get_steps(db_conn, external_owner)
+            return {'process': step_process, 'parent_id': step.parent_step_id,
+                    'is_explicit': is_explicit, 'steps': step_steps}
+
+        def walk_steps(node_id: int, node: dict[str, Any]) -> None:
+            explicit_children = [s for s in self.explicit_steps
+                                 if s.parent_step_id == node_id]
+            for child in explicit_children:
+                node['steps'][child.id_] = make_node(child)
+            for id_, step in node['steps'].items():
+                walk_steps(id_, step)
+
+        steps: dict[int, dict[str, object]] = {}
+        if external_owner is None:
+            external_owner = self
+        for step in [s for s in self.explicit_steps
+                     if s.parent_step_id is None]:
+            assert step.id_ is not None  # for mypy
+            steps[step.id_] = make_node(step)
+        for step_id, step_node in steps.items():
+            walk_steps(step_id, step_node)
+        return steps
+
+    def add_step(self, db_conn: DatabaseConnection, id_: int | None,
+                 step_process_id: int,
+                 parent_step_id: int | None) -> ProcessStep:
+        """Create new ProcessStep, save and add it to self.explicit_steps.
+
+        Also checks against step recursion.
+        The new step's parent_step_id will fall back to None either if no
+        matching ProcessStep is found (which can be assumed in case it was
+        just deleted under its feet), or if the parent step would not be
+        owned by the current Process.
         """
-        def walk_descendants(node_id: int) -> None:
-            if node_id == self.id_:
-                raise BadFormatException('bad child selection: recursion')
-            descendant = self.by_id(db_conn, node_id)
-            for descendant_id in descendant.child_ids:
-                walk_descendants(descendant_id)
+        def walk_steps(node: ProcessStep) -> None:
+            if node.step_process_id == self.id_:
+                raise BadFormatException('bad step selection causes recursion')
+            step_process = self.by_id(db_conn, node.step_process_id)
+            for step in step_process.explicit_steps:
+                walk_steps(step)
+        if parent_step_id is not None:
+            try:
+                parent_step = ProcessStep.by_id(db_conn, parent_step_id)
+                if parent_step.owner_id != self.id_:
+                    parent_step_id = None
+            except NotFoundException:
+                parent_step_id = None
+        assert self.id_ is not None
+        step = ProcessStep(id_, self.id_, step_process_id, parent_step_id)
+        walk_steps(step)
+        self.explicit_steps += [step]
+        step.save(db_conn)  # NB: This ensures a non-None step.id_.
+        return step
+
+    def save_without_steps(self, db_conn: DatabaseConnection) -> None:
+        """Add (or re-write) self and connected VersionedAttributes to DB."""
         cursor = db_conn.exec('REPLACE INTO processes VALUES (?)', (self.id_,))
         self.id_ = cursor.lastrowid
         self.title.save(db_conn)
         self.description.save(db_conn)
         self.effort.save(db_conn)
-        db_conn.exec('DELETE FROM process_children WHERE parent_id = ?',
+
+    def fix_steps(self, db_conn: DatabaseConnection) -> None:
+        """Rewrite ProcessSteps from self.explicit_steps.
+
+        This also fixes illegal Step.parent_step_id values, i.e. those pointing
+        to steps now absent, or owned by a different Process, fall back into
+        .parent_step_id=None
+        """
+        db_conn.exec('DELETE FROM process_steps WHERE owner_id = ?',
                      (self.id_,))
-        for child_id in self.child_ids:
-            walk_descendants(child_id)
-            db_conn.exec('INSERT INTO process_children VALUES (?, ?)',
-                         (self.id_, child_id))
+        for step in self.explicit_steps:
+            if step.parent_step_id is not None:
+                try:
+                    parent_step = ProcessStep.by_id(db_conn,
+                                                    step.parent_step_id)
+                    if parent_step.owner_id != self.id_:
+                        step.parent_step_id = None
+                except NotFoundException:
+                    step.parent_step_id = None
+            step.save(db_conn)
+
+
+class ProcessStep:
+    """Sub-unit of Processes."""
+
+    def __init__(self, id_: int | None, owner_id: int, step_process_id: int,
+                 parent_step_id: int | None) -> None:
+        self.id_ = id_
+        self.owner_id = owner_id
+        self.step_process_id = step_process_id
+        self.parent_step_id = parent_step_id
+
+    @classmethod
+    def from_table_row(cls, row: Row) -> ProcessStep:
+        """Make ProcessStep from database row."""
+        return cls(row[0], row[1], row[2], row[3])
+
+    @classmethod
+    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> ProcessStep:
+        """Retrieve ProcessStep by id_, or throw NotFoundException."""
+        for row in db_conn.exec('SELECT * FROM process_steps '
+                                'WHERE step_id = ?', (id_,)):
+            return cls.from_table_row(row)
+        raise NotFoundException(f'found no ProcessStep of ID {id_}')
+
+    def save(self, db_conn: DatabaseConnection) -> None:
+        """Save to database."""
+        cursor = db_conn.exec('REPLACE INTO process_steps VALUES (?, ?, ?, ?)',
+                              (self.id_, self.owner_id, self.step_process_id,
+                               self.parent_step_id))
+        self.id_ = cursor.lastrowid
 
 
 class VersionedAttribute:
diff --git a/scripts/init.sql b/scripts/init.sql
index 341f2ab..1245030 100644
--- a/scripts/init.sql
+++ b/scripts/init.sql
@@ -2,12 +2,6 @@ CREATE TABLE days (
     date TEXT PRIMARY KEY,
     comment TEXT NOT NULL
 );
-CREATE TABLE process_children (
-    parent_id INTEGER NOT NULL,
-    child_id INTEGER NOT NULL,
-    FOREIGN KEY (parent_id) REFERENCES processes(id),
-    FOREIGN KEY (child_id) REFERENCES processes(id)
-);
 CREATE TABLE process_descriptions (
     process_id INTEGER NOT NULL,
     timestamp TEXT NOT NULL,
@@ -22,6 +16,15 @@ CREATE TABLE process_efforts (
     PRIMARY KEY (process_id, timestamp),
     FOREIGN KEY (process_id) REFERENCES processes(id)
 );
+CREATE TABLE process_steps (
+    step_id INTEGER PRIMARY KEY,
+    owner_id INTEGER NOT NULL,
+    step_process_id INTEGER NOT NULL,
+    parent_step_id INTEGER,
+    FOREIGN KEY (owner_id) REFERENCES processes(id),
+    FOREIGN KEY (step_process_id) REFERENCES processes(id),
+    FOREIGN KEY (parent_step_id) REFERENCES process_steps(step_id)
+);
 CREATE TABLE process_titles (
     process_id INTEGER NOT NULL,
     timestamp TEXT NOT NULL,
diff --git a/templates/process.html b/templates/process.html
index 0f0f0c5..55eeb52 100644
--- a/templates/process.html
+++ b/templates/process.html
@@ -1,18 +1,27 @@
 {% extends 'base.html' %}
 
-{% macro process_with_children(node, indent) %}
+{% macro process_with_steps(step_id, step_node, indent) %}
 <tr>
 <td>
-<input type="checkbox" name="children" value="{{node.process.id_}}" checked />
+<input type="hidden" name="steps" value="{{step_id}}" />
+{% if step_node.is_explicit %}
+<input type="checkbox" name="keep_step" value="{{step_id}}" checked />
+<input type="hidden" name="step_{{step_id}}_process_id" value="{{step_node.process.id_}}" />
+<input type="hidden" name="step_{{step_id}}_parent_id" value="{{step_node.parent_id or ''}}" />
+{% endif %}
+</td>
+<td>{% for i in range(indent) %}+{%endfor %}
+<a href="process?id={{step_node.process.id_}}">{{step_node.process.title.newest|e}}</a>
 </td>
 <td>
-{% for i in range(indent) %}+{%endfor %}
-<a href="process?id={{node.process.id_}}">{{node.process.title.newest|e}}</a>
+add step: <input name="new_step_to_{{step_id}}" list="candidates" autocomplete="off" />
 </td>
 </tr>
-{% for child in node.children %}
-{{ process_with_children(child, indent+1) }}
+{% if indent < 5 %}
+{% for substep_id, substep in step_node.steps.items() %}
+{{ process_with_steps(substep_id, substep, indent+1) }}
 {% endfor %}
+{% endif %}
 {% endmacro %}
 
 {% block content %}
@@ -22,11 +31,11 @@ title: <input name="title" value="{{process.title.newest|e}}" />
 description: <input name="description" value="{{process.description.newest|e}}" />
 default effort: <input name="effort" type="number" step=0.1 value={{process.effort.newest}} />
 <table>
-{% for child in children %}
-{{ process_with_children(child, 0) }}
+{% for step_id, step_node in steps.items() %}
+{{ process_with_steps(step_id, step_node, 0) }}
 {% endfor %}
 </table>
-add child: <input name="children" list="candidates" autocomplete="off" />
+add step: <input name="new_top_step" list="candidates" autocomplete="off" />
 <datalist id="candidates">
 {% for candidate in candidates %}
 <option value="{{candidate.id_}}">{{candidate.title.newest|e}}</option>
diff --git a/tests/processes.py b/tests/processes.py
index 02f6644..ac519c8 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -1,5 +1,6 @@
 """Test Processes module."""
 from unittest import TestCase
+from typing import Any
 from tests.utils import TestCaseWithDB, TestCaseWithServer
 from plomtask.processes import Process
 from plomtask.exceptions import NotFoundException, BadFormatException
@@ -14,41 +15,91 @@ class TestsSansDB(TestCase):
         self.assertEqual(Process(None).description.newest, '')
         self.assertEqual(Process(None).effort.newest, 1.0)
 
+    def test_Process_legal_ID(self) -> None:
+        """Test Process cannot be instantiated with id_=0."""
+        with self.assertRaises(BadFormatException):
+            Process(0)
+
 
 class TestsWithDB(TestCaseWithDB):
     """Mdule tests not requiring DB setup."""
 
-    def test_Process_save(self) -> None:
-        """Test Process.save()."""
-        p_saved = Process(None)
-        p_saved.save(self.db_conn)
-        self.assertEqual(p_saved.id_,
+    def test_Process_ids(self) -> None:
+        """Test Process.save_without_steps() re Process.id_."""
+        p = Process(None)
+        p.save_without_steps(self.db_conn)
+        self.assertEqual(p.id_,
                          Process.by_id(self.db_conn, 1, create=False).id_)
-        with self.assertRaises(BadFormatException):
-            p_saved = Process(0)
-        p_saved = Process(5)
-        p_saved.save(self.db_conn)
-        self.assertEqual(p_saved.id_,
+        p = Process(None)
+        p.save_without_steps(self.db_conn)
+        self.assertEqual(p.id_,
+                         Process.by_id(self.db_conn, 2, create=False).id_)
+        p = Process(5)
+        p.save_without_steps(self.db_conn)
+        self.assertEqual(p.id_,
                          Process.by_id(self.db_conn, 5, create=False).id_)
-        p_saved.title.set('named')
-        p_loaded = Process.by_id(self.db_conn, p_saved.id_)
-        self.assertNotEqual(p_saved.title.history, p_loaded.title.history)
-        p_saved.save(self.db_conn)
-        p_loaded = Process.by_id(self.db_conn, p_saved.id_)
-        self.assertEqual(p_saved.title.history, p_loaded.title.history)
-        p_9 = Process(9)
-        p_9.child_ids = [4]
-        with self.assertRaises(NotFoundException):
-            p_9.save(self.db_conn)
-        p_9.child_ids = [5]
-        p_9.save(self.db_conn)
-        p_5 = Process.by_id(self.db_conn, 5)
-        p_5.child_ids = [1]
-        p_5.save(self.db_conn)
-        p_1 = Process.by_id(self.db_conn, 1)
-        p_1.child_ids = [9]
-        with self.assertRaises(BadFormatException):
-            p_1.save(self.db_conn)
+
+    def test_Process_versioned_attributes(self) -> None:
+        """Test behavior of VersionedAttributes on saving (with .title)."""
+        p = Process(None)
+        p.save_without_steps(self.db_conn)
+        p.title.set('named')
+        p_loaded = Process.by_id(self.db_conn, p.id_)
+        self.assertNotEqual(p.title.history, p_loaded.title.history)
+        p.save_without_steps(self.db_conn)
+        p_loaded = Process.by_id(self.db_conn, p.id_)
+        self.assertEqual(p.title.history, p_loaded.title.history)
+
+    def test_Process_steps(self) -> None:
+        """Test addition, nesting, and non-recursion of ProcessSteps"""
+        p_1 = Process(1)
+        p_1.save_without_steps(self.db_conn)
+        assert p_1.id_ is not None
+        p_2 = Process(2)
+        p_2.save_without_steps(self.db_conn)
+        assert p_2.id_ is not None
+        p_3 = Process(3)
+        p_3.save_without_steps(self.db_conn)
+        assert p_3.id_ is not None
+        p_1.add_step(self.db_conn, None, p_2.id_, None)
+        p_1_dict: dict[int, dict[str, Any]] = {1: {
+            'process': p_2, 'parent_id': None,
+            'is_explicit': True, 'steps': {}
+        }}
+        self.assertEqual(p_1.get_steps(self.db_conn, None), p_1_dict)
+        s_b = p_1.add_step(self.db_conn, None, p_3.id_, None)
+        p_1_dict[2] = {
+            'process': p_3, 'parent_id': None,
+            'is_explicit': True, 'steps': {}
+        }
+        self.assertEqual(p_1.get_steps(self.db_conn, None), p_1_dict)
+        s_c = p_2.add_step(self.db_conn, None, p_3.id_, None)
+        assert s_c.id_ is not None
+        p_1_dict[1]['steps'] = {3: {
+            'process': p_3, 'parent_id': None,
+            'is_explicit': False, 'steps': {}
+        }}
+        self.assertEqual(p_1.get_steps(self.db_conn, None), p_1_dict)
+        p_1.add_step(self.db_conn, None, p_2.id_, s_b.id_)
+        p_1_dict[2]['steps'][4] = {
+            'process': p_2, 'parent_id': s_b.id_,
+            'is_explicit': True, 'steps': {3: {
+                'process': p_3, 'parent_id': None,
+                'is_explicit': False, 'steps': {}
+                }}}
+        self.assertEqual(p_1.get_steps(self.db_conn, None), p_1_dict)
+        p_1.add_step(self.db_conn, None, p_3.id_, 999)
+        p_1_dict[5] = {
+            'process': p_3, 'parent_id': None,
+            'is_explicit': True, 'steps': {}
+        }
+        self.assertEqual(p_1.get_steps(self.db_conn, None), p_1_dict)
+        p_1.add_step(self.db_conn, None, p_3.id_, 3)
+        p_1_dict[6] = {
+            'process': p_3, 'parent_id': None,
+            'is_explicit': True, 'steps': {}
+        }
+        self.assertEqual(p_1.get_steps(self.db_conn, None), p_1_dict)
 
     def test_Process_by_id(self) -> None:
         """Test Process.by_id()."""
@@ -68,9 +119,9 @@ class TestsWithDB(TestCaseWithDB):
     def test_Process_all(self) -> None:
         """Test Process.all()."""
         p_1 = Process(None)
-        p_1.save(self.db_conn)
+        p_1.save_without_steps(self.db_conn)
         p_2 = Process(None)
-        p_2.save(self.db_conn)
+        p_2.save_without_steps(self.db_conn)
         self.assertEqual({p_1.id_, p_2.id_},
                          set(p.id_ for p in Process.all(self.db_conn)))
 
@@ -80,8 +131,10 @@ class TestsWithServer(TestCaseWithServer):
 
     def test_do_POST_process(self) -> None:
         """Test POST /process and its effect on the database."""
+        self.assertEqual(0, len(Process.all(self.db_conn)))
         form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
         self.check_post(form_data, '/process?id=', 302, '/')
+        self.assertEqual(1, len(Process.all(self.db_conn)))
         self.check_post(form_data, '/process?id=FOO', 400)
         form_data['effort'] = 'foo'
         self.check_post(form_data, '/process?id=', 400)
@@ -92,30 +145,7 @@ class TestsWithServer(TestCaseWithServer):
         self.check_post(form_data, '/process?id=', 400)
         form_data = {'description': '', 'effort': 1.0}
         self.check_post(form_data, '/process?id=', 400)
-        form_data = {'title': '', 'description': '',
-                     'effort': 1.1, 'children': [1]}
-        self.check_post(form_data, '/process?id=', 302, '/')
-        form_data['children'] = 1.1
-        self.check_post(form_data, '/process?id=', 400)
-        form_data['children'] = 'a'
-        self.check_post(form_data, '/process?id=', 400)
-        form_data['children'] = [1.2]
-        self.check_post(form_data, '/process?id=', 400)
-        form_data['children'] = ['b']
-        self.check_post(form_data, '/process?id=', 400)
-        form_data['children'] = [2]
-        self.check_post(form_data, '/process?id=1', 400)
-        retrieved_1 = Process.by_id(self.db_conn, 1)
-        self.assertEqual(retrieved_1.title.newest, 'foo')
-        self.assertEqual(retrieved_1.child_ids, [])
-        retrieved_2 = Process.by_id(self.db_conn, 2)
-        self.assertEqual(retrieved_2.child_ids, [1])
-        form_data = {'title': 'bar', 'description': 'bar', 'effort': 1.1}
-        self.check_post(form_data, '/process?id=1', 302, '/')
-        retrieved_1 = Process.by_id(self.db_conn, 1)
-        self.assertEqual(retrieved_1.title.newest, 'bar')
-        self.assertEqual([p.id_ for p in Process.all(self.db_conn)],
-                         [retrieved_1.id_, retrieved_2.id_])
+        self.assertEqual(1, len(Process.all(self.db_conn)))
 
     def test_do_GET(self) -> None:
         """Test /process and /processes response codes."""
-- 
2.30.2