From 13845c83a9e3e107aa7c40e86d8a0cda1a317f8a Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Tue, 2 Apr 2024 02:35:20 +0200
Subject: [PATCH] Draw Process descendant trees, and guard against recursion
 within them.

---
 plomtask/http.py       |  2 +-
 plomtask/processes.py  | 26 +++++++++++++++----
 templates/process.html | 26 ++++++++++++-------
 tests/processes.py     | 58 ++++++++++++++++++++++++------------------
 4 files changed, 72 insertions(+), 40 deletions(-)

diff --git a/plomtask/http.py b/plomtask/http.py
index 3710595..5b7100c 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -137,7 +137,7 @@ class TaskHandler(BaseHTTPRequestHandler):
         id_ = params.get_int_or_none('id')
         process = Process.by_id(conn, id_, create=True)
         return self.server.jinja.get_template('process.html').render(
-                process=process, children=process.children(conn),
+                process=process, children=process.get_descendants(conn),
                 candidates=Process.all(conn))
 
     def do_GET_processes(self, conn: DatabaseConnection,
diff --git a/plomtask/processes.py b/plomtask/processes.py
index bebd394..ba9707b 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -41,7 +41,7 @@ class Process:
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection, id_: int | None,
               create: bool = False) -> Process:
-        """Collect all Processes and their connected VersionedAttributes."""
+        """Collect Process, its VersionedAttributes, and its child IDs."""
         process = None
         for row in db_conn.exec('SELECT * FROM processes '
                                 'WHERE id = ?', (id_,)):
@@ -66,12 +66,27 @@ class Process:
                 process.child_ids += [row[1]]
         return process
 
-    def children(self, db_conn: DatabaseConnection) -> list[Process]:
-        """Return child Processes as determined by self.child_ids."""
-        return [self.__class__.by_id(db_conn, id_) for id_ in self.child_ids]
+    def get_descendants(self, db_conn: DatabaseConnection) ->\
+            dict[int, dict[str, object]]:
+        """Return tree of descendant Processes"""
+        descendants = {}
+        for id_ in self.child_ids:
+            child = self.__class__.by_id(db_conn, id_)
+            descendants[id_] = {'process': child,
+                                'children': child.get_descendants(db_conn)}
+        return descendants
 
     def save(self, db_conn: DatabaseConnection) -> None:
-        """Add (or re-write) self and connected VersionedAttributes to DB."""
+        """Add (or re-write) self and connected VersionedAttributes to DB.
+
+        Also is the point at which descendancy recursion is checked.
+        """
+        def walk_descendants(node_id: int) -> None:
+            if node_id == self.id_:
+                raise BadFormatException('bad child selection: recursion')
+            descendant = self.by_id(db_conn, node_id)
+            for descendant_id in descendant.child_ids:
+                walk_descendants(descendant_id)
         cursor = db_conn.exec('REPLACE INTO processes VALUES (?)', (self.id_,))
         self.id_ = cursor.lastrowid
         self.title.save(db_conn)
@@ -80,6 +95,7 @@ class Process:
         db_conn.exec('DELETE FROM process_children WHERE parent_id = ?',
                      (self.id_,))
         for child_id in self.child_ids:
+            walk_descendants(child_id)
             db_conn.exec('INSERT INTO process_children VALUES (?, ?)',
                          (self.id_, child_id))
 
diff --git a/templates/process.html b/templates/process.html
index f2ef5aa..f2d5055 100644
--- a/templates/process.html
+++ b/templates/process.html
@@ -1,5 +1,20 @@
 {% extends 'base.html' %}
 
+{% macro process_with_children(node, indent) %}
+<tr>
+<td>
+<input type="checkbox" name="children" value="{{node.process.id_}}" checked />
+</td>
+<td>
+{% for i in range(indent) %}+{%endfor %}
+<a href="process?id={{node.process.id_}}">{{node.process.title.newest|e}}</a>
+</td>
+</tr>
+{% for child in node.children.values() %}
+{{ process_with_children(child, indent+1) }}
+{% endfor %}
+{% endmacro %}
+
 {% block content %}
 <h3>Process</h3>
 <form action="process?id={{process.id_ or ''}}" method="POST">
@@ -7,15 +22,8 @@ title: <input name="title" value="{{process.title.newest|e}}" />
 description: <input name="description" value="{{process.description.newest|e}}" />
 default effort: <input name="effort" type="number" step=0.1 value={{process.effort.newest}} />
 <table>
-{% for child in children %}
-<tr>
-<td>
-<input type="checkbox" name="children" value="{{child.id_}}" checked />
-</td>
-<td>
-<a href="process?id={{child.id_}}">{{child.title.newest|e}}</a>
-</td>
-</tr>
+{% for child in children.values() %}
+{{ process_with_children(child, 0) }}
 {% endfor %}
 </table>
 add child: <input name="children" list="candidates" autocomplete="off" />
