From 8bbb9ac156bdca7b9dd015b62db3f07f1e7a9e17 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sat, 30 Mar 2024 07:14:40 +0100
Subject: [PATCH] Add Process.children and improve Params/Postvars parsing and
 testing.

---
 plomtask/http.py       | 51 +++++++++++++++++-----
 plomtask/processes.py  | 13 ++++++
 scripts/init.sql       |  6 +++
 templates/process.html | 18 ++++++++
 tests/misc.py          | 96 ++++++++++++++++++++++++++++++++++++++++++
 tests/processes.py     | 43 ++++++++++++++-----
 tests/utils.py         |  4 +-
 7 files changed, 208 insertions(+), 23 deletions(-)

diff --git a/plomtask/http.py b/plomtask/http.py
index cd3e445..3710595 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -32,19 +32,20 @@ class ParamsParser:
 
     def get_str(self, key: str, default: str = '') -> str:
         """Retrieve string value of key from self.params."""
-        if key not in self.params:
+        if key not in self.params or 0 == len(self.params[key]):
             return default
         return self.params[key][0]
 
     def get_int_or_none(self, key: str) -> int | None:
         """Retrieve int value of key from self.params, on empty return None."""
-        if key not in self.params or not self.params[key]:
+        if key not in self.params or \
+                0 == len(''.join(list(self.params[key]))):
             return None
-        val = self.params[key][0]
+        val_str = self.params[key][0]
         try:
-            return int(val)
+            return int(val_str)
         except ValueError as e:
-            raise BadFormatException(f'Bad ?{key}= value: {val}') from e
+            raise BadFormatException(f'Bad ?{key}= value: {val_str}') from e
 
 
 class PostvarsParser:
@@ -55,9 +56,19 @@ class PostvarsParser:
 
     def get_str(self, key: str) -> str:
         """Retrieve string value of key from self.postvars."""
-        if key not in self.postvars:
-            raise BadFormatException(f'missing value for form field: {key}')
-        return self.postvars[key][0]
+        all_str = self.get_all_str(key)
+        if 0 == len(all_str):
+            raise BadFormatException(f'missing value for key: {key}')
+        return all_str[0]
+
+    def get_int(self, key: str) -> int:
+        """Retrieve int value of key from self.postvars."""
+        val = self.get_str(key)
+        try:
+            return int(val)
+        except ValueError as e:
+            msg = f'cannot int form field value: {val}'
+            raise BadFormatException(msg) from e
 
     def get_float(self, key: str) -> float:
         """Retrieve float value of key from self.postvars."""
@@ -68,6 +79,21 @@ class PostvarsParser:
             msg = f'cannot float form field value: {val}'
             raise BadFormatException(msg) from e
 
+    def get_all_str(self, key: str) -> list[str]:
+        """Retrieve list of string values at key from self.postvars."""
+        if key not in self.postvars:
+            return []
+        return self.postvars[key]
+
+    def get_all_int(self, key: str) -> list[int]:
+        """Retrieve list of int values at key from self.postvars."""
+        all_str = self.get_all_str(key)
+        try:
+            return [int(s) for s in all_str if len(s) > 0]
+        except ValueError as e:
+            msg = f'cannot int a form field value: {all_str}'
+            raise BadFormatException(msg) from e
+
 
 class TaskHandler(BaseHTTPRequestHandler):
     """Handles single HTTP request."""
@@ -109,8 +135,10 @@ class TaskHandler(BaseHTTPRequestHandler):
                        params: ParamsParser) -> str:
         """Show process of ?id=."""
         id_ = params.get_int_or_none('id')
+        process = Process.by_id(conn, id_, create=True)
         return self.server.jinja.get_template('process.html').render(
-                process=Process.by_id(conn, id_, create=True))
+                process=process, children=process.children(conn),
+                candidates=Process.all(conn))
 
     def do_GET_processes(self, conn: DatabaseConnection,
                          _: ParamsParser) -> str:
@@ -124,7 +152,7 @@ class TaskHandler(BaseHTTPRequestHandler):
             conn, site, params = self._init_handling()
             length = int(self.headers['content-length'])
             postvars = parse_qs(self.rfile.read(length).decode(),
-                                keep_blank_values=True)
+                                keep_blank_values=True, strict_parsing=True)
             form_data = PostvarsParser(postvars)
             if site in ('day', 'process'):
                 getattr(self, f'do_POST_{site}')(conn, params, form_data)
