def get_str(self, key: str, default: str = '') -> str:
"""Retrieve string value of key from self.params."""
- if key not in self.params:
+ if key not in self.params or 0 == len(self.params[key]):
return default
return self.params[key][0]
def get_int_or_none(self, key: str) -> int | None:
"""Retrieve int value of key from self.params, on empty return None."""
- if key not in self.params or not self.params[key]:
+ if key not in self.params or \
+ 0 == len(''.join(list(self.params[key]))):
return None
- val = self.params[key][0]
+ val_str = self.params[key][0]
try:
- return int(val)
+ return int(val_str)
except ValueError as e:
- raise BadFormatException(f'Bad ?{key}= value: {val}') from e
+ raise BadFormatException(f'Bad ?{key}= value: {val_str}') from e
class PostvarsParser:
def get_str(self, key: str) -> str:
"""Retrieve string value of key from self.postvars."""
- if key not in self.postvars:
- raise BadFormatException(f'missing value for form field: {key}')
- return self.postvars[key][0]
+ all_str = self.get_all_str(key)
+ if 0 == len(all_str):
+ raise BadFormatException(f'missing value for key: {key}')
+ return all_str[0]
+
+ def get_int(self, key: str) -> int:
+ """Retrieve int value of key from self.postvars."""
+ val = self.get_str(key)
+ try:
+ return int(val)
+ except ValueError as e:
+ msg = f'cannot int form field value: {val}'
+ raise BadFormatException(msg) from e
def get_float(self, key: str) -> float:
"""Retrieve float value of key from self.postvars."""
msg = f'cannot float form field value: {val}'
raise BadFormatException(msg) from e
+ def get_all_str(self, key: str) -> list[str]:
+ """Retrieve list of string values at key from self.postvars."""
+ if key not in self.postvars:
+ return []
+ return self.postvars[key]
+
+ def get_all_int(self, key: str) -> list[int]:
+ """Retrieve list of int values at key from self.postvars."""
+ all_str = self.get_all_str(key)
+ try:
+ return [int(s) for s in all_str if len(s) > 0]
+ except ValueError as e:
+ msg = f'cannot int a form field value: {all_str}'
+ raise BadFormatException(msg) from e
+
class TaskHandler(BaseHTTPRequestHandler):
"""Handles single HTTP request."""
params: ParamsParser) -> str:
"""Show process of ?id=."""
id_ = params.get_int_or_none('id')
+ process = Process.by_id(conn, id_, create=True)
return self.server.jinja.get_template('process.html').render(
- process=Process.by_id(conn, id_, create=True))
+ process=process, children=process.children(conn),
+ candidates=Process.all(conn))
def do_GET_processes(self, conn: DatabaseConnection,
_: ParamsParser) -> str:
conn, site, params = self._init_handling()
length = int(self.headers['content-length'])
postvars = parse_qs(self.rfile.read(length).decode(),
- keep_blank_values=True)
+ keep_blank_values=True, strict_parsing=True)
form_data = PostvarsParser(postvars)
if site in ('day', 'process'):
getattr(self, f'do_POST_{site}')(conn, params, form_data)
process.title.set(form_data.get_str('title'))
process.description.set(form_data.get_str('description'))
process.effort.set(form_data.get_float('effort'))
+ process.child_ids = form_data.get_all_int('children')
process.save(conn)
def _init_handling(self) -> tuple[DatabaseConnection, str, ParamsParser]:
conn = DatabaseConnection(self.server.db)
parsed_url = urlparse(self.path)
site = path_split(parsed_url.path)[1]
- params = ParamsParser(parse_qs(parsed_url.query))
+ params = ParamsParser(parse_qs(parsed_url.query, strict_parsing=True))
return conn, site, params
def _redirect(self, target: str) -> None:
self.title = VersionedAttribute(self, 'title', 'UNNAMED')
self.description = VersionedAttribute(self, 'description', '')
self.effort = VersionedAttribute(self, 'effort', 1.0)
+ self.child_ids: list[int] = []
@classmethod
def from_table_row(cls, row: Row) -> Process:
for row in db_conn.exec('SELECT * FROM process_efforts '
'WHERE process_id = ?', (process.id_,)):
process.effort.history[row[1]] = row[2]
+ for row in db_conn.exec('SELECT * FROM process_children '
+ 'WHERE parent_id = ?', (process.id_,)):
+ process.child_ids += [row[1]]
return process
+ def children(self, db_conn: DatabaseConnection) -> list[Process]:
+ """Return child Processes as determined by self.child_ids."""
+ return [self.__class__.by_id(db_conn, id_) for id_ in self.child_ids]
+
def save(self, db_conn: DatabaseConnection) -> None:
"""Add (or re-write) self and connected VersionedAttributes to DB."""
cursor = db_conn.exec('REPLACE INTO processes VALUES (?)', (self.id_,))
self.title.save(db_conn)
self.description.save(db_conn)
self.effort.save(db_conn)
+ db_conn.exec('DELETE FROM process_children WHERE parent_id = ?',
+ (self.id_,))
+ for child_id in self.child_ids:
+ db_conn.exec('INSERT INTO process_children VALUES (?, ?)',
+ (self.id_, child_id))
class VersionedAttribute:
date TEXT PRIMARY KEY,
comment TEXT NOT NULL
);
+CREATE TABLE process_children (
+ parent_id INTEGER NOT NULL,
+ child_id INTEGER NOT NULL,
+ FOREIGN KEY (parent_id) REFERENCES processes(id),
+ FOREIGN KEY (child_id) REFERENCES processes(id)
+);
CREATE TABLE process_descriptions (
process_id INTEGER NOT NULL,
timestamp TEXT NOT NULL,
title: <input name="title" value="{{process.title.newest|e}}" />
description: <input name="description" value="{{process.description.newest|e}}" />
default effort: <input name="effort" type="number" step=0.1 value={{process.effort.newest}} />
+<table>
+{% for child in children %}
+<tr>
+<td>
+<input type="checkbox" name="children" value="{{child.id_}}" checked />
+</td>
+<td>
+<a href="process?id={{child.id_}}">{{child.title.newest|e}}</a>
+</td>
+</tr>
+{% endfor %}
+</table>
+add child: <input name="children" list="candidates" autocomplete="off" />
+<datalist id="candidates">
+{% for candidate in candidates %}
+<option value="{{candidate.id_}}">{{candidate.title.newest|e}}</option>
+{% endfor %}
+</datalist>
<input type="submit" value="OK" />
</form>
{% endblock %}
"""Miscellaneous tests."""
+from unittest import TestCase
from tests.utils import TestCaseWithServer
+from plomtask.http import ParamsParser, PostvarsParser
+from plomtask.exceptions import BadFormatException
+
+
+class TestsSansServer(TestCase):
+ """Tests that do not require DB setup or a server."""
+
+ def test_params_parser(self) -> None:
+ """Test behavior of ParamsParser."""
+ self.assertEqual('',
+ ParamsParser({}).get_str('foo'))
+ self.assertEqual('bar',
+ ParamsParser({}).get_str('foo', 'bar'))
+ self.assertEqual('bar',
+ ParamsParser({'foo': []}).get_str('foo', 'bar'))
+ self.assertEqual('baz',
+ ParamsParser({'foo': ['baz']}).get_str('foo', 'bar'))
+ self.assertEqual(None,
+ ParamsParser({}).get_int_or_none('foo'))
+ self.assertEqual(None,
+ ParamsParser({'foo': []}).get_int_or_none('foo'))
+ self.assertEqual(None,
+ ParamsParser({'foo': ['']}).get_int_or_none('foo'))
+ self.assertEqual(0,
+ ParamsParser({'foo': ['0']}).get_int_or_none('foo'))
+ with self.assertRaises(BadFormatException):
+ ParamsParser({'foo': ['None']}).get_int_or_none('foo')
+ with self.assertRaises(BadFormatException):
+ ParamsParser({'foo': ['0.1']}).get_int_or_none('foo')
+ self.assertEqual(23,
+ ParamsParser({'foo': ['23']}).get_int_or_none('foo'))
+
+ def test_postvars_parser(self) -> None:
+ """Test behavior of PostvarsParser."""
+ self.assertEqual([],
+ PostvarsParser({}).get_all_str('foo'))
+ self.assertEqual([],
+ PostvarsParser({'foo': []}).get_all_str('foo'))
+ self.assertEqual(['bar'],
+ PostvarsParser({'foo': ['bar']}).get_all_str('foo'))
+ self.assertEqual(['bar', 'baz'],
+ PostvarsParser({'foo': ['bar', 'baz']}).
+ get_all_str('foo'))
+ self.assertEqual([],
+ PostvarsParser({}).get_all_int('foo'))
+ self.assertEqual([],
+ PostvarsParser({'foo': []}).get_all_int('foo'))
+ self.assertEqual([],
+ PostvarsParser({'foo': ['']}).get_all_int('foo'))
+ self.assertEqual([0],
+ PostvarsParser({'foo': ['0']}).get_all_int('foo'))
+ self.assertEqual([0, 17],
+ PostvarsParser({'foo': ['0', '17']}).
+ get_all_int('foo'))
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': ['0.1', '17']}).get_all_int('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': ['None', '17']}).get_all_int('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({}).get_str('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': []}).get_str('foo')
+ self.assertEqual('bar',
+ PostvarsParser({'foo': ['bar']}).get_str('foo'))
+ self.assertEqual('',
+ PostvarsParser({'foo': ['', 'baz']}).get_str('foo'))
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({}).get_int('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': []}).get_int('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': ['']}).get_int('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': ['bar']}).get_int('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': ['0.1']}).get_int('foo')
+ self.assertEqual(0,
+ PostvarsParser({'foo': ['0']}).get_int('foo'))
+ self.assertEqual(17,
+ PostvarsParser({'foo': ['17', '23']}).get_int('foo'))
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({}).get_float('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': []}).get_float('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': ['']}).get_float('foo')
+ with self.assertRaises(BadFormatException):
+ PostvarsParser({'foo': ['bar']}).get_float('foo')
+ self.assertEqual(0,
+ PostvarsParser({'foo': ['0']}).get_float('foo'))
+ self.assertEqual(0.1,
+ PostvarsParser({'foo': ['0.1']}).get_float('foo'))
+ self.assertEqual(1.23,
+ PostvarsParser({'foo': ['1.23', '456']}).
+ get_float('foo'))
class TestsWithServer(TestCaseWithServer):
def test_do_POST_process(self) -> None:
"""Test POST /process and its effect on the database."""
- form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.0}
+ form_data = {'title': 'foo', 'description': 'foo',
+ 'effort': 1.1, 'children': [1, 2]}
self.check_post(form_data, '/process?id=FOO', 400)
form_data['effort'] = 'foo'
self.check_post(form_data, '/process?id=', 400)
- form_data['effort'] = None
+ form_data['effort'] = 1.1
+ form_data['children'] = 1.1
+ self.check_post(form_data, '/process?id=', 400)
+ form_data['children'] = 'a'
+ self.check_post(form_data, '/process?id=', 400)
+ form_data['children'] = [1, 1.2]
+ self.check_post(form_data, '/process?id=', 400)
+ form_data['children'] = [1, 'b']
self.check_post(form_data, '/process?id=', 400)
self.check_post({}, '/process?id=', 400)
- self.check_post({'title': '', 'description': ''}, '/process?id=', 400)
- self.check_post({'title': '', 'effort': 1}, '/process?id=', 400)
- self.check_post({'description': '', 'effort': 1}, '/process?id=', 400)
- form_data = {'title': None, 'description': 1, 'effort': 1.0}
+ form_data = {'title': '', 'description': ''}
+ self.check_post(form_data, '/process?id=', 400)
+ form_data = {'title': '', 'effort': 1.1}
+ self.check_post(form_data, '/process?id=', 400)
+ form_data = {'description': '', 'effort': 1.0}
+ self.check_post(form_data, '/process?id=', 400)
+ form_data = {'title': '', 'description': '',
+ 'effort': 1.1, 'children': [1, 2]}
self.check_post(form_data, '/process?id=', 302, '/')
- retrieved = Process.by_id(self.db_conn, 1)
- self.assertEqual(retrieved.title.newest, 'None')
- self.assertEqual([p.id_ for p in Process.all(self.db_conn)],
- [retrieved.id_])
+ retrieved_1 = Process.by_id(self.db_conn, 1)
+ self.assertEqual(retrieved_1.title.newest, '')
+ self.assertEqual(retrieved_1.child_ids, [1, 2])
+ form_data['children'] = []
+ self.check_post(form_data, '/process?id=', 302, '/')
+ retrieved_2 = Process.by_id(self.db_conn, 2)
+ self.assertEqual(retrieved_2.child_ids, [])
+ del form_data['children']
+ self.check_post(form_data, '/process?id=', 302, '/')
+ retrieved_3 = Process.by_id(self.db_conn, 3)
+ self.assertEqual(retrieved_2.child_ids, [])
self.check_post(form_data, '/process?id=1', 302, '/')
+ self.assertEqual([p.id_ for p in Process.all(self.db_conn)],
+ [retrieved_1.id_, retrieved_2.id_, retrieved_3.id_])
+ retrieved_1 = Process.by_id(self.db_conn, 1)
+ self.assertEqual(retrieved_1.child_ids, [])
self.check_post(form_data, '/process', 302, '/')
def test_do_GET(self) -> None:
self.assertEqual(self.conn.getresponse().status, expected_code)
def check_post(self, data: Mapping[str, object], target: str,
- expected_code: int, redirect_location: str = '') -> None:
+ expected_code: int, redirect_location: str = '/') -> None:
"""Check that POST of data to target yields expected_code."""
- encoded_form_data = urlencode(data).encode('utf-8')
+ encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
headers = {'Content-Type': 'application/x-www-form-urlencoded',
'Content-Length': str(len(encoded_form_data))}
self.conn.request('POST', target,