From: Christian Heller Date: Thu, 25 Apr 2024 05:06:00 +0000 (+0200) Subject: Fix bug of same dict being used for different Classes' caches. X-Git-Url: https://plomlompom.com/repos/%7B%7B%20web_path%20%7D%7D/decks/%7B%7Bprefix%7D%7D/%7B%7Btodo.date%7D%7D?a=commitdiff_plain;h=7b6b8d0b93b1d4dd85152e49e7105aacc647327c;p=plomtask Fix bug of same dict being used for different Classes' caches. --- diff --git a/plomtask/db.py b/plomtask/db.py index 2d9ae27..1753da4 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -121,7 +121,7 @@ class BaseModel(Generic[BaseModelId]): table_name = '' to_save: list[str] = [] id_: None | BaseModelId - cache_: dict[BaseModelId, Self] = {} + cache_: dict[BaseModelId, Self] @classmethod def from_table_row(cls: type[BaseModelInstance], @@ -174,8 +174,9 @@ class BaseModel(Generic[BaseModelId]): id_: BaseModelId) -> BaseModelInstance | None: """Get object of id_ from class's cache, or None if not found.""" # pylint: disable=consider-iterating-dictionary - if id_ in cls.cache_.keys(): - obj = cls.cache_[id_] + cache = cls.get_cache() + if id_ in cache.keys(): + obj = cache[id_] assert isinstance(obj, cls) return obj return None @@ -184,13 +185,15 @@ class BaseModel(Generic[BaseModelId]): """Update object in class's cache.""" if self.id_ is None: raise HandledException('Cannot cache object without ID.') - self.__class__.cache_[self.id_] = self + cache = self.__class__.get_cache() + cache[self.id_] = self def uncache(self) -> None: """Remove self from cache.""" if self.id_ is None: raise HandledException('Cannot un-cache object without ID.') - del self.__class__.cache_[self.id_] + cache = self.__class__.get_cache() + del cache[self.id_] @classmethod def empty_cache(cls) -> None: @@ -202,7 +205,7 @@ class BaseModel(Generic[BaseModelId]): db_conn: DatabaseConnection) -> list[BaseModelInstance]: """Collect all objects of class.""" items: dict[BaseModelId, BaseModelInstance] = {} - for k, v in cls.cache_.items(): + for k, v in cls.get_cache().items(): assert isinstance(v, cls) items[k] = v already_recorded = items.keys() @@ -212,3 +215,11 @@ class BaseModel(Generic[BaseModelId]): item = cls.by_id(db_conn, id_) # type: ignore[attr-defined] items[item.id_] = item return list(items.values()) + + @classmethod + def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]: + """Get cache dictionary, create it if not yet existing.""" + if not hasattr(cls, 'cache_'): + d: dict[Any, BaseModel[Any]] = {} + cls.cache_ = d + return cls.cache_