From f02e0fc13b49dc5b38924ba3ad8c485007a72cb2 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Fri, 21 Jun 2024 15:26:37 +0200
Subject: [PATCH] Refactor request handler identifying items by ID param on
 GET.

---
 plomtask/http.py | 69 ++++++++++++++++++++++++++++--------------------
 tests/todos.py   |  4 +--
 2 files changed, 42 insertions(+), 31 deletions(-)

diff --git a/plomtask/http.py b/plomtask/http.py
index 4c0d6a3..28812df 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -221,6 +221,25 @@ class TaskHandler(BaseHTTPRequestHandler):
 
     # GET handlers
 
+    @staticmethod
+    def _get_item(target_class: Any
+                  ) -> Callable[..., Callable[[TaskHandler],
+                                              dict[str, object]]]:
+        def decorator(f: Callable[..., dict[str, object]]
+                      ) -> Callable[[TaskHandler], dict[str, object]]:
+            def wrapper(self: TaskHandler) -> dict[str, object]:
+                # pylint: disable=protected-access
+                # (because pylint here fails to detect the use of wrapper as a
+                # method to self with respective access privileges)
+                id_ = self._params.get_int_or_none('id')
+                if target_class.can_create_by_id:
+                    item = target_class.by_id_or_create(self.conn, id_)
+                else:
+                    item = target_class.by_id(self.conn, id_)
+                return f(self, item)
+            return wrapper
+        return decorator
+
     def do_GET_(self) -> str:
         """Return redirect target on GET /."""
         return '/day'
@@ -279,7 +298,8 @@ class TaskHandler(BaseHTTPRequestHandler):
                 'conditions_present': conditions_present,
                 'processes': Process.all(self.conn)}
 
-    def do_GET_todo(self) -> dict[str, object]:
+    @_get_item(Todo)
+    def do_GET_todo(self, todo: Todo) -> dict[str, object]:
         """Show single Todo of ?id=."""
 
         @dataclass
@@ -330,8 +350,6 @@ class TaskHandler(BaseHTTPRequestHandler):
                 ids = ids | collect_adoptables_keys(node.children)
             return ids
 
-        id_ = self._params.get_int('id')
-        todo = Todo.by_id(self.conn, id_)
         todo_steps = [step.todo for step in todo.get_step_tree(set()).children]
         process_tree = todo.process.get_steps(self.conn, None)
         steps_todo_to_process: list[TodoStepsNode] = []
@@ -407,10 +425,9 @@ class TaskHandler(BaseHTTPRequestHandler):
                 'sort_by': sort_by,
                 'pattern': pattern}
 
-    def do_GET_condition(self) -> dict[str, object]:
+    @_get_item(Condition)
+    def do_GET_condition(self, c: Condition) -> dict[str, object]:
         """Show Condition of ?id=."""
-        id_ = self._params.get_int_or_none('id')
-        c = Condition.by_id_or_create(self.conn, id_)
         ps = Process.all(self.conn)
         return {'condition': c, 'is_new': c.id_ is None,
                 'enabled_processes': [p for p in ps if c in p.conditions],
@@ -418,22 +435,19 @@ class TaskHandler(BaseHTTPRequestHandler):
                 'enabling_processes': [p for p in ps if c in p.enables],
                 'disabling_processes': [p for p in ps if c in p.disables]}
 
-    def do_GET_condition_titles(self) -> dict[str, object]:
+    @_get_item(Condition)
+    def do_GET_condition_titles(self, c: Condition) -> dict[str, object]:
         """Show title history of Condition of ?id=."""
-        id_ = self._params.get_int('id')
-        condition = Condition.by_id(self.conn, id_)
-        return {'condition': condition}
+        return {'condition': c}
 
-    def do_GET_condition_descriptions(self) -> dict[str, object]:
+    @_get_item(Condition)
+    def do_GET_condition_descriptions(self, c: Condition) -> dict[str, object]:
         """Show description historys of Condition of ?id=."""
-        id_ = self._params.get_int('id')
-        condition = Condition.by_id(self.conn, id_)
-        return {'condition': condition}
+        return {'condition': c}
 
-    def do_GET_process(self) -> dict[str, object]:
+    @_get_item(Process)
+    def do_GET_process(self, process: Process) -> dict[str, object]:
         """Show Process of ?id=."""
-        id_ = self._params.get_int_or_none('id')
-        process = Process.by_id_or_create(self.conn, id_)
         title_64 = self._params.get_str('title_b64')
         if title_64:
             title = b64decode(title_64.encode()).decode()
@@ -451,23 +465,20 @@ class TaskHandler(BaseHTTPRequestHandler):
                 'process_candidates': Process.all(self.conn),
                 'condition_candidates': Condition.all(self.conn)}
 
-    def do_GET_process_titles(self) -> dict[str, object]:
+    @_get_item(Process)
+    def do_GET_process_titles(self, p: Process) -> dict[str, object]:
         """Show title history of Process of ?id=."""
-        id_ = self._params.get_int('id')
-        process = Process.by_id(self.conn, id_)
-        return {'process': process}
+        return {'process': p}
 
-    def do_GET_process_descriptions(self) -> dict[str, object]:
+    @_get_item(Process)
+    def do_GET_process_descriptions(self, p: Process) -> dict[str, object]:
         """Show description historys of Process of ?id=."""
-        id_ = self._params.get_int('id')
-        process = Process.by_id(self.conn, id_)
-        return {'process': process}
+        return {'process': p}
 
-    def do_GET_process_efforts(self) -> dict[str, object]:
+    @_get_item(Process)
+    def do_GET_process_efforts(self, p: Process) -> dict[str, object]:
         """Show default effort history of Process of ?id=."""
-        id_ = self._params.get_int('id')
-        process = Process.by_id(self.conn, id_)
-        return {'process': process}
+        return {'process': p}
 
     def do_GET_processes(self) -> dict[str, object]:
         """Show all Processes."""
diff --git a/tests/todos.py b/tests/todos.py
index 626b744..0998c69 100644
--- a/tests/todos.py
+++ b/tests/todos.py
@@ -414,8 +414,8 @@ class TestsWithServer(TestCaseWithServer):
         self.post_process()
         form_data = {'day_comment': '', 'new_todo': 1, 'make_type': 'full'}
         self.check_post(form_data, '/day?date=2024-01-01&make_type=full', 302)
-        self.check_get('/todo', 400)
-        self.check_get('/todo?id=', 400)
+        self.check_get('/todo', 404)
+        self.check_get('/todo?id=', 404)
         self.check_get('/todo?id=foo', 400)
         self.check_get('/todo?id=0', 404)
         self.check_get('/todo?id=1', 200)
-- 
2.30.2