From 2d0d3a138de69e5e09208936ac094b53b0785c0b Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Fri, 19 Apr 2024 07:26:01 +0200
Subject: [PATCH] Hide (almost all) remaining SQL code in DB module.

---
 plomtask/conditions.py |  8 +++----
 plomtask/db.py         | 24 +++++++++++++++----
 plomtask/http.py       |  4 ++--
 plomtask/processes.py  | 30 ++++++++++++------------
 plomtask/todos.py      | 52 +++++++++++++++++++-----------------------
 tests/todos.py         |  6 ++---
 6 files changed, 69 insertions(+), 55 deletions(-)

diff --git a/plomtask/conditions.py b/plomtask/conditions.py
index b87e3ac..9a44200 100644
--- a/plomtask/conditions.py
+++ b/plomtask/conditions.py
@@ -27,7 +27,7 @@ class Condition(BaseModel):
         assert isinstance(condition, Condition)
         for name in ('title', 'description'):
             table_name = f'condition_{name}s'
-            for row_ in db_conn.all_where(table_name, 'parent', row[0]):
+            for row_ in db_conn.row_where(table_name, 'parent', row[0]):
                 getattr(condition, name).history_from_row(row_)
         return condition
 
@@ -38,9 +38,9 @@ class Condition(BaseModel):
         for id_, condition in db_conn.cached_conditions.items():
             conditions[id_] = condition
         already_recorded = conditions.keys()
-        for row in db_conn.exec('SELECT id FROM conditions'):
-            if row[0] not in already_recorded:
-                condition = cls.by_id(db_conn, row[0])
+        for id_ in db_conn.column_all('conditions', 'id'):
+            if id_ not in already_recorded:
+                condition = cls.by_id(db_conn, id_)
                 conditions[condition.id_] = condition
         return list(conditions.values())
 
diff --git a/plomtask/db.py b/plomtask/db.py
index f45d2b6..848e750 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -70,17 +70,34 @@ class DatabaseConnection:
     def rewrite_relations(self, table_name: str, key: str, target: int,
                           rows: list[list[Any]]) -> None:
         """Rewrite relations in table_name to target, with rows values."""
-        self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
+        self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple([target] + row)
             q_marks = self.__class__.q_marks_from_values(values)
             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
 
-    def all_where(self, table_name: str, key: str, target: int) -> list[Row]:
+    def row_where(self, table_name: str, key: str,
+                  target: int | str) -> list[Row]:
         """Return list of Rows at table where key == target."""
         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
                               (target,)))
 
+    def column_where(self, table_name: str, column: str, key: str,
+                     target: int | str) -> list[Any]:
+        """Return column of table where key == target."""
+        return [row[0] for row in
+                self.exec(f'SELECT {column} FROM {table_name} '
+                          f'WHERE {key} = ?', (target,))]
+
+    def column_all(self, table_name: str, column: str) -> list[Any]:
+        """Return complete column of table."""
+        return [row[0] for row in
+                self.exec(f'SELECT {column} FROM {table_name}')]
+
+    def delete_where(self, table_name: str, key: str, target: int) -> None:
+        """Delete from table where key == target."""
+        self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
+
     @staticmethod
     def q_marks_from_values(values: tuple[Any]) -> str:
         """Return placeholder to insert values into SQL code."""
@@ -116,8 +133,7 @@ class BaseModel:
             obj = cache[id_]
             from_cache = True
         else:
-            for row in db_conn.exec(f'SELECT * FROM {cls.table_name} '
-                                    'WHERE id = ?', (id_,)):
+            for row in db_conn.row_where(cls.table_name, 'id', id_):
                 obj = cls.from_table_row(db_conn, row)
                 cache[id_] = obj
                 break
diff --git a/plomtask/http.py b/plomtask/http.py
index 5f739b6..cc4358c 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -134,7 +134,7 @@ class TaskHandler(BaseHTTPRequestHandler):
 
     def do_GET_todo(self) -> dict[str, object]:
         """Show single Todo of ?id=."""