diff --git a/tests/processes.py b/tests/processes.py
index 8a7f91d..02f6644 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -36,6 +36,19 @@ class TestsWithDB(TestCaseWithDB):
         p_saved.save(self.db_conn)
         p_loaded = Process.by_id(self.db_conn, p_saved.id_)
         self.assertEqual(p_saved.title.history, p_loaded.title.history)
+        p_9 = Process(9)
+        p_9.child_ids = [4]
+        with self.assertRaises(NotFoundException):
+            p_9.save(self.db_conn)
+        p_9.child_ids = [5]
+        p_9.save(self.db_conn)
+        p_5 = Process.by_id(self.db_conn, 5)
+        p_5.child_ids = [1]
+        p_5.save(self.db_conn)
+        p_1 = Process.by_id(self.db_conn, 1)
+        p_1.child_ids = [9]
+        with self.assertRaises(BadFormatException):
+            p_1.save(self.db_conn)
 
     def test_Process_by_id(self) -> None:
         """Test Process.by_id()."""
@@ -67,20 +80,11 @@ class TestsWithServer(TestCaseWithServer):
 
     def test_do_POST_process(self) -> None:
         """Test POST /process and its effect on the database."""
-        form_data = {'title': 'foo', 'description': 'foo',
-                     'effort': 1.1, 'children': [1, 2]}
+        form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
+        self.check_post(form_data, '/process?id=', 302, '/')
         self.check_post(form_data, '/process?id=FOO', 400)
         form_data['effort'] = 'foo'
         self.check_post(form_data, '/process?id=', 400)
-        form_data['effort'] = 1.1
-        form_data['children'] = 1.1
-        self.check_post(form_data, '/process?id=', 400)
-        form_data['children'] = 'a'
-        self.check_post(form_data, '/process?id=', 400)
-        form_data['children'] = [1, 1.2]
-        self.check_post(form_data, '/process?id=', 400)
-        form_data['children'] = [1, 'b']
-        self.check_post(form_data, '/process?id=', 400)
         self.check_post({}, '/process?id=', 400)
         form_data = {'title': '', 'description': ''}
         self.check_post(form_data, '/process?id=', 400)
@@ -89,25 +93,29 @@ class TestsWithServer(TestCaseWithServer):
         form_data = {'description': '', 'effort': 1.0}
         self.check_post(form_data, '/process?id=', 400)
         form_data = {'title': '', 'description': '',
-                     'effort': 1.1, 'children': [1, 2]}
+                     'effort': 1.1, 'children': [1]}
         self.check_post(form_data, '/process?id=', 302, '/')
+        form_data['children'] = 1.1
+        self.check_post(form_data, '/process?id=', 400)
+        form_data['children'] = 'a'
+        self.check_post(form_data, '/process?id=', 400)
+        form_data['children'] = [1.2]
+        self.check_post(form_data, '/process?id=', 400)
+        form_data['children'] = ['b']
+        self.check_post(form_data, '/process?id=', 400)
+        form_data['children'] = [2]
+        self.check_post(form_data, '/process?id=1', 400)
         retrieved_1 = Process.by_id(self.db_conn, 1)
-        self.assertEqual(retrieved_1.title.newest, '')
-        self.assertEqual(retrieved_1.child_ids, [1, 2])
-        form_data['children'] = []
-        self.check_post(form_data, '/process?id=', 302, '/')
+        self.assertEqual(retrieved_1.title.newest, 'foo')
+        self.assertEqual(retrieved_1.child_ids, [])
         retrieved_2 = Process.by_id(self.db_conn, 2)
-        self.assertEqual(retrieved_2.child_ids, [])
-        del form_data['children']
-        self.check_post(form_data, '/process?id=', 302, '/')
-        retrieved_3 = Process.by_id(self.db_conn, 3)
-        self.assertEqual(retrieved_2.child_ids, [])
+        self.assertEqual(retrieved_2.child_ids, [1])
+        form_data = {'title': 'bar', 'description': 'bar', 'effort': 1.1}
         self.check_post(form_data, '/process?id=1', 302, '/')
-        self.assertEqual([p.id_ for p in Process.all(self.db_conn)],
-                         [retrieved_1.id_, retrieved_2.id_, retrieved_3.id_])
         retrieved_1 = Process.by_id(self.db_conn, 1)
-        self.assertEqual(retrieved_1.child_ids, [])
-        self.check_post(form_data, '/process', 302, '/')
+        self.assertEqual(retrieved_1.title.newest, 'bar')
+        self.assertEqual([p.id_ for p in Process.all(self.db_conn)],
+                         [retrieved_1.id_, retrieved_2.id_])
 
     def test_do_GET(self) -> None:
         """Test /process and /processes response codes."""
-- 
2.30.2