X-Git-Url: https://plomlompom.com/repos/?a=blobdiff_plain;f=plomtask%2Fdb.py;h=2ea7421feec578dab3fb8ba9e1eecf547eb6a36c;hb=e3b01fe14d7a3b824b909382671acc4657e98145;hp=4396b444a61f0c2f59e2046b554e79c56e541faa;hpb=f92de64d072009c8c4bf96b9eeb9fa245045662b;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index 4396b44..2ea7421 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -6,9 +6,9 @@ from difflib import Differ from sqlite3 import connect as sql_connect, Cursor, Row from typing import Any, Self, TypeVar, Generic from plomtask.exceptions import HandledException, NotFoundException -from plomtask.dating import (MIN_RANGE_DATE, MAX_RANGE_DATE, valid_date) +from plomtask.dating import valid_date -EXPECTED_DB_VERSION = 4 +EXPECTED_DB_VERSION = 5 MIGRATIONS_DIR = 'migrations' FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql' PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}' @@ -172,11 +172,17 @@ class DatabaseConnection: self.conn.close() def rewrite_relations(self, table_name: str, key: str, target: int | str, - rows: list[list[Any]]) -> None: - """Rewrite relations in table_name to target, with rows values.""" + rows: list[list[Any]], key_index: int = 0) -> None: + # pylint: disable=too-many-arguments + """Rewrite relations in table_name to target, with rows values. + + Note that single rows are expected without the column and value + identified by key and target, which are inserted inside the function + at key_index. + """ self.delete_where(table_name, key, target) for row in rows: - values = tuple([target] + row) + values = tuple(row[:key_index] + [target] + row[key_index:]) q_marks = self.__class__.q_marks_from_values(values) self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values) @@ -229,7 +235,7 @@ class BaseModel(Generic[BaseModelId]): table_name = '' to_save: list[str] = [] to_save_versioned: list[str] = [] - to_save_relations: list[tuple[str, str, str]] = [] + to_save_relations: list[tuple[str, str, str, int]] = [] id_: None | BaseModelId cache_: dict[BaseModelId, Self] to_search: list[str] = [] @@ -358,19 +364,19 @@ class BaseModel(Generic[BaseModelId]): @classmethod def by_date_range_with_limits(cls: type[BaseModelInstance], db_conn: DatabaseConnection, - date_range: tuple[str, str] = ('', ''), + date_range: tuple[str, str], date_col: str = 'day' ) -> tuple[list[BaseModelInstance], str, str]: """Return list of Days in database within (open) date_range interval. - If no range values provided, defaults them to MIN_RANGE_DATE and - MAX_RANGE_DATE. Also knows to properly interpret 'today' as value. + If no range values provided, defaults them to 'yesterday' and + 'tomorrow'. Knows to properly interpret these and 'today' as value. """ - min_date = MIN_RANGE_DATE - max_date = MAX_RANGE_DATE - start_date = valid_date(date_range[0] if date_range[0] else min_date) - end_date = valid_date(date_range[1] if date_range[1] else max_date) + start_str = date_range[0] if date_range[0] else 'yesterday' + end_str = date_range[1] if date_range[1] else 'tomorrow' + start_date = valid_date(start_str) + end_date = valid_date(end_str) items = [] sql = f'SELECT id FROM {cls.table_name} ' sql += f'WHERE {date_col} >= ? AND {date_col} <= ?' @@ -420,11 +426,11 @@ class BaseModel(Generic[BaseModelId]): self.cache() for attr_name in self.to_save_versioned: getattr(self, attr_name).save(db_conn) - for table, column, attr_name in self.to_save_relations: + for table, column, attr_name, key_index in self.to_save_relations: assert isinstance(self.id_, (int, str)) db_conn.rewrite_relations(table, column, self.id_, [[i.id_] for i - in getattr(self, attr_name)]) + in getattr(self, attr_name)], key_index) def remove(self, db_conn: DatabaseConnection) -> None: """Remove from DB and cache, including dependencies.""" @@ -432,7 +438,7 @@ class BaseModel(Generic[BaseModelId]): raise HandledException('cannot remove unsaved item') for attr_name in self.to_save_versioned: getattr(self, attr_name).remove(db_conn) - for table, column, attr_name in self.to_save_relations: + for table, column, attr_name, _ in self.to_save_relations: db_conn.delete_where(table, column, self.id_) self.uncache() db_conn.delete_where(self.table_name, 'id', self.id_)