From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 2 Jan 2025 15:45:46 +0000 (+0100)
Subject: Fuse DatabaseConnection.exec and .exec_on_vals.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/process_titles?a=commitdiff_plain;h=1da23c9594f6308aa03d6e09273d1f9fa4ce61e7;p=plomtask

Fuse DatabaseConnection.exec and .exec_on_vals.
---

diff --git a/plomtask/db.py b/plomtask/db.py
index f067cd3..dc75648 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -164,14 +164,19 @@ class DatabaseConnection:
         """Commit SQL transaction."""
         self.conn.commit()
 
-    def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
-        """Add commands to SQL transaction."""
-        return self.conn.execute(code, inputs)
-
-    def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
-        """Wrapper around .exec appending adequate " (?, …)" to code."""
-        q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
-        return self.exec(f'{code} {q_marks_from_values}', inputs)
+    def exec(self,
+             code: str,
+             inputs: tuple[Any, ...] = tuple(),
+             build_q_marks: bool = True
+             ) -> Cursor:
+        """Wrapper around sqlite3.Connection.execute, building '?' if inputs"""
+        if len(inputs) > 0:
+            if build_q_marks:
+                q_marks = ('?' if len(inputs) == 1
+                           else '(' + ','.join(['?'] * len(inputs)) + ')')
+                return self.conn.execute(f'{code} {q_marks}', inputs)
+            return self.conn.execute(code, inputs)
+        return self.conn.execute(code)
 
     def close(self) -> None:
         """Close DB connection."""
@@ -189,12 +194,12 @@ class DatabaseConnection:
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
-            self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
+            self.exec(f'INSERT INTO {table_name} VALUES', values)
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
         """Return list of Rows at table where key == target."""
-        return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
+        return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} =',
                               (target,)))
 
     # def column_where_pattern(self,
@@ -213,7 +218,7 @@ class DatabaseConnection:
         """Return column of table where key == target."""
         return [row[0] for row in
                 self.exec(f'SELECT {column} FROM {table_name} '
-                          f'WHERE {key} = ?', (target,))]
+                          f'WHERE {key} =', (target,))]
 
     def column_all(self, table_name: str, column: str) -> list[Any]:
         """Return complete column of table."""
@@ -223,7 +228,7 @@ class DatabaseConnection:
     def delete_where(self, table_name: str, key: str,
                      target: int | str) -> None:
         """Delete from table where key == target."""
-        self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
+        self.exec(f'DELETE FROM {table_name} WHERE {key} =', (target,))
 
 
 BaseModelId = TypeVar('BaseModelId', int, str)
@@ -504,7 +509,8 @@ class BaseModel(Generic[BaseModelId]):
         items = []
         sql = f'SELECT id FROM {cls.table_name} '
         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
-        for row in db_conn.exec(sql, (start_date, end_date)):
+        for row in db_conn.exec(sql, (start_date, end_date),
+                                build_q_marks=False):
             items += [cls.by_id(db_conn, row[0])]
         return items, start_date, end_date
 
@@ -544,8 +550,7 @@ class BaseModel(Generic[BaseModelId]):
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save_simples])
         table_name = self.table_name
-        cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
-                                      values)
+        cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES', values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()