X-Git-Url: https://plomlompom.com/repos/foo.html?a=blobdiff_plain;ds=sidebyside;f=plomtask%2Fdb.py;h=2ea7421feec578dab3fb8ba9e1eecf547eb6a36c;hb=30aef71506f7d6215b04cddaba8fddba1788f883;hp=4396b444a61f0c2f59e2046b554e79c56e541faa;hpb=f92de64d072009c8c4bf96b9eeb9fa245045662b;p=plomtask
diff --git a/plomtask/db.py b/plomtask/db.py
index 4396b44..2ea7421 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -6,9 +6,9 @@ from difflib import Differ
from sqlite3 import connect as sql_connect, Cursor, Row
from typing import Any, Self, TypeVar, Generic
from plomtask.exceptions import HandledException, NotFoundException
-from plomtask.dating import (MIN_RANGE_DATE, MAX_RANGE_DATE, valid_date)
+from plomtask.dating import valid_date
-EXPECTED_DB_VERSION = 4
+EXPECTED_DB_VERSION = 5
MIGRATIONS_DIR = 'migrations'
FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
@@ -172,11 +172,17 @@ class DatabaseConnection:
self.conn.close()
def rewrite_relations(self, table_name: str, key: str, target: int | str,
- rows: list[list[Any]]) -> None:
- """Rewrite relations in table_name to target, with rows values."""
+ rows: list[list[Any]], key_index: int = 0) -> None:
+ # pylint: disable=too-many-arguments
+ """Rewrite relations in table_name to target, with rows values.
+
+ Note that single rows are expected without the column and value
+ identified by key and target, which are inserted inside the function
+ at key_index.
+ """
self.delete_where(table_name, key, target)
for row in rows:
- values = tuple([target] + row)
+ values = tuple(row[:key_index] + [target] + row[key_index:])
q_marks = self.__class__.q_marks_from_values(values)
self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
@@ -229,7 +235,7 @@ class BaseModel(Generic[BaseModelId]):
table_name = ''
to_save: list[str] = []
to_save_versioned: list[str] = []
- to_save_relations: list[tuple[str, str, str]] = []
+ to_save_relations: list[tuple[str, str, str, int]] = []
id_: None | BaseModelId
cache_: dict[BaseModelId, Self]
to_search: list[str] = []
@@ -358,19 +364,19 @@ class BaseModel(Generic[BaseModelId]):
@classmethod
def by_date_range_with_limits(cls: type[BaseModelInstance],
db_conn: DatabaseConnection,
- date_range: tuple[str, str] = ('', ''),
+ date_range: tuple[str, str],
date_col: str = 'day'
) -> tuple[list[BaseModelInstance], str,
str]:
"""Return list of Days in database within (open) date_range interval.
- If no range values provided, defaults them to MIN_RANGE_DATE and
- MAX_RANGE_DATE. Also knows to properly interpret 'today' as value.
+ If no range values provided, defaults them to 'yesterday' and
+ 'tomorrow'. Knows to properly interpret these and 'today' as value.
"""
- min_date = MIN_RANGE_DATE
- max_date = MAX_RANGE_DATE
- start_date = valid_date(date_range[0] if date_range[0] else min_date)
- end_date = valid_date(date_range[1] if date_range[1] else max_date)
+ start_str = date_range[0] if date_range[0] else 'yesterday'
+ end_str = date_range[1] if date_range[1] else 'tomorrow'
+ start_date = valid_date(start_str)
+ end_date = valid_date(end_str)
items = []
sql = f'SELECT id FROM {cls.table_name} '
sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
@@ -420,11 +426,11 @@ class BaseModel(Generic[BaseModelId]):
self.cache()
for attr_name in self.to_save_versioned:
getattr(self, attr_name).save(db_conn)
- for table, column, attr_name in self.to_save_relations:
+ for table, column, attr_name, key_index in self.to_save_relations:
assert isinstance(self.id_, (int, str))
db_conn.rewrite_relations(table, column, self.id_,
[[i.id_] for i
- in getattr(self, attr_name)])
+ in getattr(self, attr_name)], key_index)
def remove(self, db_conn: DatabaseConnection) -> None:
"""Remove from DB and cache, including dependencies."""
@@ -432,7 +438,7 @@ class BaseModel(Generic[BaseModelId]):
raise HandledException('cannot remove unsaved item')
for attr_name in self.to_save_versioned:
getattr(self, attr_name).remove(db_conn)
- for table, column, attr_name in self.to_save_relations:
+ for table, column, attr_name, _ in self.to_save_relations:
db_conn.delete_where(table, column, self.id_)
self.uncache()
db_conn.delete_where(self.table_name, 'id', self.id_)