X-Git-Url: https://plomlompom.com/repos//%22https:/validator.w3.org/check?a=blobdiff_plain;f=plomtask%2Fdb.py;h=cce2630cd58bfb8bf7283c2eb0d2f45006d8ba26;hb=c021152e6566c8374170de916c69d6b5c816cd54;hp=385e79855286c45087537b3b63b8b024fd553ca0;hpb=1e4c7cd5cde09a5c58bc601cae3f5a49eb615399;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index 385e798..cce2630 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -238,6 +238,7 @@ class BaseModel(Generic[BaseModelId]): id_: None | BaseModelId cache_: dict[BaseModelId, Self] to_search: list[str] = [] + can_create_by_id = False _exists = True def __init__(self, id_: BaseModelId | None) -> None: @@ -336,15 +337,14 @@ class BaseModel(Generic[BaseModelId]): def _get_cached(cls: type[BaseModelInstance], id_: BaseModelId) -> BaseModelInstance | None: """Get object of id_ from class's cache, or None if not found.""" - # pylint: disable=consider-iterating-dictionary cache = cls.get_cache() - if id_ in cache.keys(): + if id_ in cache: obj = cache[id_] assert isinstance(obj, cls) return obj return None - def _cache(self) -> None: + def cache(self) -> None: """Update object in class's cache. Also calls ._disappear if cache holds older reference to object of same @@ -383,20 +383,15 @@ class BaseModel(Generic[BaseModelId]): table_name = attr.table_name for row_ in db_conn.row_where(table_name, 'parent', obj.id_): attr.history_from_row(row_) - obj._cache() + obj.cache() return obj @classmethod - def by_id(cls, db_conn: DatabaseConnection, - id_: BaseModelId | None, - # pylint: disable=unused-argument - create: bool = False) -> Self: + def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self: """Retrieve by id_, on failure throw NotFoundException. First try to get from cls.cache_, only then check DB; if found, put into cache. - - If create=True, make anew (but do not cache yet). """ obj = None if id_ is not None: @@ -407,11 +402,22 @@ class BaseModel(Generic[BaseModelId]): break if obj: return obj - if create: - obj = cls(id_) - return obj raise NotFoundException(f'found no object of ID {id_}') + @classmethod + def by_id_or_create(cls, db_conn: DatabaseConnection, + id_: BaseModelId | None + ) -> Self: + """Wrapper around .by_id, creating (not caching/saving) if not find.""" + if not cls.can_create_by_id: + raise HandledException('Class cannot .by_id_or_create.') + if id_ is None: + return cls(None) + try: + return cls.by_id(db_conn, id_) + except NotFoundException: + return cls(id_) + @classmethod def all(cls: type[BaseModelInstance], db_conn: DatabaseConnection) -> list[BaseModelInstance]: @@ -497,7 +503,7 @@ class BaseModel(Generic[BaseModelId]): values) if not isinstance(self.id_, str): self.id_ = cursor.lastrowid # type: ignore[assignment] - self._cache() + self.cache() for attr_name in self.to_save_versioned: getattr(self, attr_name).save(db_conn) for table, column, attr_name, key_index in self.to_save_relations: