From: Christian Heller Date: Tue, 16 Apr 2024 03:29:15 +0000 (+0200) Subject: Unify ParamsParser and PostvarsParser to InputsParser. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/decks/template?a=commitdiff_plain;h=30538b26b3af74041fb6c907c2c142d8f2c11a0e;p=plomtask Unify ParamsParser and PostvarsParser to InputsParser. --- diff --git a/plomtask/http.py b/plomtask/http.py index b00ebeb..912b635 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -26,81 +26,63 @@ class TaskServer(HTTPServer): self.jinja = JinjaEnv(loader=JinjaFSLoader(TEMPLATES_DIR)) -class ParamsParser: - """Wrapper for validating and retrieving GET params.""" - - def __init__(self, params: dict[str, list[str]]) -> None: - self.params = params - - def get_str(self, key: str, default: str = '') -> str: - """Retrieve string value of key from self.params.""" - if key not in self.params or 0 == len(self.params[key]): +class InputsParser: + """Wrapper for validating and retrieving dict-like HTTP inputs.""" + + def __init__(self, dict_: dict[str, list[str]], + strictness: bool = True) -> None: + self.inputs = dict_ + self.strict = strictness + + def get_str(self, key: str, default: str = '', + ignore_strict: bool = False) -> str: + """Retrieve single/first string value of key, or default.""" + if key not in self.inputs.keys() or 0 == len(self.inputs[key]): + if self.strict and not ignore_strict: + raise BadFormatException(f'no value found for key {key}') return default - return self.params[key][0] + return self.inputs[key][0] + + def get_int(self, key: str) -> int: + """Retrieve single/first value of key as int, error if empty.""" + val = self.get_int_or_none(key) + if val is None: + raise BadFormatException(f'unexpected empty value for: {key}') + return val def get_int_or_none(self, key: str) -> int | None: - """Retrieve int value of key from self.params, on empty return None.""" - if key not in self.params or \ - 0 == len(''.join(list(self.params[key]))): + """Retrieve single/first value of key as int, return None if empty.""" + val = self.get_str(key, ignore_strict=True) + if val == '': return None - val_str = self.params[key][0] - try: - return int(val_str) - except ValueError as e: - raise BadFormatException(f'Bad ?{key}= value: {val_str}') from e - - -class PostvarsParser: - """Postvars wrapper for validating and retrieving form data.""" - - def __init__(self, postvars: dict[str, list[str]]) -> None: - self.postvars = postvars - - def get_str(self, key: str) -> str: - """Retrieve string value of key from self.postvars.""" - all_str = self.get_all_str(key) - if 0 == len(all_str): - raise BadFormatException(f'missing value for key: {key}') - return all_str[0] - - def get_int(self, key: str) -> int: - """Retrieve int value of key from self.postvars.""" - val = self.get_str(key) try: return int(val) except ValueError as e: - msg = f'cannot int form field value: {val}' + msg = f'cannot int form field value for key {key}: {val}' raise BadFormatException(msg) from e - def get_int_or_none(self, key: str) -> int | None: - """Retrieve int value of key from self.postvars, or None.""" - if key not in self.postvars or \ - 0 == len(''.join(list(self.postvars[key]))): - return None - return self.get_int(key) - def get_float(self, key: str) -> float: """Retrieve float value of key from self.postvars.""" val = self.get_str(key) try: return float(val) except ValueError as e: - msg = f'cannot float form field value: {val}' + msg = f'cannot float form field value for key {key}: {val}' raise BadFormatException(msg) from e def get_all_str(self, key: str) -> list[str]: - """Retrieve list of string values at key from self.postvars.""" - if key not in self.postvars: + """Retrieve list of string values at key.""" + if key not in self.inputs.keys(): return [] - return self.postvars[key] + return self.inputs[key] def get_all_int(self, key: str) -> list[int]: - """Retrieve list of int values at key from self.postvars.""" + """Retrieve list of int values at key.""" all_str = self.get_all_str(key) try: return [int(s) for s in all_str if len(s) > 0] except ValueError as e: - msg = f'cannot int a form field value: {all_str}' + msg = f'cannot int a form field value for key {key} in: {all_str}' raise BadFormatException(msg) from e @@ -127,7 +109,7 @@ class TaskHandler(BaseHTTPRequestHandler): conn.close() def do_GET_calendar(self, conn: DatabaseConnection, - params: ParamsParser) -> str: + params: InputsParser) -> str: """Show Days from ?start= to ?end=.""" start = params.get_str('start') end = params.get_str('end') @@ -136,7 +118,7 @@ class TaskHandler(BaseHTTPRequestHandler): days=days, start=start, end=end) def do_GET_day(self, conn: DatabaseConnection, - params: ParamsParser) -> str: + params: InputsParser) -> str: """Show single Day of ?date=.""" date = params.get_str('date', todays_date()) day = Day.by_date(conn, date, create=True) @@ -154,7 +136,7 @@ class TaskHandler(BaseHTTPRequestHandler): conditions_listing=conditions_listing) def do_GET_todo(self, conn: DatabaseConnection, params: - ParamsParser) -> str: + InputsParser) -> str: """Show single Todo of ?id=.""" id_ = params.get_int_or_none('id') todo = Todo.by_id(conn, id_) @@ -164,13 +146,13 @@ class TaskHandler(BaseHTTPRequestHandler): condition_candidates=Condition.all(conn)) def do_GET_conditions(self, conn: DatabaseConnection, - _: ParamsParser) -> str: + _: InputsParser) -> str: """Show all Conditions.""" return self.server.jinja.get_template('conditions.html').render( conditions=Condition.all(conn)) def do_GET_condition(self, conn: DatabaseConnection, - params: ParamsParser) -> str: + params: InputsParser) -> str: """Show Condition of ?id=.""" id_ = params.get_int_or_none('id') condition = Condition.by_id(conn, id_, create=True) @@ -178,7 +160,7 @@ class TaskHandler(BaseHTTPRequestHandler): condition=condition) def do_GET_process(self, conn: DatabaseConnection, - params: ParamsParser) -> str: + params: InputsParser) -> str: """Show process of ?id=.""" id_ = params.get_int_or_none('id') process = Process.by_id(conn, id_, create=True) @@ -189,7 +171,7 @@ class TaskHandler(BaseHTTPRequestHandler): condition_candidates=Condition.all(conn)) def do_GET_processes(self, conn: DatabaseConnection, - _: ParamsParser) -> str: + _: InputsParser) -> str: """Show all Processes.""" return self.server.jinja.get_template('processes.html').render( processes=Process.all(conn)) @@ -201,7 +183,8 @@ class TaskHandler(BaseHTTPRequestHandler): length = int(self.headers['content-length']) postvars = parse_qs(self.rfile.read(length).decode(), keep_blank_values=True, strict_parsing=True) - form_data = PostvarsParser(postvars) + # form_data = PostvarsParser(postvars) + form_data = InputsParser(postvars) if site in ('day', 'process', 'todo', 'condition'): getattr(self, f'do_POST_{site}')(conn, params, form_data) conn.commit() @@ -214,8 +197,8 @@ class TaskHandler(BaseHTTPRequestHandler): finally: conn.close() - def do_POST_day(self, conn: DatabaseConnection, params: ParamsParser, - form_data: PostvarsParser) -> None: + def do_POST_day(self, conn: DatabaseConnection, params: InputsParser, + form_data: InputsParser) -> None: """Update or insert Day of date and Todos mapped to it.""" date = params.get_str('date') day = Day.by_date(conn, date, create=True) @@ -227,8 +210,8 @@ class TaskHandler(BaseHTTPRequestHandler): todo = Todo(None, process, False, day) todo.save(conn) - def do_POST_todo(self, conn: DatabaseConnection, params: ParamsParser, - form_data: PostvarsParser) -> None: + def do_POST_todo(self, conn: DatabaseConnection, params: InputsParser, + form_data: InputsParser) -> None: """Update Todo and its children.""" id_ = params.get_int_or_none('id') todo = Todo.by_id(conn, id_) @@ -246,8 +229,8 @@ class TaskHandler(BaseHTTPRequestHandler): for condition in todo.undoes: condition.save(conn) - def do_POST_process(self, conn: DatabaseConnection, params: ParamsParser, - form_data: PostvarsParser) -> None: + def do_POST_process(self, conn: DatabaseConnection, params: InputsParser, + form_data: InputsParser) -> None: """Update or insert Process of ?id= and fields defined in postvars.""" id_ = params.get_int_or_none('id') process = Process.by_id(conn, id_, create=True) @@ -273,8 +256,8 @@ class TaskHandler(BaseHTTPRequestHandler): process.add_step(conn, None, step_process_id, None) process.fix_steps(conn) - def do_POST_condition(self, conn: DatabaseConnection, params: ParamsParser, - form_data: PostvarsParser) -> None: + def do_POST_condition(self, conn: DatabaseConnection, params: InputsParser, + form_data: InputsParser) -> None: """Update/insert Condition of ?id= and fields defined in postvars.""" id_ = params.get_int_or_none('id') condition = Condition.by_id(conn, id_, create=True) @@ -282,11 +265,12 @@ class TaskHandler(BaseHTTPRequestHandler): condition.description.set(form_data.get_str('description')) condition.save(conn) - def _init_handling(self) -> tuple[DatabaseConnection, str, ParamsParser]: + def _init_handling(self) -> tuple[DatabaseConnection, str, InputsParser]: conn = DatabaseConnection(self.server.db) parsed_url = urlparse(self.path) site = path_split(parsed_url.path)[1] - params = ParamsParser(parse_qs(parsed_url.query, strict_parsing=True)) + params = InputsParser(parse_qs(parsed_url.query, strict_parsing=True), + False) return conn, site, params def _redirect(self, target: str) -> None: diff --git a/tests/misc.py b/tests/misc.py index 87b3a6e..d49870f 100644 --- a/tests/misc.py +++ b/tests/misc.py @@ -1,100 +1,100 @@ """Miscellaneous tests.""" from unittest import TestCase from tests.utils import TestCaseWithServer -from plomtask.http import ParamsParser, PostvarsParser +from plomtask.http import InputsParser from plomtask.exceptions import BadFormatException class TestsSansServer(TestCase): """Tests that do not require DB setup or a server.""" - def test_params_parser(self) -> None: - """Test behavior of ParamsParser.""" - self.assertEqual('', - ParamsParser({}).get_str('foo')) - self.assertEqual('bar', - ParamsParser({}).get_str('foo', 'bar')) - self.assertEqual('bar', - ParamsParser({'foo': []}).get_str('foo', 'bar')) - self.assertEqual('baz', - ParamsParser({'foo': ['baz']}).get_str('foo', 'bar')) - self.assertEqual(None, - ParamsParser({}).get_int_or_none('foo')) - self.assertEqual(None, - ParamsParser({'foo': []}).get_int_or_none('foo')) - self.assertEqual(None, - ParamsParser({'foo': ['']}).get_int_or_none('foo')) - self.assertEqual(0, - ParamsParser({'foo': ['0']}).get_int_or_none('foo')) + def test_InputsParser_non_strict(self) -> None: + """Test behavior of non-strict (= params) InputsParser.""" + params = InputsParser({}, False) + self.assertEqual('', params.get_str('foo')) + params = InputsParser({}, False) + self.assertEqual('bar', params.get_str('foo', 'bar')) + params = InputsParser({'foo': []}, False) + self.assertEqual('bar', params.get_str('foo', 'bar')) + params = InputsParser({'foo': ['baz']}, False) + self.assertEqual('baz', params.get_str('foo', 'bar')) + params = InputsParser({}, False) + self.assertEqual(None, params.get_int_or_none('foo')) + params = InputsParser({'foo': []}, False) + self.assertEqual(None, params.get_int_or_none('foo')) + params = InputsParser({'foo': ['']}, False) + self.assertEqual(None, params.get_int_or_none('foo')) + params = InputsParser({'foo': ['0']}, False) + self.assertEqual(0, params.get_int_or_none('foo')) with self.assertRaises(BadFormatException): - ParamsParser({'foo': ['None']}).get_int_or_none('foo') + InputsParser({'foo': ['None']}, False).get_int_or_none('foo') with self.assertRaises(BadFormatException): - ParamsParser({'foo': ['0.1']}).get_int_or_none('foo') - self.assertEqual(23, - ParamsParser({'foo': ['23']}).get_int_or_none('foo')) + InputsParser({'foo': ['0.1']}, False).get_int_or_none('foo') + params = InputsParser({'foo': ['23']}, False) + self.assertEqual(23, params.get_int_or_none('foo')) - def test_postvars_parser(self) -> None: - """Test behavior of PostvarsParser.""" + def test_InputsParser_strict(self) -> None: + """Test behavior of strict (= postvars) InputsParser.""" self.assertEqual([], - PostvarsParser({}).get_all_str('foo')) + InputsParser({}).get_all_str('foo')) self.assertEqual([], - PostvarsParser({'foo': []}).get_all_str('foo')) + InputsParser({'foo': []}).get_all_str('foo')) self.assertEqual(['bar'], - PostvarsParser({'foo': ['bar']}).get_all_str('foo')) + InputsParser({'foo': ['bar']}).get_all_str('foo')) self.assertEqual(['bar', 'baz'], - PostvarsParser({'foo': ['bar', 'baz']}). + InputsParser({'foo': ['bar', 'baz']}). get_all_str('foo')) self.assertEqual([], - PostvarsParser({}).get_all_int('foo')) + InputsParser({}).get_all_int('foo')) self.assertEqual([], - PostvarsParser({'foo': []}).get_all_int('foo')) + InputsParser({'foo': []}).get_all_int('foo')) self.assertEqual([], - PostvarsParser({'foo': ['']}).get_all_int('foo')) + InputsParser({'foo': ['']}).get_all_int('foo')) self.assertEqual([0], - PostvarsParser({'foo': ['0']}).get_all_int('foo')) + InputsParser({'foo': ['0']}).get_all_int('foo')) self.assertEqual([0, 17], - PostvarsParser({'foo': ['0', '17']}). + InputsParser({'foo': ['0', '17']}). get_all_int('foo')) with self.assertRaises(BadFormatException): - PostvarsParser({'foo': ['0.1', '17']}).get_all_int('foo') + InputsParser({'foo': ['0.1', '17']}).get_all_int('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': ['None', '17']}).get_all_int('foo') + InputsParser({'foo': ['None', '17']}).get_all_int('foo') with self.assertRaises(BadFormatException): - PostvarsParser({}).get_str('foo') + InputsParser({}).get_str('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': []}).get_str('foo') + InputsParser({'foo': []}).get_str('foo') self.assertEqual('bar', - PostvarsParser({'foo': ['bar']}).get_str('foo')) + InputsParser({'foo': ['bar']}).get_str('foo')) self.assertEqual('', - PostvarsParser({'foo': ['', 'baz']}).get_str('foo')) + InputsParser({'foo': ['', 'baz']}).get_str('foo')) with self.assertRaises(BadFormatException): - PostvarsParser({}).get_int('foo') + InputsParser({}).get_int('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': []}).get_int('foo') + InputsParser({'foo': []}).get_int('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': ['']}).get_int('foo') + InputsParser({'foo': ['']}).get_int('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': ['bar']}).get_int('foo') + InputsParser({'foo': ['bar']}).get_int('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': ['0.1']}).get_int('foo') + InputsParser({'foo': ['0.1']}).get_int('foo') self.assertEqual(0, - PostvarsParser({'foo': ['0']}).get_int('foo')) + InputsParser({'foo': ['0']}).get_int('foo')) self.assertEqual(17, - PostvarsParser({'foo': ['17', '23']}).get_int('foo')) + InputsParser({'foo': ['17', '23']}).get_int('foo')) with self.assertRaises(BadFormatException): - PostvarsParser({}).get_float('foo') + InputsParser({}).get_float('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': []}).get_float('foo') + InputsParser({'foo': []}).get_float('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': ['']}).get_float('foo') + InputsParser({'foo': ['']}).get_float('foo') with self.assertRaises(BadFormatException): - PostvarsParser({'foo': ['bar']}).get_float('foo') + InputsParser({'foo': ['bar']}).get_float('foo') self.assertEqual(0, - PostvarsParser({'foo': ['0']}).get_float('foo')) + InputsParser({'foo': ['0']}).get_float('foo')) self.assertEqual(0.1, - PostvarsParser({'foo': ['0.1']}).get_float('foo')) + InputsParser({'foo': ['0.1']}).get_float('foo')) self.assertEqual(1.23, - PostvarsParser({'foo': ['1.23', '456']}). + InputsParser({'foo': ['1.23', '456']}). get_float('foo'))