-        id_ = self.params.get_int_or_none('id')
+        id_ = self.params.get_int('id')
         todo = Todo.by_id(self.conn, id_)
         return {'todo': todo,
                 'todo_candidates': Todo.by_date(self.conn, todo.date),
@@ -198,7 +198,7 @@ class TaskHandler(BaseHTTPRequestHandler):
 
     def do_POST_todo(self) -> None:
         """Update Todo and its children."""
-        id_ = self.params.get_int_or_none('id')
+        id_ = self.params.get_int('id')
         todo = Todo.by_id(self.conn, id_)
         child_id = self.form_data.get_int_or_none('adopt')
         if child_id is not None:
diff --git a/plomtask/processes.py b/plomtask/processes.py
index e5851d0..490acc3 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -30,9 +30,9 @@ class Process(BaseModel):
         for id_, process in db_conn.cached_processes.items():
             processes[id_] = process
         already_recorded = processes.keys()
-        for row in db_conn.exec('SELECT id FROM processes'):
-            if row[0] not in already_recorded:
-                process = cls.by_id(db_conn, row[0])
+        for id_ in db_conn.column_all('processes', 'id'):
+            if id_ not in already_recorded:
+                process = cls.by_id(db_conn, id_)
                 processes[process.id_] = process
         return list(processes.values())
 
@@ -50,26 +50,29 @@ class Process(BaseModel):
         if isinstance(process.id_, int):
             for name in ('title', 'description', 'effort'):
                 table = f'process_{name}s'
-                for row in db_conn.all_where(table, 'parent', process.id_):
+                for row in db_conn.row_where(table, 'parent', process.id_):
                     getattr(process, name).history_from_row(row)
-            for row in db_conn.all_where('process_steps', 'owner',
+            for row in db_conn.row_where('process_steps', 'owner',
                                          process.id_):
                 step = ProcessStep.from_table_row(db_conn, row)
                 process.explicit_steps += [step]
             for name in ('conditions', 'fulfills', 'undoes'):
                 table = f'process_{name}'
-                for row in db_conn.all_where(table, 'process', process.id_):
+                for cond_id in db_conn.column_where(table, 'condition',
+                                                    'process', process.id_):
                     target = getattr(process, name)
-                    target += [Condition.by_id(db_conn, row[1])]
+                    target += [Condition.by_id(db_conn, cond_id)]
         assert isinstance(process, Process)
         return process
 
     def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]:
         """Return Processes using self for a ProcessStep."""
+        if not self.id_:
+            return []
         owner_ids = set()
-        for owner_id in db_conn.exec('SELECT owner FROM process_steps WHERE'
-                                     ' step_process = ?', (self.id_,)):
-            owner_ids.add(owner_id[0])
+        for id_ in db_conn.column_where('process_steps', 'owner',
+                                        'step_process', self.id_):
+            owner_ids.add(id_)
         return [self.__class__.by_id(db_conn, id_) for id_ in owner_ids]
 
     def get_steps(self, db_conn: DatabaseConnection, external_owner:
@@ -162,12 +165,12 @@ class Process(BaseModel):
     def set_steps(self, db_conn: DatabaseConnection,
                   steps: list[tuple[int | None, int, int | None]]) -> None:
         """Set self.explicit_steps in bulk."""
+        assert isinstance(self.id_, int)
         for step in self.explicit_steps:
             assert isinstance(step.id_, int)
             del db_conn.cached_process_steps[step.id_]
         self.explicit_steps = []
-        db_conn.exec('DELETE FROM process_steps WHERE owner = ?',
-                     (self.id_,))
+        db_conn.delete_where('process_steps', 'owner', self.id_)
         for step_tuple in steps:
             self._add_step(db_conn, step_tuple[0],
                            step_tuple[1], step_tuple[2])
@@ -185,8 +188,7 @@ class Process(BaseModel):
                                   [[c.id_] for c in self.fulfills])
         db_conn.rewrite_relations('process_undoes', 'process', self.id_,
                                   [[c.id_] for c in self.undoes])
-        db_conn.exec('DELETE FROM process_steps WHERE owner = ?',
-                     (self.id_,))
+        db_conn.delete_where('process_steps', 'owner', self.id_)
         for step in self.explicit_steps:
             step.save(db_conn)
         db_conn.cached_processes[self.id_] = self
diff --git a/plomtask/todos.py b/plomtask/todos.py
index 348dbdd..fd72af6 100644
--- a/plomtask/todos.py
+++ b/plomtask/todos.py
@@ -47,30 +47,24 @@ class Todo(BaseModel):
         return todo
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int | None) -> Todo:
+    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo:
         """Get Todo of .id_=id_ and children (from DB cache if possible)."""
-        if id_:
-            todo, from_cache = super()._by_id(db_conn, id_)
-        else:
-            todo, from_cache = None, False
+        todo, from_cache = super()._by_id(db_conn, id_)
         if todo is None:
             raise NotFoundException(f'Todo of ID not found: {id_}')
         if not from_cache:
