home · contact · privacy
Hide (almost all) remaining SQL code in DB module.
[plomtask] / plomtask / todos.py
index cfac5b536e91514f02d938fc760818d3d7278129..fd72af6bf8c0842f2a2727185c2b4c3f707a20d4 100644 (file)
@@ -1,5 +1,7 @@
 """Actionables."""
 from __future__ import annotations
+from typing import Any
+from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.processes import Process
 from plomtask.conditions import Condition
@@ -32,38 +34,37 @@ class Todo(BaseModel):
             self.undoes = process.undoes[:]
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: int | None) -> Todo:
+    def from_table_row(cls, db_conn: DatabaseConnection,
+                       row: Row | list[Any]) -> Todo:
+        """Make from DB row, write to DB cache."""
+        if row[1] == 0:
+            raise NotFoundException('calling Todo of '
+                                    'unsaved Process')
+        row_as_list = list(row)
+        row_as_list[1] = Process.by_id(db_conn, row[1])
+        todo = super().from_table_row(db_conn, row_as_list)
+        assert isinstance(todo, Todo)
+        return todo
+
+    @classmethod
+    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo:
         """Get Todo of .id_=id_ and children (from DB cache if possible)."""
-        if id_ in db_conn.cached_todos.keys():
-            todo = db_conn.cached_todos[id_]
-        else:
-            todo = None
-            for row in db_conn.exec('SELECT * FROM todos WHERE id = ?',
-                                    (id_,)):
-                row = list(row)
-                if row[1] == 0:
-                    raise NotFoundException('calling Todo of '
-                                            'unsaved Process')
-                row[1] = Process.by_id(db_conn, row[1])
-                todo = cls.from_table_row(db_conn, row)
-                break
-            if todo is None:
-                raise NotFoundException(f'Todo of ID not found: {id_}')
-            for row in db_conn.exec('SELECT child FROM todo_children '
-                                    'WHERE parent = ?', (id_,)):
-                todo.children += [cls.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT parent FROM todo_children '
-                                    'WHERE child = ?', (id_,)):
-                todo.parents += [cls.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT condition FROM todo_conditions '
-                                    'WHERE todo = ?', (id_,)):
-                todo.conditions += [Condition.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT condition FROM todo_fulfills '
-                                    'WHERE todo = ?', (id_,)):
-                todo.fulfills += [Condition.by_id(db_conn, row[0])]
-            for row in db_conn.exec('SELECT condition FROM todo_undoes '
-                                    'WHERE todo = ?', (id_,)):
-                todo.undoes += [Condition.by_id(db_conn, row[0])]
+        todo, from_cache = super()._by_id(db_conn, id_)
+        if todo is None:
+            raise NotFoundException(f'Todo of ID not found: {id_}')
+        if not from_cache:
+            for t_id in db_conn.column_where('todo_children', 'child',
+                                             'parent', id_):
+                todo.children += [cls.by_id(db_conn, t_id)]
+            for t_id in db_conn.column_where('todo_children', 'parent',
+                                             'child', id_):
+                todo.parents += [cls.by_id(db_conn, t_id)]
+            for name in ('conditions', 'fulfills', 'undoes'):
+                table = f'todo_{name}'
+                for cond_id in db_conn.column_where(table, 'condition',
+                                                    'todo', todo.id_):
+                    target = getattr(todo, name)
+                    target += [Condition.by_id(db_conn, cond_id)]
         assert isinstance(todo, Todo)
         return todo
 
@@ -71,18 +72,19 @@ class Todo(BaseModel):
     def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Todo]:
         """Collect all Todos for Day of date."""
         todos = []
-        for row in db_conn.exec('SELECT id FROM todos WHERE day = ?', (date,)):
-            todos += [cls.by_id(db_conn, row[0])]
+        for id_ in db_conn.column_where('todos', 'id', 'day', date):
+            todos += [cls.by_id(db_conn, id_)]
         return todos
 
     @classmethod
     def enablers_for_at(cls, db_conn: DatabaseConnection, condition: Condition,
                         date: str) -> list[Todo]:
         """Collect all Todos of day that enable condition."""
+        assert isinstance(condition.id_, int)
         enablers = []
-        for row in db_conn.exec('SELECT todo FROM todo_fulfills '
-                                'WHERE condition = ?', (condition.id_,)):
-            todo = cls.by_id(db_conn, row[0])
+        for id_ in db_conn.column_where('todo_fulfills', 'todo', 'condition',
+                                        condition.id_):
+            todo = cls.by_id(db_conn, id_)
             if todo.date == date:
                 enablers += [todo]
         return enablers
@@ -91,10 +93,11 @@ class Todo(BaseModel):
     def disablers_for_at(cls, db_conn: DatabaseConnection,
                          condition: Condition, date: str) -> list[Todo]:
         """Collect all Todos of day that disable condition."""
+        assert isinstance(condition.id_, int)
         disablers = []
-        for row in db_conn.exec('SELECT todo FROM todo_undoes '
-                                'WHERE condition = ?', (condition.id_,)):
-            todo = cls.by_id(db_conn, row[0])
+        for id_ in db_conn.column_where('todo_undoes', 'todo', 'condition',
+                                        condition.id_):
+            todo = cls.by_id(db_conn, id_)
             if todo.date == date:
                 disablers += [todo]
         return disablers
@@ -174,29 +177,11 @@ class Todo(BaseModel):
         self.save_core(db_conn)
         assert isinstance(self.id_, int)
         db_conn.cached_todos[self.id_] = self
-        db_conn.exec('DELETE FROM todo_children WHERE parent = ?',
-                     (self.id_,))
-        for child in self.children:
-            db_conn.exec('INSERT INTO todo_children VALUES (?, ?)',
-                         (self.id_, child.id_))
-        db_conn.exec('DELETE FROM todo_fulfills WHERE todo = ?', (self.id_,))
-        for condition in self.fulfills:
-            if condition.id_ is None:
-                raise NotFoundException('Fulfilled Condition of Todo '
-                                        'without ID (not saved?)')
-            db_conn.exec('INSERT INTO todo_fulfills VALUES (?, ?)',
-                         (self.id_, condition.id_))
-        db_conn.exec('DELETE FROM todo_undoes WHERE todo = ?', (self.id_,))
-        for condition in self.undoes:
-            if condition.id_ is None:
-                raise NotFoundException('Undone Condition of Todo '
-                                        'without ID (not saved?)')
-            db_conn.exec('INSERT INTO todo_undoes VALUES (?, ?)',
-                         (self.id_, condition.id_))
-        db_conn.exec('DELETE FROM todo_conditions WHERE todo = ?', (self.id_,))
-        for condition in self.conditions:
-            if condition.id_ is None:
-                raise NotFoundException('Condition of Todo '
-                                        'without ID (not saved?)')
-            db_conn.exec('INSERT INTO todo_conditions VALUES (?, ?)',
-                         (self.id_, condition.id_))
+        db_conn.rewrite_relations('todo_children', 'parent', self.id_,
+                                  [[c.id_] for c in self.children])
+        db_conn.rewrite_relations('todo_conditions', 'todo', self.id_,
+                                  [[c.id_] for c in self.conditions])
+        db_conn.rewrite_relations('todo_fulfills', 'todo', self.id_,
+                                  [[c.id_] for c in self.fulfills])
+        db_conn.rewrite_relations('todo_undoes', 'todo', self.id_,
+                                  [[c.id_] for c in self.undoes])