From: Christian Heller Date: Fri, 19 Apr 2024 05:26:01 +0000 (+0200) Subject: Hide (almost all) remaining SQL code in DB module. X-Git-Url: https://plomlompom.com/repos/?a=commitdiff_plain;h=2d0d3a138de69e5e09208936ac094b53b0785c0b;p=plomtask Hide (almost all) remaining SQL code in DB module. --- diff --git a/plomtask/conditions.py b/plomtask/conditions.py index b87e3ac..9a44200 100644 --- a/plomtask/conditions.py +++ b/plomtask/conditions.py @@ -27,7 +27,7 @@ class Condition(BaseModel): assert isinstance(condition, Condition) for name in ('title', 'description'): table_name = f'condition_{name}s' - for row_ in db_conn.all_where(table_name, 'parent', row[0]): + for row_ in db_conn.row_where(table_name, 'parent', row[0]): getattr(condition, name).history_from_row(row_) return condition @@ -38,9 +38,9 @@ class Condition(BaseModel): for id_, condition in db_conn.cached_conditions.items(): conditions[id_] = condition already_recorded = conditions.keys() - for row in db_conn.exec('SELECT id FROM conditions'): - if row[0] not in already_recorded: - condition = cls.by_id(db_conn, row[0]) + for id_ in db_conn.column_all('conditions', 'id'): + if id_ not in already_recorded: + condition = cls.by_id(db_conn, id_) conditions[condition.id_] = condition return list(conditions.values()) diff --git a/plomtask/db.py b/plomtask/db.py index f45d2b6..848e750 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -70,17 +70,34 @@ class DatabaseConnection: def rewrite_relations(self, table_name: str, key: str, target: int, rows: list[list[Any]]) -> None: """Rewrite relations in table_name to target, with rows values.""" - self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,)) + self.delete_where(table_name, key, target) for row in rows: values = tuple([target] + row) q_marks = self.__class__.q_marks_from_values(values) self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values) - def all_where(self, table_name: str, key: str, target: int) -> list[Row]: + def row_where(self, table_name: str, key: str, + target: int | str) -> list[Row]: """Return list of Rows at table where key == target.""" return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?', (target,))) + def column_where(self, table_name: str, column: str, key: str, + target: int | str) -> list[Any]: + """Return column of table where key == target.""" + return [row[0] for row in + self.exec(f'SELECT {column} FROM {table_name} ' + f'WHERE {key} = ?', (target,))] + + def column_all(self, table_name: str, column: str) -> list[Any]: + """Return complete column of table.""" + return [row[0] for row in + self.exec(f'SELECT {column} FROM {table_name}')] + + def delete_where(self, table_name: str, key: str, target: int) -> None: + """Delete from table where key == target.""" + self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,)) + @staticmethod def q_marks_from_values(values: tuple[Any]) -> str: """Return placeholder to insert values into SQL code.""" @@ -116,8 +133,7 @@ class BaseModel: obj = cache[id_] from_cache = True else: - for row in db_conn.exec(f'SELECT * FROM {cls.table_name} ' - 'WHERE id = ?', (id_,)): + for row in db_conn.row_where(cls.table_name, 'id', id_): obj = cls.from_table_row(db_conn, row) cache[id_] = obj break diff --git a/plomtask/http.py b/plomtask/http.py index 5f739b6..cc4358c 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -134,7 +134,7 @@ class TaskHandler(BaseHTTPRequestHandler): def do_GET_todo(self) -> dict[str, object]: """Show single Todo of ?id=.""" - id_ = self.params.get_int_or_none('id') + id_ = self.params.get_int('id') todo = Todo.by_id(self.conn, id_) return {'todo': todo, 'todo_candidates': Todo.by_date(self.conn, todo.date), @@ -198,7 +198,7 @@ class TaskHandler(BaseHTTPRequestHandler): def do_POST_todo(self) -> None: """Update Todo and its children.""" - id_ = self.params.get_int_or_none('id') + id_ = self.params.get_int('id') todo = Todo.by_id(self.conn, id_) child_id = self.form_data.get_int_or_none('adopt') if child_id is not None: diff --git a/plomtask/processes.py b/plomtask/processes.py index e5851d0..490acc3 100644 --- a/plomtask/processes.py +++ b/plomtask/processes.py @@ -30,9 +30,9 @@ class Process(BaseModel): for id_, process in db_conn.cached_processes.items(): processes[id_] = process already_recorded = processes.keys() - for row in db_conn.exec('SELECT id FROM processes'): - if row[0] not in already_recorded: - process = cls.by_id(db_conn, row[0]) + for id_ in db_conn.column_all('processes', 'id'): + if id_ not in already_recorded: + process = cls.by_id(db_conn, id_) processes[process.id_] = process return list(processes.values()) @@ -50,26 +50,29 @@ class Process(BaseModel): if isinstance(process.id_, int): for name in ('title', 'description', 'effort'): table = f'process_{name}s' - for row in db_conn.all_where(table, 'parent', process.id_): + for row in db_conn.row_where(table, 'parent', process.id_): getattr(process, name).history_from_row(row) - for row in db_conn.all_where('process_steps', 'owner', + for row in db_conn.row_where('process_steps', 'owner', process.id_): step = ProcessStep.from_table_row(db_conn, row) process.explicit_steps += [step] for name in ('conditions', 'fulfills', 'undoes'): table = f'process_{name}' - for row in db_conn.all_where(table, 'process', process.id_): + for cond_id in db_conn.column_where(table, 'condition', + 'process', process.id_): target = getattr(process, name) - target += [Condition.by_id(db_conn, row[1])] + target += [Condition.by_id(db_conn, cond_id)] assert isinstance(process, Process) return process def used_as_step_by(self, db_conn: DatabaseConnection) -> list[Process]: """Return Processes using self for a ProcessStep.""" + if not self.id_: + return [] owner_ids = set() - for owner_id in db_conn.exec('SELECT owner FROM process_steps WHERE' - ' step_process = ?', (self.id_,)): - owner_ids.add(owner_id[0]) + for id_ in db_conn.column_where('process_steps', 'owner', + 'step_process', self.id_): + owner_ids.add(id_) return [self.__class__.by_id(db_conn, id_) for id_ in owner_ids] def get_steps(self, db_conn: DatabaseConnection, external_owner: @@ -162,12 +165,12 @@ class Process(BaseModel): def set_steps(self, db_conn: DatabaseConnection, steps: list[tuple[int | None, int, int | None]]) -> None: """Set self.explicit_steps in bulk.""" + assert isinstance(self.id_, int) for step in self.explicit_steps: assert isinstance(step.id_, int) del db_conn.cached_process_steps[step.id_] self.explicit_steps = [] - db_conn.exec('DELETE FROM process_steps WHERE owner = ?', - (self.id_,)) + db_conn.delete_where('process_steps', 'owner', self.id_) for step_tuple in steps: self._add_step(db_conn, step_tuple[0], step_tuple[1], step_tuple[2]) @@ -185,8 +188,7 @@ class Process(BaseModel): [[c.id_] for c in self.fulfills]) db_conn.rewrite_relations('process_undoes', 'process', self.id_, [[c.id_] for c in self.undoes]) - db_conn.exec('DELETE FROM process_steps WHERE owner = ?', - (self.id_,)) + db_conn.delete_where('process_steps', 'owner', self.id_) for step in self.explicit_steps: step.save(db_conn) db_conn.cached_processes[self.id_] = self diff --git a/plomtask/todos.py b/plomtask/todos.py index 348dbdd..fd72af6 100644 --- a/plomtask/todos.py +++ b/plomtask/todos.py @@ -47,30 +47,24 @@ class Todo(BaseModel): return todo @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: int | None) -> Todo: + def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Todo: """Get Todo of .id_=id_ and children (from DB cache if possible).""" - if id_: - todo, from_cache = super()._by_id(db_conn, id_) - else: - todo, from_cache = None, False + todo, from_cache = super()._by_id(db_conn, id_) if todo is None: raise NotFoundException(f'Todo of ID not found: {id_}') if not from_cache: - for row in db_conn.exec('SELECT child FROM todo_children ' - 'WHERE parent = ?', (id_,)): - todo.children += [cls.by_id(db_conn, row[0])] - for row in db_conn.exec('SELECT parent FROM todo_children ' - 'WHERE child = ?', (id_,)): - todo.parents += [cls.by_id(db_conn, row[0])] - for row in db_conn.exec('SELECT condition FROM todo_conditions ' - 'WHERE todo = ?', (id_,)): - todo.conditions += [Condition.by_id(db_conn, row[0])] - for row in db_conn.exec('SELECT condition FROM todo_fulfills ' - 'WHERE todo = ?', (id_,)): - todo.fulfills += [Condition.by_id(db_conn, row[0])] - for row in db_conn.exec('SELECT condition FROM todo_undoes ' - 'WHERE todo = ?', (id_,)): - todo.undoes += [Condition.by_id(db_conn, row[0])] + for t_id in db_conn.column_where('todo_children', 'child', + 'parent', id_): + todo.children += [cls.by_id(db_conn, t_id)] + for t_id in db_conn.column_where('todo_children', 'parent', + 'child', id_): + todo.parents += [cls.by_id(db_conn, t_id)] + for name in ('conditions', 'fulfills', 'undoes'): + table = f'todo_{name}' + for cond_id in db_conn.column_where(table, 'condition', + 'todo', todo.id_): + target = getattr(todo, name) + target += [Condition.by_id(db_conn, cond_id)] assert isinstance(todo, Todo) return todo @@ -78,18 +72,19 @@ class Todo(BaseModel): def by_date(cls, db_conn: DatabaseConnection, date: str) -> list[Todo]: """Collect all Todos for Day of date.""" todos = [] - for row in db_conn.exec('SELECT id FROM todos WHERE day = ?', (date,)): - todos += [cls.by_id(db_conn, row[0])] + for id_ in db_conn.column_where('todos', 'id', 'day', date): + todos += [cls.by_id(db_conn, id_)] return todos @classmethod def enablers_for_at(cls, db_conn: DatabaseConnection, condition: Condition, date: str) -> list[Todo]: """Collect all Todos of day that enable condition.""" + assert isinstance(condition.id_, int) enablers = [] - for row in db_conn.exec('SELECT todo FROM todo_fulfills ' - 'WHERE condition = ?', (condition.id_,)): - todo = cls.by_id(db_conn, row[0]) + for id_ in db_conn.column_where('todo_fulfills', 'todo', 'condition', + condition.id_): + todo = cls.by_id(db_conn, id_) if todo.date == date: enablers += [todo] return enablers @@ -98,10 +93,11 @@ class Todo(BaseModel): def disablers_for_at(cls, db_conn: DatabaseConnection, condition: Condition, date: str) -> list[Todo]: """Collect all Todos of day that disable condition.""" + assert isinstance(condition.id_, int) disablers = [] - for row in db_conn.exec('SELECT todo FROM todo_undoes ' - 'WHERE condition = ?', (condition.id_,)): - todo = cls.by_id(db_conn, row[0]) + for id_ in db_conn.column_where('todo_undoes', 'todo', 'condition', + condition.id_): + todo = cls.by_id(db_conn, id_) if todo.date == date: disablers += [todo] return disablers diff --git a/tests/todos.py b/tests/todos.py index 17454c5..426bb91 100644 --- a/tests/todos.py +++ b/tests/todos.py @@ -210,7 +210,7 @@ class TestsWithServer(TestCaseWithServer): self.check_post(form_data, '/day?date=2024-01-01', 302, '/') form_data = {} self.check_post(form_data, '/todo=', 404) - self.check_post(form_data, '/todo?id=', 404) + self.check_post(form_data, '/todo?id=', 400) self.check_post(form_data, '/todo?id=FOO', 400) self.check_post(form_data, '/todo?id=0', 404) todo1 = post_and_reload(form_data) @@ -249,8 +249,8 @@ class TestsWithServer(TestCaseWithServer): self.check_post(form_data, '/process?id=', 302, '/') form_data = {'comment': '', 'new_todo': 1} self.check_post(form_data, '/day?date=2024-01-01', 302, '/') - self.check_get('/todo', 404) - self.check_get('/todo?id=', 404) + self.check_get('/todo', 400) + self.check_get('/todo?id=', 400) self.check_get('/todo?id=foo', 400) self.check_get('/todo?id=0', 404) self.check_get('/todo?id=1', 200)