@@ -153,13 +181,14 @@ class TaskHandler(BaseHTTPRequestHandler):
         process.title.set(form_data.get_str('title'))
         process.description.set(form_data.get_str('description'))
         process.effort.set(form_data.get_float('effort'))
+        process.child_ids = form_data.get_all_int('children')
         process.save(conn)
 
     def _init_handling(self) -> tuple[DatabaseConnection, str, ParamsParser]:
         conn = DatabaseConnection(self.server.db)
         parsed_url = urlparse(self.path)
         site = path_split(parsed_url.path)[1]
-        params = ParamsParser(parse_qs(parsed_url.query))
+        params = ParamsParser(parse_qs(parsed_url.query, strict_parsing=True))
         return conn, site, params
 
     def _redirect(self, target: str) -> None:
diff --git a/plomtask/processes.py b/plomtask/processes.py
index 0300e73..bebd394 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -16,6 +16,7 @@ class Process:
         self.title = VersionedAttribute(self, 'title', 'UNNAMED')
         self.description = VersionedAttribute(self, 'description', '')
         self.effort = VersionedAttribute(self, 'effort', 1.0)
+        self.child_ids: list[int] = []
 
     @classmethod
     def from_table_row(cls, row: Row) -> Process:
@@ -60,8 +61,15 @@ class Process:
             for row in db_conn.exec('SELECT * FROM process_efforts '
                                     'WHERE process_id = ?', (process.id_,)):
                 process.effort.history[row[1]] = row[2]
+            for row in db_conn.exec('SELECT * FROM process_children '
+                                    'WHERE parent_id = ?', (process.id_,)):
+                process.child_ids += [row[1]]
         return process
 
+    def children(self, db_conn: DatabaseConnection) -> list[Process]:
+        """Return child Processes as determined by self.child_ids."""
+        return [self.__class__.by_id(db_conn, id_) for id_ in self.child_ids]
+
     def save(self, db_conn: DatabaseConnection) -> None:
         """Add (or re-write) self and connected VersionedAttributes to DB."""
         cursor = db_conn.exec('REPLACE INTO processes VALUES (?)', (self.id_,))
@@ -69,6 +77,11 @@ class Process:
         self.title.save(db_conn)
         self.description.save(db_conn)
         self.effort.save(db_conn)
+        db_conn.exec('DELETE FROM process_children WHERE parent_id = ?',
+                     (self.id_,))
+        for child_id in self.child_ids:
+            db_conn.exec('INSERT INTO process_children VALUES (?, ?)',
+                         (self.id_, child_id))
 
 
 class VersionedAttribute:
diff --git a/scripts/init.sql b/scripts/init.sql
index a98e828..341f2ab 100644
--- a/scripts/init.sql
+++ b/scripts/init.sql
@@ -2,6 +2,12 @@ CREATE TABLE days (
     date TEXT PRIMARY KEY,
     comment TEXT NOT NULL
 );
