from os.path import isfile
from difflib import Differ
from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Self, TypeVar, Generic
+from typing import Any, Self, TypeVar, Generic, Callable
from plomtask.exceptions import HandledException, NotFoundException
from plomtask.dating import valid_date
class BaseModel(Generic[BaseModelId]):
"""Template for most of the models we use/derive from the DB."""
table_name = ''
- to_save: list[str] = []
- to_save_versioned: list[str] = []
+ to_save_simples: list[str] = []
to_save_relations: list[tuple[str, str, str, int]] = []
+ versioned_defaults: dict[str, str | float] = {}
+ add_to_dict: list[str] = []
id_: None | BaseModelId
cache_: dict[BaseModelId, Self]
to_search: list[str] = []
can_create_by_id = False
_exists = True
+ sorters: dict[str, Callable[..., Any]] = {}
def __init__(self, id_: BaseModelId | None) -> None:
if isinstance(id_, int) and id_ < 1:
self.id_ = id_
def __hash__(self) -> int:
- hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+ hashable = [self.id_] + [getattr(self, name)
+ for name in self.to_save_simples]
for definition in self.to_save_relations:
attr = getattr(self, definition[2])
hashable += [tuple(rel.id_ for rel in attr)]
- for name in self.to_save_versioned:
+ for name in self.to_save_versioned():
hashable += [hash(getattr(self, name))]
return hash(tuple(hashable))
assert isinstance(other.id_, int)
return self.id_ < other.id_
+ @classmethod
+ def to_save_versioned(cls) -> list[str]:
+ """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
+ return list(cls.versioned_defaults.keys())
+
@property
- def as_dict(self) -> dict[str, object]:
- """Return self as (json.dumps-coompatible) dict."""
+ def as_dict_and_refs(self) -> tuple[dict[str, object],
+ list[BaseModel[int] | BaseModel[str]]]:
+ """Return self as json.dumps-ready dict, list of referenced objects."""
d: dict[str, object] = {'id': self.id_}
- if len(self.to_save_versioned) > 0:
+ refs: list[BaseModel[int] | BaseModel[str]] = []
+ for to_save in self.to_save_simples:
+ d[to_save] = getattr(self, to_save)
+ if len(self.to_save_versioned()) > 0:
d['_versioned'] = {}
- for k in self.to_save:
- attr = getattr(self, k)
- if hasattr(attr, 'as_dict'):
- d[k] = attr.as_dict
- d[k] = attr
- for k in self.to_save_versioned:
+ for k in self.to_save_versioned():
attr = getattr(self, k)
assert isinstance(d['_versioned'], dict)
d['_versioned'][k] = attr.history
- for r in self.to_save_relations:
- attr_name = r[2]
- d[attr_name] = [x.as_dict for x in getattr(self, attr_name)]
- return d
+ rels_to_collect = [rel[2] for rel in self.to_save_relations]
+ rels_to_collect += self.add_to_dict
+ for attr_name in rels_to_collect:
+ rel_list = []
+ for item in getattr(self, attr_name):
+ rel_list += [item.id_]
+ if item not in refs:
+ refs += [item]
+ d[attr_name] = rel_list
+ return d, refs
+
+ @classmethod
+ def name_lowercase(cls) -> str:
+ """Convenience method to return cls' name in lowercase."""
+ return cls.__name__.lower()
+
+ @classmethod
+ def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
+ ) -> str:
+ """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
+
+ Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
+ ensure predictability where parts of seq are of same sort value.
+ """
+ reverse = False
+ if len(sort_key) > 1 and '-' == sort_key[0]:
+ sort_key = sort_key[1:]
+ reverse = True
+ if sort_key not in cls.sorters:
+ sort_key = default
+ seq.sort(key=lambda x: x.id_, reverse=reverse)
+ sorter: Callable[..., Any] = cls.sorters[sort_key]
+ seq.sort(key=sorter, reverse=reverse)
+ if reverse:
+ sort_key = f'-{sort_key}'
+ return sort_key
# cache management
# (we primarily use the cache to ensure we work on the same object in
def __getattribute__(self, name: str) -> Any:
"""Ensure fail if ._disappear() was called, except to check ._exists"""
if name != '_exists' and not super().__getattribute__('_exists'):
- raise HandledException('Object does not exist.')
+ msg = f'Object for attribute does not exist: {name}'
+ raise HandledException(msg)
return super().__getattribute__(name)
def _disappear(self) -> None:
cls.cache_ = {}
@classmethod
- def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
+ def get_cache(cls: type[BaseModelInstance]
+ ) -> dict[Any, BaseModelInstance]:
"""Get cache dictionary, create it if not yet existing."""
if not hasattr(cls, 'cache_'):
- d: dict[Any, BaseModel[Any]] = {}
+ d: dict[Any, BaseModelInstance] = {}
cls.cache_ = d
return cls.cache_
"""Make from DB row (sans relations), update DB cache with it."""
obj = cls(*row)
assert obj.id_ is not None
- for attr_name in cls.to_save_versioned:
+ for attr_name in cls.to_save_versioned():
attr = getattr(obj, attr_name)
table_name = attr.table_name
for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
item = cls.by_id(db_conn, id_)
assert item.id_ is not None
items[item.id_] = item
- return list(items.values())
+ return sorted(list(items.values()))
@classmethod
def by_date_range_with_limits(cls: type[BaseModelInstance],
date_col: str = 'day'
) -> tuple[list[BaseModelInstance], str,
str]:
- """Return list of items in database within (open) date_range interval.
+ """Return list of items in DB within (closed) date_range interval.
If no range values provided, defaults them to 'yesterday' and
'tomorrow'. Knows to properly interpret these and 'today' as value.
"""Write self to DB and cache and ensure .id_.
Write both to DB, and to cache. To DB, write .id_ and attributes
- listed in cls.to_save[_versioned|_relations].
+ listed in cls.to_save_[simples|versioned|_relations].
Ensure self.id_ by setting it to what the DB command returns as the
last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
only the case with the Day class, where it's to be a date string.
"""
values = tuple([self.id_] + [getattr(self, key)
- for key in self.to_save])
+ for key in self.to_save_simples])
table_name = self.table_name
cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
values)
if not isinstance(self.id_, str):
self.id_ = cursor.lastrowid # type: ignore[assignment]
self.cache()
- for attr_name in self.to_save_versioned:
+ for attr_name in self.to_save_versioned():
getattr(self, attr_name).save(db_conn)
for table, column, attr_name, key_index in self.to_save_relations:
assert isinstance(self.id_, (int, str))
"""Remove from DB and cache, including dependencies."""
if self.id_ is None or self._get_cached(self.id_) is None:
raise HandledException('cannot remove unsaved item')
- for attr_name in self.to_save_versioned:
+ for attr_name in self.to_save_versioned():
getattr(self, attr_name).remove(db_conn)
for table, column, attr_name, _ in self.to_save_relations:
db_conn.delete_where(table, column, self.id_)