From 14e7f26613b8ac213a1b82370a153f81df7726cf Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 12 Aug 2024 13:58:27 +0200
Subject: [PATCH] Harmonize treatment of GET /[item]?id=.

---
 plomtask/db.py      |  9 ++++--
 plomtask/http.py    |  8 ++++--
 tests/conditions.py |  3 +-
 tests/days.py       |  4 +--
 tests/misc.py       |  3 +-
 tests/processes.py  | 12 ++------
 tests/todos.py      |  9 ++----
 tests/utils.py      | 67 +++++++++++++++++++++++++++++----------------
 8 files changed, 65 insertions(+), 50 deletions(-)

diff --git a/plomtask/db.py b/plomtask/db.py
index 1fdd3e1..f067cd3 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -5,7 +5,8 @@ from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from typing import Any, Self, TypeVar, Generic, Callable
-from plomtask.exceptions import HandledException, NotFoundException
+from plomtask.exceptions import (HandledException, NotFoundException,
+                                 BadFormatException)
 from plomtask.dating import valid_date
 
 EXPECTED_DB_VERSION = 5
@@ -246,10 +247,10 @@ class BaseModel(Generic[BaseModelId]):
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
-            raise HandledException(msg)
+            raise BadFormatException(msg)
         if isinstance(id_, str) and "" == id_:
             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
-            raise HandledException(msg)
+            raise BadFormatException(msg)
         self.id_ = id_
 
     def __hash__(self) -> int:
@@ -437,6 +438,8 @@ class BaseModel(Generic[BaseModelId]):
         """
         obj = None
         if id_ is not None:
+            if isinstance(id_, int) and id_ == 0:
+                raise BadFormatException('illegal ID of value 0')
             obj = cls._get_cached(id_)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
diff --git a/plomtask/http.py b/plomtask/http.py
index 4426bba..e242a36 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -50,7 +50,7 @@ class InputsParser:
         """Retrieve list of int values at key."""
         all_str = self.get_all_str(key)
         try:
-            return [int(s) for s in all_str if len(s) > 0]
+            return [int(s) for s in all_str]
         except ValueError as e:
             msg = f'cannot int a form field value for key {key} in: {all_str}'
             raise BadFormatException(msg) from e
@@ -305,7 +305,9 @@ class TaskHandler(BaseHTTPRequestHandler):
                 # pylint: disable=protected-access
                 # (because pylint here fails to detect the use of wrapper as a
                 # method to self with respective access privileges)
-                id_ = self._params.get_int_or_none('id')
+                id_ = None
+                for val in self._params.get_all_int('id'):
+                    id_ = val
                 if target_class.can_create_by_id:
                     item = target_class.by_id_or_create(self._conn, id_)
                 else:
@@ -348,7 +350,7 @@ class TaskHandler(BaseHTTPRequestHandler):
 
     def do_GET_day(self) -> dict[str, object]:
         """Show single Day of ?date=."""
-        date = self._params.get_str_or_fail('date', date_in_n_days(0))
+        date = self._params.get_str('date', date_in_n_days(0))
         make_type = self._params.get_str_or_fail('make_type', 'full')
         #
         day = Day.by_id_or_create(self._conn, date)
diff --git a/tests/conditions.py b/tests/conditions.py
index a9b28bb..58fa18b 100644
--- a/tests/conditions.py
+++ b/tests/conditions.py
@@ -72,6 +72,7 @@ class ExpectedGetCondition(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Module tests against our HTTP server/handler (and database)."""
+    checked_class = Condition
 
     def test_fail_POST_condition(self) -> None:
         """Test malformed/illegal POST /condition requests."""
@@ -152,8 +153,8 @@ class TestsWithServer(TestCaseWithServer):
         self.check_filter(exp, 'conditions', 'sort_by', 'is_active', [1, 2, 3])
         self.check_filter(exp, 'conditions', 'sort_by', '-is_active',
                           [3, 2, 1])
-        # test pattern matching on title
         exp.set('sort_by', 'title')
+        # test pattern matching on title
         exp.lib_del('Condition', 1)
         self.check_filter(exp, 'conditions', 'pattern', 'ba', [2, 3])
         # test pattern matching on description