+CREATE TABLE process_children (
+    parent_id INTEGER NOT NULL,
+    child_id INTEGER NOT NULL,
+    FOREIGN KEY (parent_id) REFERENCES processes(id),
+    FOREIGN KEY (child_id) REFERENCES processes(id)
+);
 CREATE TABLE process_descriptions (
     process_id INTEGER NOT NULL,
     timestamp TEXT NOT NULL,
diff --git a/templates/process.html b/templates/process.html
index 1743936..f2ef5aa 100644
--- a/templates/process.html
+++ b/templates/process.html
@@ -6,6 +6,24 @@
 title: <input name="title" value="{{process.title.newest|e}}" />
 description: <input name="description" value="{{process.description.newest|e}}" />
 default effort: <input name="effort" type="number" step=0.1 value={{process.effort.newest}} />
+<table>
+{% for child in children %}
+<tr>
+<td>
+<input type="checkbox" name="children" value="{{child.id_}}" checked />
+</td>
+<td>
+<a href="process?id={{child.id_}}">{{child.title.newest|e}}</a>
+</td>
+</tr>
+{% endfor %}
+</table>
+add child: <input name="children" list="candidates" autocomplete="off" />
+<datalist id="candidates">
+{% for candidate in candidates %}
+<option value="{{candidate.id_}}">{{candidate.title.newest|e}}</option>
+{% endfor %}
+</datalist>
 <input type="submit" value="OK" />
 </form>
 {% endblock %}
diff --git a/tests/misc.py b/tests/misc.py
index 893d67e..87b3a6e 100644
--- a/tests/misc.py
+++ b/tests/misc.py
@@ -1,5 +1,101 @@
 """Miscellaneous tests."""
+from unittest import TestCase
 from tests.utils import TestCaseWithServer
+from plomtask.http import ParamsParser, PostvarsParser
+from plomtask.exceptions import BadFormatException
+
+
+class TestsSansServer(TestCase):
+    """Tests that do not require DB setup or a server."""
+
+    def test_params_parser(self) -> None:
+        """Test behavior of ParamsParser."""
+        self.assertEqual('',
+                         ParamsParser({}).get_str('foo'))
+        self.assertEqual('bar',
+                         ParamsParser({}).get_str('foo', 'bar'))
+        self.assertEqual('bar',
+                         ParamsParser({'foo': []}).get_str('foo', 'bar'))
+        self.assertEqual('baz',
+                         ParamsParser({'foo': ['baz']}).get_str('foo', 'bar'))
+        self.assertEqual(None,
+                         ParamsParser({}).get_int_or_none('foo'))
+        self.assertEqual(None,
+                         ParamsParser({'foo': []}).get_int_or_none('foo'))
+        self.assertEqual(None,
+                         ParamsParser({'foo': ['']}).get_int_or_none('foo'))
+        self.assertEqual(0,
+                         ParamsParser({'foo': ['0']}).get_int_or_none('foo'))
+        with self.assertRaises(BadFormatException):
+            ParamsParser({'foo': ['None']}).get_int_or_none('foo')
+        with self.assertRaises(BadFormatException):
+            ParamsParser({'foo': ['0.1']}).get_int_or_none('foo')
+        self.assertEqual(23,
+                         ParamsParser({'foo': ['23']}).get_int_or_none('foo'))
+
+    def test_postvars_parser(self) -> None:
+        """Test behavior of PostvarsParser."""
+        self.assertEqual([],
+                         PostvarsParser({}).get_all_str('foo'))
+        self.assertEqual([],
+                         PostvarsParser({'foo': []}).get_all_str('foo'))
+        self.assertEqual(['bar'],
+                         PostvarsParser({'foo': ['bar']}).get_all_str('foo'))
+        self.assertEqual(['bar', 'baz'],
+                         PostvarsParser({'foo': ['bar', 'baz']}).
+                         get_all_str('foo'))
+        self.assertEqual([],
+                         PostvarsParser({}).get_all_int('foo'))
+        self.assertEqual([],
+                         PostvarsParser({'foo': []}).get_all_int('foo'))
+        self.assertEqual([],
+                         PostvarsParser({'foo': ['']}).get_all_int('foo'))
+        self.assertEqual([0],
+                         PostvarsParser({'foo': ['0']}).get_all_int('foo'))
+        self.assertEqual([0, 17],
+                         PostvarsParser({'foo': ['0', '17']}).
+                         get_all_int('foo'))
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': ['0.1', '17']}).get_all_int('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': ['None', '17']}).get_all_int('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({}).get_str('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': []}).get_str('foo')
+        self.assertEqual('bar',
+                         PostvarsParser({'foo': ['bar']}).get_str('foo'))
+        self.assertEqual('',
+                         PostvarsParser({'foo': ['', 'baz']}).get_str('foo'))
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({}).get_int('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': []}).get_int('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': ['']}).get_int('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': ['bar']}).get_int('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': ['0.1']}).get_int('foo')
+        self.assertEqual(0,
+                         PostvarsParser({'foo': ['0']}).get_int('foo'))
+        self.assertEqual(17,
+                         PostvarsParser({'foo': ['17', '23']}).get_int('foo'))
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({}).get_float('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': []}).get_float('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': ['']}).get_float('foo')
+        with self.assertRaises(BadFormatException):
+            PostvarsParser({'foo': ['bar']}).get_float('foo')
+        self.assertEqual(0,
+                         PostvarsParser({'foo': ['0']}).get_float('foo'))
+        self.assertEqual(0.1,
+                         PostvarsParser({'foo': ['0.1']}).get_float('foo'))
+        self.assertEqual(1.23,
+                         PostvarsParser({'foo': ['1.23', '456']}).
+                         get_float('foo'))
 
 
 class TestsWithServer(TestCaseWithServer):
