X-Git-Url: https://plomlompom.com/repos/?a=blobdiff_plain;f=plomtask%2Fdb.py;h=b5461a507e9612e2593643e6d6e198779e2fc456;hb=6b9970ff864e0e63527213fea5c0bed40ba877a7;hp=7962eabeffd28964c0892b87f7ce35e6052a2f3e;hpb=83266154e9140151c975586d21f393a5eb3f4ef4;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index 7962eab..b5461a5 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -6,8 +6,9 @@ from difflib import Differ from sqlite3 import connect as sql_connect, Cursor, Row from typing import Any, Self, TypeVar, Generic from plomtask.exceptions import HandledException, NotFoundException +from plomtask.dating import valid_date -EXPECTED_DB_VERSION = 1 +EXPECTED_DB_VERSION = 4 MIGRATIONS_DIR = 'migrations' FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql' PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}' @@ -131,6 +132,9 @@ class DatabaseFile: # pylint: disable=too-few-public-methods new_row += [f' {segment}'] new_row[0] = new_row[0].lstrip() new_row[-1] = new_row[-1].lstrip() + if new_row[-1] != ')' and new_row[-3][-1] != ',': + new_row[-3] = new_row[-3] + ',' + new_row[-2:] = [' ' + new_row[-1][:-1]] + [')'] new_rows += ['\n'.join(new_row)] return new_rows @@ -182,6 +186,17 @@ class DatabaseConnection: return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?', (target,))) + # def column_where_pattern(self, + # table_name: str, + # column: str, + # pattern: str, + # keys: list[str]) -> list[Any]: + # """Return column of rows where one of keys matches pattern.""" + # targets = tuple([f'%{pattern}%'] * len(keys)) + # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys]) + # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}' + # return [row[0] for row in self.exec(sql, targets)] + def column_where(self, table_name: str, column: str, key: str, target: int | str) -> list[Any]: """Return column of table where key == target.""" @@ -217,6 +232,7 @@ class BaseModel(Generic[BaseModelId]): to_save_relations: list[tuple[str, str, str]] = [] id_: None | BaseModelId cache_: dict[BaseModelId, Self] + to_search: list[str] = [] def __init__(self, id_: BaseModelId | None) -> None: if isinstance(id_, int) and id_ < 1: @@ -339,6 +355,49 @@ class BaseModel(Generic[BaseModelId]): items[item.id_] = item return list(items.values()) + @classmethod + def by_date_range_with_limits(cls: type[BaseModelInstance], + db_conn: DatabaseConnection, + date_range: tuple[str, str], + date_col: str = 'day' + ) -> tuple[list[BaseModelInstance], str, + str]: + """Return list of Days in database within (open) date_range interval. + + If no range values provided, defaults them to 'yesterday' and + 'tomorrow'. Knows to properly interpret these and 'today' as value. + """ + start_str = date_range[0] if date_range[0] else 'yesterday' + end_str = date_range[1] if date_range[1] else 'tomorrow' + start_date = valid_date(start_str) + end_date = valid_date(end_str) + items = [] + sql = f'SELECT id FROM {cls.table_name} ' + sql += f'WHERE {date_col} >= ? AND {date_col} <= ?' + for row in db_conn.exec(sql, (start_date, end_date)): + items += [cls.by_id(db_conn, row[0])] + return items, start_date, end_date + + @classmethod + def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection, + pattern: str) -> list[BaseModelInstance]: + """Return all objects whose .to_search match pattern.""" + items = cls.all(db_conn) + if pattern: + filtered = [] + for item in items: + for attr_name in cls.to_search: + toks = attr_name.split('.') + parent = item + for tok in toks: + attr = getattr(parent, tok) + parent = attr + if pattern in attr: + filtered += [item] + break + return filtered + return items + def save(self, db_conn: DatabaseConnection) -> None: """Write self to DB and cache and ensure .id_.