from os.path import isfile
from difflib import Differ
from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Self, TypeVar, Generic
+from typing import Any, Self, TypeVar, Generic, Callable
from plomtask.exceptions import HandledException, NotFoundException
from plomtask.dating import valid_date
@property
def _user_version(self) -> int:
"""Get DB user_version."""
- # pylint: disable=protected-access
- # (since we remain within class)
- return self.__class__._get_version_of_db(self.path)
+ return self._get_version_of_db(self.path)
def _validate_schema(self) -> None:
"""Compare found schema with what's stored at PATH_DB_SCHEMA."""
class BaseModel(Generic[BaseModelId]):
"""Template for most of the models we use/derive from the DB."""
table_name = ''
- to_save: list[str] = []
- to_save_versioned: list[str] = []
+ to_save_simples: list[str] = []
to_save_relations: list[tuple[str, str, str, int]] = []
+ versioned_defaults: dict[str, str | float] = {}
+ add_to_dict: list[str] = []
id_: None | BaseModelId
cache_: dict[BaseModelId, Self]
to_search: list[str] = []
+ can_create_by_id = False
+ _exists = True
+ sorters: dict[str, Callable[..., Any]] = {}
def __init__(self, id_: BaseModelId | None) -> None:
if isinstance(id_, int) and id_ < 1:
self.id_ = id_
def __hash__(self) -> int:
- hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+ hashable = [self.id_] + [getattr(self, name)
+ for name in self.to_save_simples]
for definition in self.to_save_relations:
attr = getattr(self, definition[2])
hashable += [tuple(rel.id_ for rel in attr)]
- for name in self.to_save_versioned:
+ for name in self.to_save_versioned():
hashable += [hash(getattr(self, name))]
return hash(tuple(hashable))
assert isinstance(other.id_, int)
return self.id_ < other.id_
- # cache management
+ @classmethod
+ def to_save_versioned(cls) -> list[str]:
+ """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
+ return list(cls.versioned_defaults.keys())
+
+ @property
+ def as_dict_and_refs(self) -> tuple[dict[str, object],
+ list[BaseModel[int] | BaseModel[str]]]:
+ """Return self as json.dumps-ready dict, list of referenced objects."""
+ d: dict[str, object] = {'id': self.id_}
+ refs: list[BaseModel[int] | BaseModel[str]] = []
+ for to_save in self.to_save_simples:
+ d[to_save] = getattr(self, to_save)
+ if len(self.to_save_versioned()) > 0:
+ d['_versioned'] = {}
+ for k in self.to_save_versioned():
+ attr = getattr(self, k)
+ assert isinstance(d['_versioned'], dict)
+ d['_versioned'][k] = attr.history
+ rels_to_collect = [rel[2] for rel in self.to_save_relations]
+ rels_to_collect += self.add_to_dict
+ for attr_name in rels_to_collect:
+ rel_list = []
+ for item in getattr(self, attr_name):
+ rel_list += [item.id_]
+ if item not in refs:
+ refs += [item]
+ d[attr_name] = rel_list
+ return d, refs
@classmethod
- def _get_cached(cls: type[BaseModelInstance],
- id_: BaseModelId) -> BaseModelInstance | None:
- """Get object of id_ from class's cache, or None if not found."""
- # pylint: disable=consider-iterating-dictionary
- cache = cls.get_cache()
- if id_ in cache.keys():
- obj = cache[id_]
- assert isinstance(obj, cls)
- return obj
- return None
+ def name_lowercase(cls) -> str:
+ """Convenience method to return cls' name in lowercase."""
+ return cls.__name__.lower()
+
+ @classmethod
+ def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
+ ) -> str:
+ """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
+
+ Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
+ ensure predictability where parts of seq are of same sort value.
+ """
+ reverse = False
+ if len(sort_key) > 1 and '-' == sort_key[0]:
+ sort_key = sort_key[1:]
+ reverse = True
+ if sort_key not in cls.sorters:
+ sort_key = default
+ seq.sort(key=lambda x: x.id_, reverse=reverse)
+ sorter: Callable[..., Any] = cls.sorters[sort_key]
+ seq.sort(key=sorter, reverse=reverse)
+ if reverse:
+ sort_key = f'-{sort_key}'
+ return sort_key
+
+ # cache management
+ # (we primarily use the cache to ensure we work on the same object in
+ # memory no matter where and how we retrieve it, e.g. we don't want
+ # .by_id() calls to create a new object each time, but rather a pointer
+ # to the one already instantiated)
+
+ def __getattribute__(self, name: str) -> Any:
+ """Ensure fail if ._disappear() was called, except to check ._exists"""
+ if name != '_exists' and not super().__getattribute__('_exists'):
+ msg = f'Object for attribute does not exist: {name}'
+ raise HandledException(msg)
+ return super().__getattribute__(name)
+
+ def _disappear(self) -> None:
+ """Invalidate object, make future use raise exceptions."""
+ assert self.id_ is not None
+ if self._get_cached(self.id_):
+ self._uncache()
+ to_kill = list(self.__dict__.keys())
+ for attr in to_kill:
+ delattr(self, attr)
+ self._exists = False
@classmethod
def empty_cache(cls) -> None:
- """Empty class's cache."""
+ """Empty class's cache, and disappear all former inhabitants."""
+ # pylint: disable=protected-access
+ # (cause we remain within the class)
+ if hasattr(cls, 'cache_'):
+ to_disappear = list(cls.cache_.values())
+ for item in to_disappear:
+ item._disappear()
cls.cache_ = {}
@classmethod
- def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
+ def get_cache(cls: type[BaseModelInstance]
+ ) -> dict[Any, BaseModelInstance]:
"""Get cache dictionary, create it if not yet existing."""
if not hasattr(cls, 'cache_'):
- d: dict[Any, BaseModel[Any]] = {}
+ d: dict[Any, BaseModelInstance] = {}
cls.cache_ = d
return cls.cache_
+ @classmethod
+ def _get_cached(cls: type[BaseModelInstance],
+ id_: BaseModelId) -> BaseModelInstance | None:
+ """Get object of id_ from class's cache, or None if not found."""
+ cache = cls.get_cache()
+ if id_ in cache:
+ obj = cache[id_]
+ assert isinstance(obj, cls)
+ return obj
+ return None
+
def cache(self) -> None:
- """Update object in class's cache."""
+ """Update object in class's cache.
+
+ Also calls ._disappear if cache holds older reference to object of same
+ ID, but different memory address, to avoid doing anything with
+ dangling leftovers.
+ """
if self.id_ is None:
raise HandledException('Cannot cache object without ID.')
- cache = self.__class__.get_cache()
+ cache = self.get_cache()
+ old_cached = self._get_cached(self.id_)
+ if old_cached and id(old_cached) != id(self):
+ # pylint: disable=protected-access
+ # (cause we remain within the class)
+ old_cached._disappear()
cache[self.id_] = self
- def uncache(self) -> None:
+ def _uncache(self) -> None:
"""Remove self from cache."""
if self.id_ is None:
raise HandledException('Cannot un-cache object without ID.')
- cache = self.__class__.get_cache()
+ cache = self.get_cache()
del cache[self.id_]
# object retrieval and generation
# pylint: disable=unused-argument
db_conn: DatabaseConnection,
row: Row | list[Any]) -> BaseModelInstance:
- """Make from DB row, write to DB cache."""
+ """Make from DB row (sans relations), update DB cache with it."""
obj = cls(*row)
+ assert obj.id_ is not None
+ for attr_name in cls.to_save_versioned():
+ attr = getattr(obj, attr_name)
+ table_name = attr.table_name
+ for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
+ attr.history_from_row(row_)
obj.cache()
return obj
@classmethod
- def by_id(cls, db_conn: DatabaseConnection,
- id_: BaseModelId | None,
- # pylint: disable=unused-argument
- create: bool = False) -> Self:
+ def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
"""Retrieve by id_, on failure throw NotFoundException.
First try to get from cls.cache_, only then check DB; if found,
put into cache.
-
- If create=True, make anew (but do not cache yet).
"""
obj = None
if id_ is not None:
if not obj:
for row in db_conn.row_where(cls.table_name, 'id', id_):
obj = cls.from_table_row(db_conn, row)
- obj.cache()
break
if obj:
return obj
- if create:
- obj = cls(id_)
- return obj
raise NotFoundException(f'found no object of ID {id_}')
+ @classmethod
+ def by_id_or_create(cls, db_conn: DatabaseConnection,
+ id_: BaseModelId | None
+ ) -> Self:
+ """Wrapper around .by_id, creating (not caching/saving) if not find."""
+ if not cls.can_create_by_id:
+ raise HandledException('Class cannot .by_id_or_create.')
+ if id_ is None:
+ return cls(None)
+ try:
+ return cls.by_id(db_conn, id_)
+ except NotFoundException:
+ return cls(id_)
+
@classmethod
def all(cls: type[BaseModelInstance],
db_conn: DatabaseConnection) -> list[BaseModelInstance]:
item = cls.by_id(db_conn, id_)
assert item.id_ is not None
items[item.id_] = item
- return list(items.values())
+ return sorted(list(items.values()))
@classmethod
def by_date_range_with_limits(cls: type[BaseModelInstance],
date_col: str = 'day'
) -> tuple[list[BaseModelInstance], str,
str]:
- """Return list of items in database within (open) date_range interval.
+ """Return list of items in DB within (closed) date_range interval.
If no range values provided, defaults them to 'yesterday' and
'tomorrow'. Knows to properly interpret these and 'today' as value.
"""Write self to DB and cache and ensure .id_.
Write both to DB, and to cache. To DB, write .id_ and attributes
- listed in cls.to_save[_versioned|_relations].
+ listed in cls.to_save_[simples|versioned|_relations].
Ensure self.id_ by setting it to what the DB command returns as the
last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
only the case with the Day class, where it's to be a date string.
"""
values = tuple([self.id_] + [getattr(self, key)
- for key in self.to_save])
+ for key in self.to_save_simples])
table_name = self.table_name
cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
values)
if not isinstance(self.id_, str):
self.id_ = cursor.lastrowid # type: ignore[assignment]
self.cache()
- for attr_name in self.to_save_versioned:
+ for attr_name in self.to_save_versioned():
getattr(self, attr_name).save(db_conn)
for table, column, attr_name, key_index in self.to_save_relations:
assert isinstance(self.id_, (int, str))
def remove(self, db_conn: DatabaseConnection) -> None:
"""Remove from DB and cache, including dependencies."""
- # pylint: disable=protected-access
- # (since we remain within class)
- if self.id_ is None or self.__class__._get_cached(self.id_) is None:
+ if self.id_ is None or self._get_cached(self.id_) is None:
raise HandledException('cannot remove unsaved item')
- for attr_name in self.to_save_versioned:
+ for attr_name in self.to_save_versioned():
getattr(self, attr_name).remove(db_conn)
for table, column, attr_name, _ in self.to_save_relations:
db_conn.delete_where(table, column, self.id_)
- self.uncache()
+ self._uncache()
db_conn.delete_where(self.table_name, 'id', self.id_)
+ self._disappear()