diff --git a/tests/processes.py b/tests/processes.py
index fb572f5..8a7f91d 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -67,23 +67,46 @@ class TestsWithServer(TestCaseWithServer):
 
     def test_do_POST_process(self) -> None:
         """Test POST /process and its effect on the database."""
-        form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.0}
+        form_data = {'title': 'foo', 'description': 'foo',
+                     'effort': 1.1, 'children': [1, 2]}
         self.check_post(form_data, '/process?id=FOO', 400)
         form_data['effort'] = 'foo'
         self.check_post(form_data, '/process?id=', 400)
-        form_data['effort'] = None
+        form_data['effort'] = 1.1
+        form_data['children'] = 1.1
+        self.check_post(form_data, '/process?id=', 400)
+        form_data['children'] = 'a'
+        self.check_post(form_data, '/process?id=', 400)
+        form_data['children'] = [1, 1.2]
+        self.check_post(form_data, '/process?id=', 400)
+        form_data['children'] = [1, 'b']
         self.check_post(form_data, '/process?id=', 400)
         self.check_post({}, '/process?id=', 400)
-        self.check_post({'title': '', 'description': ''}, '/process?id=', 400)
-        self.check_post({'title': '', 'effort': 1}, '/process?id=', 400)
-        self.check_post({'description': '', 'effort': 1}, '/process?id=', 400)
-        form_data = {'title': None, 'description': 1, 'effort': 1.0}
+        form_data = {'title': '', 'description': ''}
+        self.check_post(form_data, '/process?id=', 400)
+        form_data = {'title': '', 'effort': 1.1}
+        self.check_post(form_data, '/process?id=', 400)
+        form_data = {'description': '', 'effort': 1.0}
+        self.check_post(form_data, '/process?id=', 400)
+        form_data = {'title': '', 'description': '',
+                     'effort': 1.1, 'children': [1, 2]}
         self.check_post(form_data, '/process?id=', 302, '/')
-        retrieved = Process.by_id(self.db_conn, 1)
-        self.assertEqual(retrieved.title.newest, 'None')
-        self.assertEqual([p.id_ for p in Process.all(self.db_conn)],
-                         [retrieved.id_])
+        retrieved_1 = Process.by_id(self.db_conn, 1)
+        self.assertEqual(retrieved_1.title.newest, '')
+        self.assertEqual(retrieved_1.child_ids, [1, 2])
+        form_data['children'] = []
+        self.check_post(form_data, '/process?id=', 302, '/')
+        retrieved_2 = Process.by_id(self.db_conn, 2)
+        self.assertEqual(retrieved_2.child_ids, [])
+        del form_data['children']
+        self.check_post(form_data, '/process?id=', 302, '/')
+        retrieved_3 = Process.by_id(self.db_conn, 3)
+        self.assertEqual(retrieved_2.child_ids, [])
         self.check_post(form_data, '/process?id=1', 302, '/')
+        self.assertEqual([p.id_ for p in Process.all(self.db_conn)],
+                         [retrieved_1.id_, retrieved_2.id_, retrieved_3.id_])
+        retrieved_1 = Process.by_id(self.db_conn, 1)
+        self.assertEqual(retrieved_1.child_ids, [])
         self.check_post(form_data, '/process', 302, '/')
 
     def test_do_GET(self) -> None:
diff --git a/tests/utils.py b/tests/utils.py
index d41e7b3..c80b34d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -54,9 +54,9 @@ class TestCaseWithServer(TestCaseWithDB):
         self.assertEqual(self.conn.getresponse().status, expected_code)
 
     def check_post(self, data: Mapping[str, object], target: str,
-                   expected_code: int, redirect_location: str = '') -> None:
+                   expected_code: int, redirect_location: str = '/') -> None:
         """Check that POST of data to target yields expected_code."""
-        encoded_form_data = urlencode(data).encode('utf-8')
+        encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
         headers = {'Content-Type': 'application/x-www-form-urlencoded',
                    'Content-Length': str(len(encoded_form_data))}
         self.conn.request('POST', target,
-- 
2.30.2