-            for row in db_conn.exec('SELECT child FROM todo_children '
-                                    'WHERE parent = ?', (id_,)):
-                todo.children += [cls.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT parent FROM todo_children '
-                                    'WHERE child = ?', (id_,)):
-                todo.parents += [cls.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT condition FROM todo_conditions '
-                                    'WHERE todo = ?', (id_,)):
-                todo.conditions += [Condition.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT condition FROM todo_fulfills '
-                                    'WHERE todo = ?', (id_,)):
-                todo.fulfills += [Condition.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT condition FROM todo_undoes '
-                                    'WHERE todo = ?', (id_,)):
-                todo.undoes += [Condition.by_id(db_conn, row[0])]
+            for t_id in db_conn.column_where('todo_children', 'child',
+                                             'parent', id_):
+                todo.children += [cls.by_id(db_conn, t_id)]
+            for t_id in db_conn.column_where('todo_children', 'parent',
+                                             'child', id_):
+                todo.parents += [cls.by_id(db_conn, t_id)]
+            for name in ('conditions', 'fulfills', 'undoes'):
+                table = f'todo_{name}'
+                for cond_id in db_conn.column_where(table, 'condition',
+                                                    'todo', todo.id_):
+                    target = getattr(todo, name)
+                    target += [Condition.by_id(db_conn, cond_id)]
         assert isinstance(todo, Todo)
         return todo
 
@@ -78,18 +72,19 @@ class Todo(BaseModel):
     def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Todo]:
         """Collect all Todos for Day of date."""
         todos = []
-        for row in db_conn.exec('SELECT id FROM todos WHERE day = ?', (date,)):
-            todos += [cls.by_id(db_conn, row[0])]
+        for id_ in db_conn.column_where('todos', 'id', 'day', date):
+            todos += [cls.by_id(db_conn, id_)]
         return todos
 
     @classmethod
     def enablers_for_at(cls, db_conn: DatabaseConnection, condition: Condition,
                         date: str) -> list[Todo]:
         """Collect all Todos of day that enable condition."""
+        assert isinstance(condition.id_, int)
         enablers = []
-        for row in db_conn.exec('SELECT todo FROM todo_fulfills '
-                                'WHERE condition = ?', (condition.id_,)):
-            todo = cls.by_id(db_conn, row[0])
+        for id_ in db_conn.column_where('todo_fulfills', 'todo', 'condition',
+                                        condition.id_):
+            todo = cls.by_id(db_conn, id_)
             if todo.date == date:
                 enablers += [todo]
         return enablers
@@ -98,10 +93,11 @@ class Todo(BaseModel):
     def disablers_for_at(cls, db_conn: DatabaseConnection,
                          condition: Condition, date: str) -> list[Todo]:
         """Collect all Todos of day that disable condition."""
+        assert isinstance(condition.id_, int)
         disablers = []
-        for row in db_conn.exec('SELECT todo FROM todo_undoes '
-                                'WHERE condition = ?', (condition.id_,)):
-            todo = cls.by_id(db_conn, row[0])
+        for id_ in db_conn.column_where('todo_undoes', 'todo', 'condition',
+                                        condition.id_):
+            todo = cls.by_id(db_conn, id_)
             if todo.date == date:
                 disablers += [todo]
         return disablers
diff --git a/tests/todos.py b/tests/todos.py
index 17454c5..426bb91 100644
--- a/tests/todos.py
+++ b/tests/todos.py
@@ -210,7 +210,7 @@ class TestsWithServer(TestCaseWithServer):
         self.check_post(form_data, '/day?date=2024-01-01', 302, '/')
         form_data = {}
         self.check_post(form_data, '/todo=', 404)
-        self.check_post(form_data, '/todo?id=', 404)
+        self.check_post(form_data, '/todo?id=', 400)
         self.check_post(form_data, '/todo?id=FOO', 400)
         self.check_post(form_data, '/todo?id=0', 404)
         todo1 = post_and_reload(form_data)
@@ -249,8 +249,8 @@ class TestsWithServer(TestCaseWithServer):
         self.check_post(form_data, '/process?id=', 302, '/')
         form_data = {'comment': '', 'new_todo': 1}
         self.check_post(form_data, '/day?date=2024-01-01', 302, '/')
-        self.check_get('/todo', 404)
-        self.check_get('/todo?id=', 404)
+        self.check_get('/todo', 400)
+        self.check_get('/todo?id=', 400)
         self.check_get('/todo?id=foo', 400)
         self.check_get('/todo?id=0', 404)
         self.check_get('/todo?id=1', 200)
-- 
2.30.2