From: Christian Heller Date: Thu, 2 Jan 2025 15:45:46 +0000 (+0100) Subject: Fuse DatabaseConnection.exec and .exec_on_vals. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/%7B%7Bdb.prefix%7D%7D/templates?a=commitdiff_plain;h=1da23c9594f6308aa03d6e09273d1f9fa4ce61e7;p=plomtask Fuse DatabaseConnection.exec and .exec_on_vals. --- diff --git a/plomtask/db.py b/plomtask/db.py index f067cd3..dc75648 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -164,14 +164,19 @@ class DatabaseConnection: """Commit SQL transaction.""" self.conn.commit() - def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor: - """Add commands to SQL transaction.""" - return self.conn.execute(code, inputs) - - def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor: - """Wrapper around .exec appending adequate " (?, …)" to code.""" - q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')' - return self.exec(f'{code} {q_marks_from_values}', inputs) + def exec(self, + code: str, + inputs: tuple[Any, ...] = tuple(), + build_q_marks: bool = True + ) -> Cursor: + """Wrapper around sqlite3.Connection.execute, building '?' if inputs""" + if len(inputs) > 0: + if build_q_marks: + q_marks = ('?' if len(inputs) == 1 + else '(' + ','.join(['?'] * len(inputs)) + ')') + return self.conn.execute(f'{code} {q_marks}', inputs) + return self.conn.execute(code, inputs) + return self.conn.execute(code) def close(self) -> None: """Close DB connection.""" @@ -189,12 +194,12 @@ class DatabaseConnection: self.delete_where(table_name, key, target) for row in rows: values = tuple(row[:key_index] + [target] + row[key_index:]) - self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values) + self.exec(f'INSERT INTO {table_name} VALUES', values) def row_where(self, table_name: str, key: str, target: int | str) -> list[Row]: """Return list of Rows at table where key == target.""" - return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?', + return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} =', (target,))) # def column_where_pattern(self, @@ -213,7 +218,7 @@ class DatabaseConnection: """Return column of table where key == target.""" return [row[0] for row in self.exec(f'SELECT {column} FROM {table_name} ' - f'WHERE {key} = ?', (target,))] + f'WHERE {key} =', (target,))] def column_all(self, table_name: str, column: str) -> list[Any]: """Return complete column of table.""" @@ -223,7 +228,7 @@ class DatabaseConnection: def delete_where(self, table_name: str, key: str, target: int | str) -> None: """Delete from table where key == target.""" - self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,)) + self.exec(f'DELETE FROM {table_name} WHERE {key} =', (target,)) BaseModelId = TypeVar('BaseModelId', int, str) @@ -504,7 +509,8 @@ class BaseModel(Generic[BaseModelId]): items = [] sql = f'SELECT id FROM {cls.table_name} ' sql += f'WHERE {date_col} >= ? AND {date_col} <= ?' - for row in db_conn.exec(sql, (start_date, end_date)): + for row in db_conn.exec(sql, (start_date, end_date), + build_q_marks=False): items += [cls.by_id(db_conn, row[0])] return items, start_date, end_date @@ -544,8 +550,7 @@ class BaseModel(Generic[BaseModelId]): values = tuple([self.id_] + [getattr(self, key) for key in self.to_save_simples]) table_name = self.table_name - cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES', - values) + cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES', values) if not isinstance(self.id_, str): self.id_ = cursor.lastrowid # type: ignore[assignment] self.cache()