home · contact · privacy
Initialize Days with links to their Todos as early as possible.
[plomtask] / plomtask / db.py
index b5461a507e9612e2593643e6d6e198779e2fc456..bd609685096b9d2dbc3dd0b0ea4243ec47904058 100644 (file)
@@ -8,7 +8,7 @@ from typing import Any, Self, TypeVar, Generic
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
-EXPECTED_DB_VERSION = 4
+EXPECTED_DB_VERSION = 5
 MIGRATIONS_DIR = 'migrations'
 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
 MIGRATIONS_DIR = 'migrations'
 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
@@ -172,11 +172,17 @@ class DatabaseConnection:
         self.conn.close()
 
     def rewrite_relations(self, table_name: str, key: str, target: int | str,
         self.conn.close()
 
     def rewrite_relations(self, table_name: str, key: str, target: int | str,
-                          rows: list[list[Any]]) -> None:
-        """Rewrite relations in table_name to target, with rows values."""
+                          rows: list[list[Any]], key_index: int = 0) -> None:
+        # pylint: disable=too-many-arguments
+        """Rewrite relations in table_name to target, with rows values.
+
+        Note that single rows are expected without the column and value
+        identified by key and target, which are inserted inside the function
+        at key_index.
+        """
         self.delete_where(table_name, key, target)
         for row in rows:
         self.delete_where(table_name, key, target)
         for row in rows:
-            values = tuple([target] + row)
+            values = tuple(row[:key_index] + [target] + row[key_index:])
             q_marks = self.__class__.q_marks_from_values(values)
             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
 
             q_marks = self.__class__.q_marks_from_values(values)
             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
 
@@ -229,7 +235,7 @@ class BaseModel(Generic[BaseModelId]):
     table_name = ''
     to_save: list[str] = []
     to_save_versioned: list[str] = []
     table_name = ''
     to_save: list[str] = []
     to_save_versioned: list[str] = []
-    to_save_relations: list[tuple[str, str, str]] = []
+    to_save_relations: list[tuple[str, str, str, int]] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
@@ -362,7 +368,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of Days in database within (open) date_range interval.
+        """Return list of items in database within (open) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -420,11 +426,11 @@ class BaseModel(Generic[BaseModelId]):
         self.cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         self.cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
-        for table, column, attr_name in self.to_save_relations:
+        for table, column, attr_name, key_index in self.to_save_relations:
             assert isinstance(self.id_, (int, str))
             db_conn.rewrite_relations(table, column, self.id_,
                                       [[i.id_] for i
             assert isinstance(self.id_, (int, str))
             db_conn.rewrite_relations(table, column, self.id_,
                                       [[i.id_] for i
-                                       in getattr(self, attr_name)])
+                                       in getattr(self, attr_name)], key_index)
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
@@ -432,7 +438,7 @@ class BaseModel(Generic[BaseModelId]):
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)
-        for table, column, attr_name in self.to_save_relations:
+        for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)
         self.uncache()
         db_conn.delete_where(self.table_name, 'id', self.id_)
             db_conn.delete_where(table, column, self.id_)
         self.uncache()
         db_conn.delete_where(self.table_name, 'id', self.id_)