X-Git-Url: https://plomlompom.com/repos/day?a=blobdiff_plain;f=plomtask%2Fdb.py;h=13cdaef5b9c7d3e992f8c92730a9979b9eee2d73;hb=HEAD;hp=cce2630cd58bfb8bf7283c2eb0d2f45006d8ba26;hpb=c021152e6566c8374170de916c69d6b5c816cd54;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index cce2630..f067cd3 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -4,8 +4,9 @@ from os import listdir from os.path import isfile from difflib import Differ from sqlite3 import connect as sql_connect, Cursor, Row -from typing import Any, Self, TypeVar, Generic -from plomtask.exceptions import HandledException, NotFoundException +from typing import Any, Self, TypeVar, Generic, Callable +from plomtask.exceptions import (HandledException, NotFoundException, + BadFormatException) from plomtask.dating import valid_date EXPECTED_DB_VERSION = 5 @@ -232,30 +233,33 @@ BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]') class BaseModel(Generic[BaseModelId]): """Template for most of the models we use/derive from the DB.""" table_name = '' - to_save: list[str] = [] - to_save_versioned: list[str] = [] + to_save_simples: list[str] = [] to_save_relations: list[tuple[str, str, str, int]] = [] + versioned_defaults: dict[str, str | float] = {} + add_to_dict: list[str] = [] id_: None | BaseModelId cache_: dict[BaseModelId, Self] to_search: list[str] = [] can_create_by_id = False _exists = True + sorters: dict[str, Callable[..., Any]] = {} def __init__(self, id_: BaseModelId | None) -> None: if isinstance(id_, int) and id_ < 1: msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}' - raise HandledException(msg) + raise BadFormatException(msg) if isinstance(id_, str) and "" == id_: msg = f'illegal {self.__class__.__name__} ID, must be non-empty' - raise HandledException(msg) + raise BadFormatException(msg) self.id_ = id_ def __hash__(self) -> int: - hashable = [self.id_] + [getattr(self, name) for name in self.to_save] + hashable = [self.id_] + [getattr(self, name) + for name in self.to_save_simples] for definition in self.to_save_relations: attr = getattr(self, definition[2]) hashable += [tuple(rel.id_ for rel in attr)] - for name in self.to_save_versioned: + for name in self.to_save_versioned(): hashable += [hash(getattr(self, name))] return hash(tuple(hashable)) @@ -272,25 +276,61 @@ class BaseModel(Generic[BaseModelId]): assert isinstance(other.id_, int) return self.id_ < other.id_ + @classmethod + def to_save_versioned(cls) -> list[str]: + """Return keys of cls.versioned_defaults assuming we wanna save 'em.""" + return list(cls.versioned_defaults.keys()) + @property - def as_dict(self) -> dict[str, object]: - """Return self as (json.dumps-coompatible) dict.""" + def as_dict_and_refs(self) -> tuple[dict[str, object], + list[BaseModel[int] | BaseModel[str]]]: + """Return self as json.dumps-ready dict, list of referenced objects.""" d: dict[str, object] = {'id': self.id_} - if len(self.to_save_versioned) > 0: + refs: list[BaseModel[int] | BaseModel[str]] = [] + for to_save in self.to_save_simples: + d[to_save] = getattr(self, to_save) + if len(self.to_save_versioned()) > 0: d['_versioned'] = {} - for k in self.to_save: - attr = getattr(self, k) - if hasattr(attr, 'as_dict'): - d[k] = attr.as_dict - d[k] = attr - for k in self.to_save_versioned: + for k in self.to_save_versioned(): attr = getattr(self, k) assert isinstance(d['_versioned'], dict) d['_versioned'][k] = attr.history - for r in self.to_save_relations: - attr_name = r[2] - d[attr_name] = [x.as_dict for x in getattr(self, attr_name)] - return d + rels_to_collect = [rel[2] for rel in self.to_save_relations] + rels_to_collect += self.add_to_dict + for attr_name in rels_to_collect: + rel_list = [] + for item in getattr(self, attr_name): + rel_list += [item.id_] + if item not in refs: + refs += [item] + d[attr_name] = rel_list + return d, refs + + @classmethod + def name_lowercase(cls) -> str: + """Convenience method to return cls' name in lowercase.""" + return cls.__name__.lower() + + @classmethod + def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title' + ) -> str: + """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed). + + Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to + ensure predictability where parts of seq are of same sort value. + """ + reverse = False + if len(sort_key) > 1 and '-' == sort_key[0]: + sort_key = sort_key[1:] + reverse = True + if sort_key not in cls.sorters: + sort_key = default + seq.sort(key=lambda x: x.id_, reverse=reverse) + sorter: Callable[..., Any] = cls.sorters[sort_key] + seq.sort(key=sorter, reverse=reverse) + if reverse: + sort_key = f'-{sort_key}' + return sort_key # cache management # (we primarily use the cache to ensure we work on the same object in @@ -301,7 +341,8 @@ class BaseModel(Generic[BaseModelId]): def __getattribute__(self, name: str) -> Any: """Ensure fail if ._disappear() was called, except to check ._exists""" if name != '_exists' and not super().__getattribute__('_exists'): - raise HandledException('Object does not exist.') + msg = f'Object for attribute does not exist: {name}' + raise HandledException(msg) return super().__getattribute__(name) def _disappear(self) -> None: @@ -326,16 +367,18 @@ class BaseModel(Generic[BaseModelId]): cls.cache_ = {} @classmethod - def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]: + def get_cache(cls: type[BaseModelInstance] + ) -> dict[Any, BaseModelInstance]: """Get cache dictionary, create it if not yet existing.""" if not hasattr(cls, 'cache_'): - d: dict[Any, BaseModel[Any]] = {} + d: dict[Any, BaseModelInstance] = {} cls.cache_ = d return cls.cache_ @classmethod def _get_cached(cls: type[BaseModelInstance], - id_: BaseModelId) -> BaseModelInstance | None: + id_: BaseModelId + ) -> BaseModelInstance | None: """Get object of id_ from class's cache, or None if not found.""" cache = cls.get_cache() if id_ in cache: @@ -378,7 +421,7 @@ class BaseModel(Generic[BaseModelId]): """Make from DB row (sans relations), update DB cache with it.""" obj = cls(*row) assert obj.id_ is not None - for attr_name in cls.to_save_versioned: + for attr_name in cls.to_save_versioned(): attr = getattr(obj, attr_name) table_name = attr.table_name for row_ in db_conn.row_where(table_name, 'parent', obj.id_): @@ -395,6 +438,8 @@ class BaseModel(Generic[BaseModelId]): """ obj = None if id_ is not None: + if isinstance(id_, int) and id_ == 0: + raise BadFormatException('illegal ID of value 0') obj = cls._get_cached(id_) if not obj: for row in db_conn.row_where(cls.table_name, 'id', id_): @@ -408,7 +453,7 @@ class BaseModel(Generic[BaseModelId]): def by_id_or_create(cls, db_conn: DatabaseConnection, id_: BaseModelId | None ) -> Self: - """Wrapper around .by_id, creating (not caching/saving) if not find.""" + """Wrapper around .by_id, creating (not caching/saving) if no find.""" if not cls.can_create_by_id: raise HandledException('Class cannot .by_id_or_create.') if id_ is None: @@ -438,7 +483,7 @@ class BaseModel(Generic[BaseModelId]): item = cls.by_id(db_conn, id_) assert item.id_ is not None items[item.id_] = item - return list(items.values()) + return sorted(list(items.values())) @classmethod def by_date_range_with_limits(cls: type[BaseModelInstance], @@ -447,7 +492,7 @@ class BaseModel(Generic[BaseModelId]): date_col: str = 'day' ) -> tuple[list[BaseModelInstance], str, str]: - """Return list of items in database within (open) date_range interval. + """Return list of items in DB within (closed) date_range interval. If no range values provided, defaults them to 'yesterday' and 'tomorrow'. Knows to properly interpret these and 'today' as value. @@ -489,7 +534,7 @@ class BaseModel(Generic[BaseModelId]): """Write self to DB and cache and ensure .id_. Write both to DB, and to cache. To DB, write .id_ and attributes - listed in cls.to_save[_versioned|_relations]. + listed in cls.to_save_[simples|versioned|_relations]. Ensure self.id_ by setting it to what the DB command returns as the last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already @@ -497,14 +542,14 @@ class BaseModel(Generic[BaseModelId]): only the case with the Day class, where it's to be a date string. """ values = tuple([self.id_] + [getattr(self, key) - for key in self.to_save]) + for key in self.to_save_simples]) table_name = self.table_name cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES', values) if not isinstance(self.id_, str): self.id_ = cursor.lastrowid # type: ignore[assignment] self.cache() - for attr_name in self.to_save_versioned: + for attr_name in self.to_save_versioned(): getattr(self, attr_name).save(db_conn) for table, column, attr_name, key_index in self.to_save_relations: assert isinstance(self.id_, (int, str)) @@ -516,7 +561,7 @@ class BaseModel(Generic[BaseModelId]): """Remove from DB and cache, including dependencies.""" if self.id_ is None or self._get_cached(self.id_) is None: raise HandledException('cannot remove unsaved item') - for attr_name in self.to_save_versioned: + for attr_name in self.to_save_versioned(): getattr(self, attr_name).remove(db_conn) for table, column, attr_name, _ in self.to_save_relations: db_conn.delete_where(table, column, self.id_)