X-Git-Url: https://plomlompom.com/repos/?a=blobdiff_plain;f=plomtask%2Fdb.py;fp=plomtask%2Fdb.py;h=90ec8332fb655b7d56a20ed9e4dcae05cd429ab0;hb=0630d9cfdc47e306b96ad05b4077ee96eec71226;hp=b5461a507e9612e2593643e6d6e198779e2fc456;hpb=2337eb37f27fdc60e3cb000052f00218d815c49f;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index b5461a5..90ec833 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -172,11 +172,17 @@ class DatabaseConnection: self.conn.close() def rewrite_relations(self, table_name: str, key: str, target: int | str, - rows: list[list[Any]]) -> None: - """Rewrite relations in table_name to target, with rows values.""" + rows: list[list[Any]], key_index: int = 0) -> None: + # pylint: disable=too-many-arguments + """Rewrite relations in table_name to target, with rows values. + + Note that single rows are expected without the column and value + identified by key and target, which are inserted inside the function + at key_index. + """ self.delete_where(table_name, key, target) for row in rows: - values = tuple([target] + row) + values = tuple(row[:key_index] + [target] + row[key_index:]) q_marks = self.__class__.q_marks_from_values(values) self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values) @@ -229,7 +235,7 @@ class BaseModel(Generic[BaseModelId]): table_name = '' to_save: list[str] = [] to_save_versioned: list[str] = [] - to_save_relations: list[tuple[str, str, str]] = [] + to_save_relations: list[tuple[str, str, str, int]] = [] id_: None | BaseModelId cache_: dict[BaseModelId, Self] to_search: list[str] = [] @@ -420,11 +426,11 @@ class BaseModel(Generic[BaseModelId]): self.cache() for attr_name in self.to_save_versioned: getattr(self, attr_name).save(db_conn) - for table, column, attr_name in self.to_save_relations: + for table, column, attr_name, key_index in self.to_save_relations: assert isinstance(self.id_, (int, str)) db_conn.rewrite_relations(table, column, self.id_, [[i.id_] for i - in getattr(self, attr_name)]) + in getattr(self, attr_name)], key_index) def remove(self, db_conn: DatabaseConnection) -> None: """Remove from DB and cache, including dependencies.""" @@ -432,7 +438,7 @@ class BaseModel(Generic[BaseModelId]): raise HandledException('cannot remove unsaved item') for attr_name in self.to_save_versioned: getattr(self, attr_name).remove(db_conn) - for table, column, attr_name in self.to_save_relations: + for table, column, attr_name, _ in self.to_save_relations: db_conn.delete_where(table, column, self.id_) self.uncache() db_conn.delete_where(self.table_name, 'id', self.id_)