diff --git a/tests/days.py b/tests/days.py
index aac150b..5edec50 100644
--- a/tests/days.py
+++ b/tests/days.py
@@ -159,12 +159,12 @@ class ExpectedGetDay(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
+    checked_class = Day
 
     def test_basic_GET_day(self) -> None:
         """Test basic (no Processes/Conditions/Todos) GET /day basics."""
         # check illegal date parameters
-        self.check_get('/day?date=', 400)
-        self.check_get('/day?date=foo', 400)
+        self.check_get_defaults('/day', '2024-01-01', 'date')
         self.check_get('/day?date=2024-02-30', 400)
         # check undefined day
         date = _testing_date_in_n_days(0)
diff --git a/tests/misc.py b/tests/misc.py
index 86474c7..8159124 100644
--- a/tests/misc.py
+++ b/tests/misc.py
@@ -147,7 +147,8 @@ class TestsSansServer(TestCase):
         parser = InputsParser({'foo': []})
         self.assertEqual([], parser.get_all_int('foo'))
         parser = InputsParser({'foo': ['']})
-        self.assertEqual([], parser.get_all_int('foo'))
+        with self.assertRaises(BadFormatException):
+            parser.get_all_int('foo')
         parser = InputsParser({'foo': ['0']})
         self.assertEqual([0], parser.get_all_int('foo'))
         parser = InputsParser({'foo': ['0', '17']})
diff --git a/tests/processes.py b/tests/processes.py
index 24a62bd..2561fbb 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -284,6 +284,7 @@ class ExpectedGetProcesses(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Module tests against our HTTP server/handler (and database)."""
+    checked_class = Process
 
     def _post_process(self, id_: int = 1,
                       form_data: dict[str, Any] | None = None
@@ -404,19 +405,10 @@ class TestsWithServer(TestCaseWithServer):
         p = p_min | {'kept_steps': [1, 2, 3], 'new_step_to_2': 5, 'step_of': 6}
         self.check_post(p, url, 400)
 
-    def test_GET(self) -> None:
-        """Test /process and /processes response codes."""
-        self.check_get('/process', 200)
-        self.check_get('/process?id=', 200)
-        self.check_get('/process?id=1', 200)
-        self.check_get_defaults('/process')
-        self.check_get('/processes', 200)
-
     def test_fail_GET_process(self) -> None:
         """Test invalid GET /process params."""
         # check for invalid IDs
-        self.check_get('/process?id=foo', 400)
-        self.check_get('/process?id=0', 500)
+        self.check_get_defaults('/process')
         # check we catch invalid base64
         self.check_get('/process?title_b64=foo', 400)
         # check failure on references to unknown processes; we create Process
diff --git a/tests/todos.py b/tests/todos.py
index f048d46..9f3874d 100644
--- a/tests/todos.py
+++ b/tests/todos.py
@@ -266,6 +266,7 @@ class ExpectedGetTodo(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
+    checked_class = Todo
 
     def _post_exp_todo(
             self, id_: int, payload: dict[str, Any], exp: Expected) -> None:
@@ -278,7 +279,7 @@ class TestsWithServer(TestCaseWithServer):
         # test we cannot just POST into non-existing Todo
         self.check_post({}, '/todo', 404)
         self.check_post({}, '/todo?id=FOO', 400)
-        self.check_post({}, '/todo?id=0', 404)
+        self.check_post({}, '/todo?id=0', 400)
         self.check_post({}, '/todo?id=1', 404)
         # test malformed values on existing Todo
         self.post_exp_day([], {'new_todo': [1]})
@@ -463,11 +464,7 @@ class TestsWithServer(TestCaseWithServer):
     def test_GET_todo(self) -> None:
         """Test GET /todo response codes."""
         # test malformed or illegal parameter values
-        self.check_get('/todo', 404)
-        self.check_get('/todo?id=', 404)
-        self.check_get('/todo?id=foo', 400)
-        self.check_get('/todo?id=0', 404)
-        self.check_get('/todo?id=2', 404)
+        self.check_get_defaults('/todo')
         # test all existing Processes are shown as available
         exp = ExpectedGetTodo(1)
         self.post_exp_process([exp], {}, 1)
diff --git a/tests/utils.py b/tests/utils.py
index 7945f61..75c7e50 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -35,17 +35,8 @@ class TestCaseAugmented(TestCase):
     default_init_kwargs: dict[str, Any] = {}
 
     @staticmethod
-    def _run_if_checked_class(f: Callable[..., None]) -> Callable[..., None]:
-        def wrapper(self: TestCase) -> None:
-            if hasattr(self, 'checked_class'):
-                f(self)
-        return wrapper
-
-    @classmethod
-    def _run_on_versioned_attributes(cls,
-                                     f: Callable[..., None]
+    def _run_on_versioned_attributes(f: Callable[..., None]
                                      ) -> Callable[..., None]:
-        @cls._run_if_checked_class
         def wrapper(self: TestCase) -> None:
             assert isinstance(self, TestCaseAugmented)
             for attr_name in self.checked_class.to_save_versioned():
@@ -56,6 +47,23 @@ class TestCaseAugmented(TestCase):
                 f(self, owner, attr_name, attr, default, to_set)
         return wrapper
 
+    @classmethod
+    def _run_if_sans_db(cls, f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseSansDB) -> None:
+            if issubclass(cls, TestCaseSansDB):
+                f(self)
+        return wrapper
+
+    @classmethod
+    def _run_if_with_db_but_not_server(cls,
+                                       f: Callable[..., None]
+                                       ) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if issubclass(cls, TestCaseWithDB) and\
+                    not issubclass(cls, TestCaseWithServer):
+                f(self)
+        return wrapper
+
     @classmethod
     def _make_from_defaults(cls, id_: float | str | None) -> Any:
         return cls.checked_class(id_, **cls.default_init_kwargs)
@@ -66,7 +74,7 @@ class TestCaseSansDB(TestCaseAugmented):
     legal_ids: list[str] | list[int] = [1, 5]
     illegal_ids: list[str] | list[int] = [0]
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_sans_db
     def test_id_validation(self) -> None:
         """Test .id_ validation/setting."""
         for id_ in self.illegal_ids:
@@ -76,6 +84,7 @@ class TestCaseSansDB(TestCaseAugmented):
             obj = self._make_from_defaults(id_)
             self.assertEqual(obj.id_, id_)
 
+    @TestCaseAugmented._run_if_sans_db
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_set(self,
                            _: Any,
@@ -115,6 +124,7 @@ class TestCaseSansDB(TestCaseAugmented):
         attr.set(to_set[1])
         self.assertEqual(timesorted_vals, expected)
 
+    @TestCaseAugmented._run_if_sans_db
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_newest(self,
                               _: Any,
@@ -134,6 +144,7 @@ class TestCaseSansDB(TestCaseAugmented):
         attr.set(default)
         self.assertEqual(attr.newest, default)
 
+    @TestCaseAugmented._run_if_sans_db
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_at(self,
                           _: Any,
@@ -277,6 +288,7 @@ class TestCaseWithDB(TestCaseAugmented):
         self.assertEqual(start, end)
         self.assertEqual(items, [obj_today])
 
+    @TestCaseAugmented._run_if_with_db_but_not_server
     @TestCaseAugmented._run_on_versioned_attributes
     def test_saving_versioned_attributes(self,
                                          owner: Any,
@@ -318,7 +330,7 @@ class TestCaseWithDB(TestCaseAugmented):
         attr_vals_saved = retrieve_attr_vals(attr)
         self.assertEqual(to_set, attr_vals_saved)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_saving_and_caching(self) -> None:
         """Test effects of .cache() and .save()."""
         id1 = self.default_ids[0]
@@ -353,7 +365,7 @@ class TestCaseWithDB(TestCaseAugmented):
         with self.assertRaises(HandledException):
             obj1.save(self.db_conn)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_by_id(self) -> None:
         """Test .by_id()."""
         id1, id2, _ = self.default_ids
@@ -369,7 +381,7 @@ class TestCaseWithDB(TestCaseAugmented):
         obj2.save(self.db_conn)
         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_by_id_or_create(self) -> None:
         """Test .by_id_or_create."""
         # check .by_id_or_create fails if wrong class
@@ -392,7 +404,7 @@ class TestCaseWithDB(TestCaseAugmented):
             self.checked_class.by_id(self.db_conn, item.id_)
         self.assertEqual(self.checked_class(item.id_), item)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
         id_ = self.default_ids[0]
@@ -416,6 +428,7 @@ class TestCaseWithDB(TestCaseAugmented):
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
+    @TestCaseAugmented._run_if_with_db_but_not_server
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_history_from_row(self,
                                         owner: Any,
@@ -439,7 +452,7 @@ class TestCaseWithDB(TestCaseAugmented):
             for timestamp, value in attr.history.items():
                 self.assertEqual(value, loaded_attr.history[timestamp])
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
         id1, id2, id3 = self.default_ids
@@ -457,7 +470,7 @@ class TestCaseWithDB(TestCaseAugmented):
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
@@ -469,6 +482,7 @@ class TestCaseWithDB(TestCaseAugmented):
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
+    @TestCaseAugmented._run_if_with_db_but_not_server
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_singularity(self,
                                    owner: Any,
@@ -485,7 +499,7 @@ class TestCaseWithDB(TestCaseAugmented):
         attr_retrieved = getattr(retrieved, attr_name)
         self.assertEqual(attr.history, attr_retrieved.history)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
         id_ = self.default_ids[0]
@@ -947,13 +961,18 @@ class TestCaseWithServer(TestCaseWithDB):
         else:
             self.assertEqual(self.conn.getresponse().status, expected_code)
 
-    def check_get_defaults(self, path: str) -> None:
+    def check_get_defaults(self,
+                           path: str,
+                           default_id: str = '1',
+                           id_name: str = 'id'
+                           ) -> None:
         """Some standard model paths to test."""
-        self.check_get(path, 200)
-        self.check_get(f'{path}?id=', 200)
-        self.check_get(f'{path}?id=foo', 400)
-        self.check_get(f'/{path}?id=0', 500)
-        self.check_get(f'{path}?id=1', 200)
+        nonexist_status = 200 if self.checked_class.can_create_by_id else 404
+        self.check_get(path, nonexist_status)
+        self.check_get(f'{path}?{id_name}=', 400)
+        self.check_get(f'{path}?{id_name}=foo', 400)
+        self.check_get(f'/{path}?{id_name}=0', 400)
+        self.check_get(f'{path}?{id_name}={default_id}', nonexist_status)
 
     def check_json_get(self, path: str, expected: Expected) -> None:
         """Compare JSON on GET path with expected.
-- 
2.30.2