From 14e7f26613b8ac213a1b82370a153f81df7726cf Mon Sep 17 00:00:00 2001 From: Christian Heller <c.heller@plomlompom.de> Date: Mon, 12 Aug 2024 13:58:27 +0200 Subject: [PATCH] Harmonize treatment of GET /[item]?id=. --- plomtask/db.py | 9 ++++-- plomtask/http.py | 8 ++++-- tests/conditions.py | 3 +- tests/days.py | 4 +-- tests/misc.py | 3 +- tests/processes.py | 12 ++------ tests/todos.py | 9 ++---- tests/utils.py | 67 +++++++++++++++++++++++++++++---------------- 8 files changed, 65 insertions(+), 50 deletions(-) diff --git a/plomtask/db.py b/plomtask/db.py index 1fdd3e1..f067cd3 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -5,7 +5,8 @@ from os.path import isfile from difflib import Differ from sqlite3 import connect as sql_connect, Cursor, Row from typing import Any, Self, TypeVar, Generic, Callable -from plomtask.exceptions import HandledException, NotFoundException +from plomtask.exceptions import (HandledException, NotFoundException, + BadFormatException) from plomtask.dating import valid_date EXPECTED_DB_VERSION = 5 @@ -246,10 +247,10 @@ class BaseModel(Generic[BaseModelId]): def __init__(self, id_: BaseModelId | None) -> None: if isinstance(id_, int) and id_ < 1: msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}' - raise HandledException(msg) + raise BadFormatException(msg) if isinstance(id_, str) and "" == id_: msg = f'illegal {self.__class__.__name__} ID, must be non-empty' - raise HandledException(msg) + raise BadFormatException(msg) self.id_ = id_ def __hash__(self) -> int: @@ -437,6 +438,8 @@ class BaseModel(Generic[BaseModelId]): """ obj = None if id_ is not None: + if isinstance(id_, int) and id_ == 0: + raise BadFormatException('illegal ID of value 0') obj = cls._get_cached(id_) if not obj: for row in db_conn.row_where(cls.table_name, 'id', id_): diff --git a/plomtask/http.py b/plomtask/http.py index 4426bba..e242a36 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -50,7 +50,7 @@ class InputsParser: """Retrieve list of int values at key.""" all_str = self.get_all_str(key) try: - return [int(s) for s in all_str if len(s) > 0] + return [int(s) for s in all_str] except ValueError as e: msg = f'cannot int a form field value for key {key} in: {all_str}' raise BadFormatException(msg) from e @@ -305,7 +305,9 @@ class TaskHandler(BaseHTTPRequestHandler): # pylint: disable=protected-access # (because pylint here fails to detect the use of wrapper as a # method to self with respective access privileges) - id_ = self._params.get_int_or_none('id') + id_ = None + for val in self._params.get_all_int('id'): + id_ = val if target_class.can_create_by_id: item = target_class.by_id_or_create(self._conn, id_) else: @@ -348,7 +350,7 @@ class TaskHandler(BaseHTTPRequestHandler): def do_GET_day(self) -> dict[str, object]: """Show single Day of ?date=.""" - date = self._params.get_str_or_fail('date', date_in_n_days(0)) + date = self._params.get_str('date', date_in_n_days(0)) make_type = self._params.get_str_or_fail('make_type', 'full') # day = Day.by_id_or_create(self._conn, date) diff --git a/tests/conditions.py b/tests/conditions.py index a9b28bb..58fa18b 100644 --- a/tests/conditions.py +++ b/tests/conditions.py @@ -72,6 +72,7 @@ class ExpectedGetCondition(Expected): class TestsWithServer(TestCaseWithServer): """Module tests against our HTTP server/handler (and database).""" + checked_class = Condition def test_fail_POST_condition(self) -> None: """Test malformed/illegal POST /condition requests.""" @@ -152,8 +153,8 @@ class TestsWithServer(TestCaseWithServer): self.check_filter(exp, 'conditions', 'sort_by', 'is_active', [1, 2, 3]) self.check_filter(exp, 'conditions', 'sort_by', '-is_active', [3, 2, 1]) - # test pattern matching on title exp.set('sort_by', 'title') + # test pattern matching on title exp.lib_del('Condition', 1) self.check_filter(exp, 'conditions', 'pattern', 'ba', [2, 3]) # test pattern matching on description diff --git a/tests/days.py b/tests/days.py index aac150b..5edec50 100644 --- a/tests/days.py +++ b/tests/days.py @@ -159,12 +159,12 @@ class ExpectedGetDay(Expected): class TestsWithServer(TestCaseWithServer): """Tests against our HTTP server/handler (and database).""" + checked_class = Day def test_basic_GET_day(self) -> None: """Test basic (no Processes/Conditions/Todos) GET /day basics.""" # check illegal date parameters - self.check_get('/day?date=', 400) - self.check_get('/day?date=foo', 400) + self.check_get_defaults('/day', '2024-01-01', 'date') self.check_get('/day?date=2024-02-30', 400) # check undefined day date = _testing_date_in_n_days(0) diff --git a/tests/misc.py b/tests/misc.py index 86474c7..8159124 100644 --- a/tests/misc.py +++ b/tests/misc.py @@ -147,7 +147,8 @@ class TestsSansServer(TestCase): parser = InputsParser({'foo': []}) self.assertEqual([], parser.get_all_int('foo')) parser = InputsParser({'foo': ['']}) - self.assertEqual([], parser.get_all_int('foo')) + with self.assertRaises(BadFormatException): + parser.get_all_int('foo') parser = InputsParser({'foo': ['0']}) self.assertEqual([0], parser.get_all_int('foo')) parser = InputsParser({'foo': ['0', '17']}) diff --git a/tests/processes.py b/tests/processes.py index 24a62bd..2561fbb 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -284,6 +284,7 @@ class ExpectedGetProcesses(Expected): class TestsWithServer(TestCaseWithServer): """Module tests against our HTTP server/handler (and database).""" + checked_class = Process def _post_process(self, id_: int = 1, form_data: dict[str, Any] | None = None @@ -404,19 +405,10 @@ class TestsWithServer(TestCaseWithServer): p = p_min | {'kept_steps': [1, 2, 3], 'new_step_to_2': 5, 'step_of': 6} self.check_post(p, url, 400) - def test_GET(self) -> None: - """Test /process and /processes response codes.""" - self.check_get('/process', 200) - self.check_get('/process?id=', 200) - self.check_get('/process?id=1', 200) - self.check_get_defaults('/process') - self.check_get('/processes', 200) - def test_fail_GET_process(self) -> None: """Test invalid GET /process params.""" # check for invalid IDs - self.check_get('/process?id=foo', 400) - self.check_get('/process?id=0', 500) + self.check_get_defaults('/process') # check we catch invalid base64 self.check_get('/process?title_b64=foo', 400) # check failure on references to unknown processes; we create Process diff --git a/tests/todos.py b/tests/todos.py index f048d46..9f3874d 100644 --- a/tests/todos.py +++ b/tests/todos.py @@ -266,6 +266,7 @@ class ExpectedGetTodo(Expected): class TestsWithServer(TestCaseWithServer): """Tests against our HTTP server/handler (and database).""" + checked_class = Todo def _post_exp_todo( self, id_: int, payload: dict[str, Any], exp: Expected) -> None: @@ -278,7 +279,7 @@ class TestsWithServer(TestCaseWithServer): # test we cannot just POST into non-existing Todo self.check_post({}, '/todo', 404) self.check_post({}, '/todo?id=FOO', 400) - self.check_post({}, '/todo?id=0', 404) + self.check_post({}, '/todo?id=0', 400) self.check_post({}, '/todo?id=1', 404) # test malformed values on existing Todo self.post_exp_day([], {'new_todo': [1]}) @@ -463,11 +464,7 @@ class TestsWithServer(TestCaseWithServer): def test_GET_todo(self) -> None: """Test GET /todo response codes.""" # test malformed or illegal parameter values - self.check_get('/todo', 404) - self.check_get('/todo?id=', 404) - self.check_get('/todo?id=foo', 400) - self.check_get('/todo?id=0', 404) - self.check_get('/todo?id=2', 404) + self.check_get_defaults('/todo') # test all existing Processes are shown as available exp = ExpectedGetTodo(1) self.post_exp_process([exp], {}, 1) diff --git a/tests/utils.py b/tests/utils.py index 7945f61..75c7e50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,17 +35,8 @@ class TestCaseAugmented(TestCase): default_init_kwargs: dict[str, Any] = {} @staticmethod - def _run_if_checked_class(f: Callable[..., None]) -> Callable[..., None]: - def wrapper(self: TestCase) -> None: - if hasattr(self, 'checked_class'): - f(self) - return wrapper - - @classmethod - def _run_on_versioned_attributes(cls, - f: Callable[..., None] + def _run_on_versioned_attributes(f: Callable[..., None] ) -> Callable[..., None]: - @cls._run_if_checked_class def wrapper(self: TestCase) -> None: assert isinstance(self, TestCaseAugmented) for attr_name in self.checked_class.to_save_versioned(): @@ -56,6 +47,23 @@ class TestCaseAugmented(TestCase): f(self, owner, attr_name, attr, default, to_set) return wrapper + @classmethod + def _run_if_sans_db(cls, f: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: TestCaseSansDB) -> None: + if issubclass(cls, TestCaseSansDB): + f(self) + return wrapper + + @classmethod + def _run_if_with_db_but_not_server(cls, + f: Callable[..., None] + ) -> Callable[..., None]: + def wrapper(self: TestCaseWithDB) -> None: + if issubclass(cls, TestCaseWithDB) and\ + not issubclass(cls, TestCaseWithServer): + f(self) + return wrapper + @classmethod def _make_from_defaults(cls, id_: float | str | None) -> Any: return cls.checked_class(id_, **cls.default_init_kwargs) @@ -66,7 +74,7 @@ class TestCaseSansDB(TestCaseAugmented): legal_ids: list[str] | list[int] = [1, 5] illegal_ids: list[str] | list[int] = [0] - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_sans_db def test_id_validation(self) -> None: """Test .id_ validation/setting.""" for id_ in self.illegal_ids: @@ -76,6 +84,7 @@ class TestCaseSansDB(TestCaseAugmented): obj = self._make_from_defaults(id_) self.assertEqual(obj.id_, id_) + @TestCaseAugmented._run_if_sans_db @TestCaseAugmented._run_on_versioned_attributes def test_versioned_set(self, _: Any, @@ -115,6 +124,7 @@ class TestCaseSansDB(TestCaseAugmented): attr.set(to_set[1]) self.assertEqual(timesorted_vals, expected) + @TestCaseAugmented._run_if_sans_db @TestCaseAugmented._run_on_versioned_attributes def test_versioned_newest(self, _: Any, @@ -134,6 +144,7 @@ class TestCaseSansDB(TestCaseAugmented): attr.set(default) self.assertEqual(attr.newest, default) + @TestCaseAugmented._run_if_sans_db @TestCaseAugmented._run_on_versioned_attributes def test_versioned_at(self, _: Any, @@ -277,6 +288,7 @@ class TestCaseWithDB(TestCaseAugmented): self.assertEqual(start, end) self.assertEqual(items, [obj_today]) + @TestCaseAugmented._run_if_with_db_but_not_server @TestCaseAugmented._run_on_versioned_attributes def test_saving_versioned_attributes(self, owner: Any, @@ -318,7 +330,7 @@ class TestCaseWithDB(TestCaseAugmented): attr_vals_saved = retrieve_attr_vals(attr) self.assertEqual(to_set, attr_vals_saved) - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_with_db_but_not_server def test_saving_and_caching(self) -> None: """Test effects of .cache() and .save().""" id1 = self.default_ids[0] @@ -353,7 +365,7 @@ class TestCaseWithDB(TestCaseAugmented): with self.assertRaises(HandledException): obj1.save(self.db_conn) - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_with_db_but_not_server def test_by_id(self) -> None: """Test .by_id().""" id1, id2, _ = self.default_ids @@ -369,7 +381,7 @@ class TestCaseWithDB(TestCaseAugmented): obj2.save(self.db_conn) self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2)) - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_with_db_but_not_server def test_by_id_or_create(self) -> None: """Test .by_id_or_create.""" # check .by_id_or_create fails if wrong class @@ -392,7 +404,7 @@ class TestCaseWithDB(TestCaseAugmented): self.checked_class.by_id(self.db_conn, item.id_) self.assertEqual(self.checked_class(item.id_), item) - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_with_db_but_not_server def test_from_table_row(self) -> None: """Test .from_table_row() properly reads in class directly from DB.""" id_ = self.default_ids[0] @@ -416,6 +428,7 @@ class TestCaseWithDB(TestCaseAugmented): self.assertEqual({retrieved.id_: retrieved}, self.checked_class.get_cache()) + @TestCaseAugmented._run_if_with_db_but_not_server @TestCaseAugmented._run_on_versioned_attributes def test_versioned_history_from_row(self, owner: Any, @@ -439,7 +452,7 @@ class TestCaseWithDB(TestCaseAugmented): for timestamp, value in attr.history.items(): self.assertEqual(value, loaded_attr.history[timestamp]) - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_with_db_but_not_server def test_all(self) -> None: """Test .all() and its relation to cache and savings.""" id1, id2, id3 = self.default_ids @@ -457,7 +470,7 @@ class TestCaseWithDB(TestCaseAugmented): self.assertEqual(sorted(self.checked_class.all(self.db_conn)), sorted([item1, item2, item3])) - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_with_db_but_not_server def test_singularity(self) -> None: """Test pointers made for single object keep pointing to it.""" id1 = self.default_ids[0] @@ -469,6 +482,7 @@ class TestCaseWithDB(TestCaseAugmented): retrieved = self.checked_class.by_id(self.db_conn, id1) self.assertEqual(new_attr, getattr(retrieved, attr_name)) + @TestCaseAugmented._run_if_with_db_but_not_server @TestCaseAugmented._run_on_versioned_attributes def test_versioned_singularity(self, owner: Any, @@ -485,7 +499,7 @@ class TestCaseWithDB(TestCaseAugmented): attr_retrieved = getattr(retrieved, attr_name) self.assertEqual(attr.history, attr_retrieved.history) - @TestCaseAugmented._run_if_checked_class + @TestCaseAugmented._run_if_with_db_but_not_server def test_remove(self) -> None: """Test .remove() effects on DB and cache.""" id_ = self.default_ids[0] @@ -947,13 +961,18 @@ class TestCaseWithServer(TestCaseWithDB): else: self.assertEqual(self.conn.getresponse().status, expected_code) - def check_get_defaults(self, path: str) -> None: + def check_get_defaults(self, + path: str, + default_id: str = '1', + id_name: str = 'id' + ) -> None: """Some standard model paths to test.""" - self.check_get(path, 200) - self.check_get(f'{path}?id=', 200) - self.check_get(f'{path}?id=foo', 400) - self.check_get(f'/{path}?id=0', 500) - self.check_get(f'{path}?id=1', 200) + nonexist_status = 200 if self.checked_class.can_create_by_id else 404 + self.check_get(path, nonexist_status) + self.check_get(f'{path}?{id_name}=', 400) + self.check_get(f'{path}?{id_name}=foo', 400) + self.check_get(f'/{path}?{id_name}=0', 400) + self.check_get(f'{path}?{id_name}={default_id}', nonexist_status) def check_json_get(self, path: str, expected: Expected) -> None: """Compare JSON on GET path with expected. -- 2.30.2