X-Git-Url: https://plomlompom.com/repos/todo?a=blobdiff_plain;f=plomtask%2Fdb.py;h=b3f1db00986b1142f5f31be34060864840ab5bdc;hb=21df71ef1fde304b158da5989692c01f463515b5;hp=1cecc16f6985b555f25757ad6e9f65724311a287;hpb=db62e6559fdd577dae38d4b6f5cbd5ef6a14cc57;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index 1cecc16..b3f1db0 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -238,6 +238,7 @@ class BaseModel(Generic[BaseModelId]): id_: None | BaseModelId cache_: dict[BaseModelId, Self] to_search: list[str] = [] + can_create_by_id = False _exists = True def __init__(self, id_: BaseModelId | None) -> None: @@ -273,21 +274,58 @@ class BaseModel(Generic[BaseModelId]): @property def as_dict(self) -> dict[str, object]: - """Return self as (json.dumps-coompatible) dict.""" - d: dict[str, object] = {'id': self.id_} - for k in self.to_save: - attr = getattr(self, k) - if hasattr(attr, 'as_dict'): - d[k] = attr.as_dict - d[k] = attr + """Return self as (json.dumps-compatible) dict.""" + library: dict[str, dict[str | int, object]] = {} + d: dict[str, object] = {'id': self.id_, '_library': library} + for to_save in self.to_save: + attr = getattr(self, to_save) + if hasattr(attr, 'as_dict_into_reference'): + d[to_save] = attr.as_dict_into_reference(library) + else: + d[to_save] = attr + if len(self.to_save_versioned) > 0: + d['_versioned'] = {} for k in self.to_save_versioned: attr = getattr(self, k) - d[k] = attr.as_dict + assert isinstance(d['_versioned'], dict) + d['_versioned'][k] = attr.history for r in self.to_save_relations: attr_name = r[2] - d[attr_name] = [x.as_dict for x in getattr(self, attr_name)] + l: list[int | str] = [] + for rel in getattr(self, attr_name): + l += [rel.as_dict_into_reference(library)] + d[attr_name] = l return d + def as_dict_into_reference(self, + library: dict[str, dict[str | int, object]] + ) -> int | str: + """Return self.id_ while writing .as_dict into library.""" + def into_library(library: dict[str, dict[str | int, object]], + cls_name: str, + id_: str | int, + d: dict[str, object] + ) -> None: + if cls_name not in library: + library[cls_name] = {} + if id_ in library[cls_name]: + if library[cls_name][id_] != d: + msg = 'Unexpected inequality of entries for ' +\ + f'_library at: {cls_name}/{id_}' + raise HandledException(msg) + else: + library[cls_name][id_] = d + as_dict = self.as_dict + assert isinstance(as_dict['_library'], dict) + for cls_name, dict_of_objs in as_dict['_library'].items(): + for id_, obj in dict_of_objs.items(): + into_library(library, cls_name, id_, obj) + del as_dict['_library'] + assert self.id_ is not None + into_library(library, self.__class__.__name__, self.id_, as_dict) + assert isinstance(as_dict['id'], (int, str)) + return as_dict['id'] + # cache management # (we primarily use the cache to ensure we work on the same object in # memory no matter where and how we retrieve it, e.g. we don't want @@ -312,7 +350,13 @@ class BaseModel(Generic[BaseModelId]): @classmethod def empty_cache(cls) -> None: - """Empty class's cache.""" + """Empty class's cache, and disappear all former inhabitants.""" + # pylint: disable=protected-access + # (cause we remain within the class) + if hasattr(cls, 'cache_'): + to_disappear = list(cls.cache_.values()) + for item in to_disappear: + item._disappear() cls.cache_ = {} @classmethod @@ -327,15 +371,14 @@ class BaseModel(Generic[BaseModelId]): def _get_cached(cls: type[BaseModelInstance], id_: BaseModelId) -> BaseModelInstance | None: """Get object of id_ from class's cache, or None if not found.""" - # pylint: disable=consider-iterating-dictionary cache = cls.get_cache() - if id_ in cache.keys(): + if id_ in cache: obj = cache[id_] assert isinstance(obj, cls) return obj return None - def _cache(self) -> None: + def cache(self) -> None: """Update object in class's cache. Also calls ._disappear if cache holds older reference to object of same @@ -366,22 +409,23 @@ class BaseModel(Generic[BaseModelId]): # pylint: disable=unused-argument db_conn: DatabaseConnection, row: Row | list[Any]) -> BaseModelInstance: - """Make from DB row, update DB cache with it.""" + """Make from DB row (sans relations), update DB cache with it.""" obj = cls(*row) - obj._cache() + assert obj.id_ is not None + for attr_name in cls.to_save_versioned: + attr = getattr(obj, attr_name) + table_name = attr.table_name + for row_ in db_conn.row_where(table_name, 'parent', obj.id_): + attr.history_from_row(row_) + obj.cache() return obj @classmethod - def by_id(cls, db_conn: DatabaseConnection, - id_: BaseModelId | None, - # pylint: disable=unused-argument - create: bool = False) -> Self: + def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self: """Retrieve by id_, on failure throw NotFoundException. First try to get from cls.cache_, only then check DB; if found, put into cache. - - If create=True, make anew (but do not cache yet). """ obj = None if id_ is not None: @@ -392,11 +436,22 @@ class BaseModel(Generic[BaseModelId]): break if obj: return obj - if create: - obj = cls(id_) - return obj raise NotFoundException(f'found no object of ID {id_}') + @classmethod + def by_id_or_create(cls, db_conn: DatabaseConnection, + id_: BaseModelId | None + ) -> Self: + """Wrapper around .by_id, creating (not caching/saving) if not find.""" + if not cls.can_create_by_id: + raise HandledException('Class cannot .by_id_or_create.') + if id_ is None: + return cls(None) + try: + return cls.by_id(db_conn, id_) + except NotFoundException: + return cls(id_) + @classmethod def all(cls: type[BaseModelInstance], db_conn: DatabaseConnection) -> list[BaseModelInstance]: @@ -482,7 +537,7 @@ class BaseModel(Generic[BaseModelId]): values) if not isinstance(self.id_, str): self.id_ = cursor.lastrowid # type: ignore[assignment] - self._cache() + self.cache() for attr_name in self.to_save_versioned: getattr(self, attr_name).save(db_conn) for table, column, attr_name, key_index in self.to_save_relations: