home · contact · privacy
Harmonize treatment of GET /[item]?id=.
authorChristian Heller <c.heller@plomlompom.de>
Mon, 12 Aug 2024 11:58:27 +0000 (13:58 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Mon, 12 Aug 2024 11:58:27 +0000 (13:58 +0200)
plomtask/db.py
plomtask/http.py
tests/conditions.py
tests/days.py
tests/misc.py
tests/processes.py
tests/todos.py
tests/utils.py

index 1fdd3e1c21d96804b37611d13664b89ace333535..f067cd35246850d2600c05c31a5e57fcc3d2d925 100644 (file)
@@ -5,7 +5,8 @@ from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from typing import Any, Self, TypeVar, Generic, Callable
-from plomtask.exceptions import HandledException, NotFoundException
+from plomtask.exceptions import (HandledException, NotFoundException,
+                                 BadFormatException)
 from plomtask.dating import valid_date
 
 EXPECTED_DB_VERSION = 5
@@ -246,10 +247,10 @@ class BaseModel(Generic[BaseModelId]):
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
-            raise HandledException(msg)
+            raise BadFormatException(msg)
         if isinstance(id_, str) and "" == id_:
             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
-            raise HandledException(msg)
+            raise BadFormatException(msg)
         self.id_ = id_
 
     def __hash__(self) -> int:
@@ -437,6 +438,8 @@ class BaseModel(Generic[BaseModelId]):
         """
         obj = None
         if id_ is not None:
+            if isinstance(id_, int) and id_ == 0:
+                raise BadFormatException('illegal ID of value 0')
             obj = cls._get_cached(id_)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
index 4426bbab1288623b01385c8b4f798eb2a019ed0e..e242a3647f752ca9545f6d41752c538ace109f72 100644 (file)
@@ -50,7 +50,7 @@ class InputsParser:
         """Retrieve list of int values at key."""
         all_str = self.get_all_str(key)
         try:
-            return [int(s) for s in all_str if len(s) > 0]
+            return [int(s) for s in all_str]
         except ValueError as e:
             msg = f'cannot int a form field value for key {key} in: {all_str}'
             raise BadFormatException(msg) from e
@@ -305,7 +305,9 @@ class TaskHandler(BaseHTTPRequestHandler):
                 # pylint: disable=protected-access
                 # (because pylint here fails to detect the use of wrapper as a
                 # method to self with respective access privileges)
-                id_ = self._params.get_int_or_none('id')
+                id_ = None
+                for val in self._params.get_all_int('id'):
+                    id_ = val
                 if target_class.can_create_by_id:
                     item = target_class.by_id_or_create(self._conn, id_)
                 else:
@@ -348,7 +350,7 @@ class TaskHandler(BaseHTTPRequestHandler):
 
     def do_GET_day(self) -> dict[str, object]:
         """Show single Day of ?date=."""
-        date = self._params.get_str_or_fail('date', date_in_n_days(0))
+        date = self._params.get_str('date', date_in_n_days(0))
         make_type = self._params.get_str_or_fail('make_type', 'full')
         #
         day = Day.by_id_or_create(self._conn, date)
index a9b28bbb9fd59a7b124ffb740ca45bdafd48c0e2..58fa18b09610c28a71956e1b3118367bb3dfa154 100644 (file)
@@ -72,6 +72,7 @@ class ExpectedGetCondition(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Module tests against our HTTP server/handler (and database)."""
+    checked_class = Condition
 
     def test_fail_POST_condition(self) -> None:
         """Test malformed/illegal POST /condition requests."""
@@ -152,8 +153,8 @@ class TestsWithServer(TestCaseWithServer):
         self.check_filter(exp, 'conditions', 'sort_by', 'is_active', [1, 2, 3])
         self.check_filter(exp, 'conditions', 'sort_by', '-is_active',
                           [3, 2, 1])
-        # test pattern matching on title
         exp.set('sort_by', 'title')
+        # test pattern matching on title
         exp.lib_del('Condition', 1)
         self.check_filter(exp, 'conditions', 'pattern', 'ba', [2, 3])
         # test pattern matching on description
index aac150b91e62ac56a7ed6837172fa0b3fd16c797..5edec502ca5ef49ae811ac0da7d96940e98fd03b 100644 (file)
@@ -159,12 +159,12 @@ class ExpectedGetDay(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
+    checked_class = Day
 
     def test_basic_GET_day(self) -> None:
         """Test basic (no Processes/Conditions/Todos) GET /day basics."""
         # check illegal date parameters
-        self.check_get('/day?date=', 400)
-        self.check_get('/day?date=foo', 400)
+        self.check_get_defaults('/day', '2024-01-01', 'date')
         self.check_get('/day?date=2024-02-30', 400)
         # check undefined day
         date = _testing_date_in_n_days(0)
index 86474c7204cab11d2624a73a88a82554362b7df5..81591248c78b70c8b4592c10167e1061a4fb3ea5 100644 (file)
@@ -147,7 +147,8 @@ class TestsSansServer(TestCase):
         parser = InputsParser({'foo': []})
         self.assertEqual([], parser.get_all_int('foo'))
         parser = InputsParser({'foo': ['']})
-        self.assertEqual([], parser.get_all_int('foo'))
+        with self.assertRaises(BadFormatException):
+            parser.get_all_int('foo')
         parser = InputsParser({'foo': ['0']})
         self.assertEqual([0], parser.get_all_int('foo'))
         parser = InputsParser({'foo': ['0', '17']})
index 24a62bd02795e7cb34710238370b47c6dd89b2a5..2561fbbaa1e1615a082c124d1fe20b64265a53be 100644 (file)
@@ -284,6 +284,7 @@ class ExpectedGetProcesses(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Module tests against our HTTP server/handler (and database)."""
+    checked_class = Process
 
     def _post_process(self, id_: int = 1,
                       form_data: dict[str, Any] | None = None
@@ -404,19 +405,10 @@ class TestsWithServer(TestCaseWithServer):
         p = p_min | {'kept_steps': [1, 2, 3], 'new_step_to_2': 5, 'step_of': 6}
         self.check_post(p, url, 400)
 
-    def test_GET(self) -> None:
-        """Test /process and /processes response codes."""
-        self.check_get('/process', 200)
-        self.check_get('/process?id=', 200)
-        self.check_get('/process?id=1', 200)
-        self.check_get_defaults('/process')
-        self.check_get('/processes', 200)
-
     def test_fail_GET_process(self) -> None:
         """Test invalid GET /process params."""
         # check for invalid IDs
-        self.check_get('/process?id=foo', 400)
-        self.check_get('/process?id=0', 500)
+        self.check_get_defaults('/process')
         # check we catch invalid base64
         self.check_get('/process?title_b64=foo', 400)
         # check failure on references to unknown processes; we create Process
index f048d468eee981db2f3ae68759b4ff6066898b41..9f3874d72efe6995bf8ffbc63e7f4b86ccfd9eb3 100644 (file)
@@ -266,6 +266,7 @@ class ExpectedGetTodo(Expected):
 
 class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
+    checked_class = Todo
 
     def _post_exp_todo(
             self, id_: int, payload: dict[str, Any], exp: Expected) -> None:
@@ -278,7 +279,7 @@ class TestsWithServer(TestCaseWithServer):
         # test we cannot just POST into non-existing Todo
         self.check_post({}, '/todo', 404)
         self.check_post({}, '/todo?id=FOO', 400)
-        self.check_post({}, '/todo?id=0', 404)
+        self.check_post({}, '/todo?id=0', 400)
         self.check_post({}, '/todo?id=1', 404)
         # test malformed values on existing Todo
         self.post_exp_day([], {'new_todo': [1]})
@@ -463,11 +464,7 @@ class TestsWithServer(TestCaseWithServer):
     def test_GET_todo(self) -> None:
         """Test GET /todo response codes."""
         # test malformed or illegal parameter values
-        self.check_get('/todo', 404)
-        self.check_get('/todo?id=', 404)
-        self.check_get('/todo?id=foo', 400)
-        self.check_get('/todo?id=0', 404)
-        self.check_get('/todo?id=2', 404)
+        self.check_get_defaults('/todo')
         # test all existing Processes are shown as available
         exp = ExpectedGetTodo(1)
         self.post_exp_process([exp], {}, 1)
index 7945f61fc113b169aaf39827474df2fdc8466211..75c7e50f0b32342cb1f114d572b077ed95612c26 100644 (file)
@@ -35,17 +35,8 @@ class TestCaseAugmented(TestCase):
     default_init_kwargs: dict[str, Any] = {}
 
     @staticmethod
-    def _run_if_checked_class(f: Callable[..., None]) -> Callable[..., None]:
-        def wrapper(self: TestCase) -> None:
-            if hasattr(self, 'checked_class'):
-                f(self)
-        return wrapper
-
-    @classmethod
-    def _run_on_versioned_attributes(cls,
-                                     f: Callable[..., None]
+    def _run_on_versioned_attributes(f: Callable[..., None]
                                      ) -> Callable[..., None]:
-        @cls._run_if_checked_class
         def wrapper(self: TestCase) -> None:
             assert isinstance(self, TestCaseAugmented)
             for attr_name in self.checked_class.to_save_versioned():
@@ -56,6 +47,23 @@ class TestCaseAugmented(TestCase):
                 f(self, owner, attr_name, attr, default, to_set)
         return wrapper
 
+    @classmethod
+    def _run_if_sans_db(cls, f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseSansDB) -> None:
+            if issubclass(cls, TestCaseSansDB):
+                f(self)
+        return wrapper
+
+    @classmethod
+    def _run_if_with_db_but_not_server(cls,
+                                       f: Callable[..., None]
+                                       ) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            if issubclass(cls, TestCaseWithDB) and\
+                    not issubclass(cls, TestCaseWithServer):
+                f(self)
+        return wrapper
+
     @classmethod
     def _make_from_defaults(cls, id_: float | str | None) -> Any:
         return cls.checked_class(id_, **cls.default_init_kwargs)
@@ -66,7 +74,7 @@ class TestCaseSansDB(TestCaseAugmented):
     legal_ids: list[str] | list[int] = [1, 5]
     illegal_ids: list[str] | list[int] = [0]
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_sans_db
     def test_id_validation(self) -> None:
         """Test .id_ validation/setting."""
         for id_ in self.illegal_ids:
@@ -76,6 +84,7 @@ class TestCaseSansDB(TestCaseAugmented):
             obj = self._make_from_defaults(id_)
             self.assertEqual(obj.id_, id_)
 
+    @TestCaseAugmented._run_if_sans_db
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_set(self,
                            _: Any,
@@ -115,6 +124,7 @@ class TestCaseSansDB(TestCaseAugmented):
         attr.set(to_set[1])
         self.assertEqual(timesorted_vals, expected)
 
+    @TestCaseAugmented._run_if_sans_db
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_newest(self,
                               _: Any,
@@ -134,6 +144,7 @@ class TestCaseSansDB(TestCaseAugmented):
         attr.set(default)
         self.assertEqual(attr.newest, default)
 
+    @TestCaseAugmented._run_if_sans_db
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_at(self,
                           _: Any,
@@ -277,6 +288,7 @@ class TestCaseWithDB(TestCaseAugmented):
         self.assertEqual(start, end)
         self.assertEqual(items, [obj_today])
 
+    @TestCaseAugmented._run_if_with_db_but_not_server
     @TestCaseAugmented._run_on_versioned_attributes
     def test_saving_versioned_attributes(self,
                                          owner: Any,
@@ -318,7 +330,7 @@ class TestCaseWithDB(TestCaseAugmented):
         attr_vals_saved = retrieve_attr_vals(attr)
         self.assertEqual(to_set, attr_vals_saved)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_saving_and_caching(self) -> None:
         """Test effects of .cache() and .save()."""
         id1 = self.default_ids[0]
@@ -353,7 +365,7 @@ class TestCaseWithDB(TestCaseAugmented):
         with self.assertRaises(HandledException):
             obj1.save(self.db_conn)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_by_id(self) -> None:
         """Test .by_id()."""
         id1, id2, _ = self.default_ids
@@ -369,7 +381,7 @@ class TestCaseWithDB(TestCaseAugmented):
         obj2.save(self.db_conn)
         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_by_id_or_create(self) -> None:
         """Test .by_id_or_create."""
         # check .by_id_or_create fails if wrong class
@@ -392,7 +404,7 @@ class TestCaseWithDB(TestCaseAugmented):
             self.checked_class.by_id(self.db_conn, item.id_)
         self.assertEqual(self.checked_class(item.id_), item)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
         id_ = self.default_ids[0]
@@ -416,6 +428,7 @@ class TestCaseWithDB(TestCaseAugmented):
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
 
+    @TestCaseAugmented._run_if_with_db_but_not_server
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_history_from_row(self,
                                         owner: Any,
@@ -439,7 +452,7 @@ class TestCaseWithDB(TestCaseAugmented):
             for timestamp, value in attr.history.items():
                 self.assertEqual(value, loaded_attr.history[timestamp])
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
         id1, id2, id3 = self.default_ids
@@ -457,7 +470,7 @@ class TestCaseWithDB(TestCaseAugmented):
         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
                          sorted([item1, item2, item3]))
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
@@ -469,6 +482,7 @@ class TestCaseWithDB(TestCaseAugmented):
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
+    @TestCaseAugmented._run_if_with_db_but_not_server
     @TestCaseAugmented._run_on_versioned_attributes
     def test_versioned_singularity(self,
                                    owner: Any,
@@ -485,7 +499,7 @@ class TestCaseWithDB(TestCaseAugmented):
         attr_retrieved = getattr(retrieved, attr_name)
         self.assertEqual(attr.history, attr_retrieved.history)
 
-    @TestCaseAugmented._run_if_checked_class
+    @TestCaseAugmented._run_if_with_db_but_not_server
     def test_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
         id_ = self.default_ids[0]
@@ -947,13 +961,18 @@ class TestCaseWithServer(TestCaseWithDB):
         else:
             self.assertEqual(self.conn.getresponse().status, expected_code)
 
-    def check_get_defaults(self, path: str) -> None:
+    def check_get_defaults(self,
+                           path: str,
+                           default_id: str = '1',
+                           id_name: str = 'id'
+                           ) -> None:
         """Some standard model paths to test."""
-        self.check_get(path, 200)
-        self.check_get(f'{path}?id=', 200)
-        self.check_get(f'{path}?id=foo', 400)
-        self.check_get(f'/{path}?id=0', 500)
-        self.check_get(f'{path}?id=1', 200)
+        nonexist_status = 200 if self.checked_class.can_create_by_id else 404
+        self.check_get(path, nonexist_status)
+        self.check_get(f'{path}?{id_name}=', 400)
+        self.check_get(f'{path}?{id_name}=foo', 400)
+        self.check_get(f'/{path}?{id_name}=0', 400)
+        self.check_get(f'{path}?{id_name}={default_id}', nonexist_status)
 
     def check_json_get(self, path: str, expected: Expected) -> None:
         """Compare JSON on GET path with expected.