From: Christian Heller Date: Fri, 14 Jun 2024 21:06:09 +0000 (+0200) Subject: Minor BaseModel code re-organization. X-Git-Url: https://plomlompom.com/repos/te"st.html?a=commitdiff_plain;h=b4a34a415fb31a00ee1e092fcc2a6b5d97edd52a;p=plomtask Minor BaseModel code re-organization. --- diff --git a/plomtask/db.py b/plomtask/db.py index a47dff1..df98dd0 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -272,9 +272,11 @@ class BaseModel(Generic[BaseModelId]): assert isinstance(other.id_, int) return self.id_ < other.id_ + # cache management + @classmethod - def get_cached(cls: type[BaseModelInstance], - id_: BaseModelId) -> BaseModelInstance | None: + def _get_cached(cls: type[BaseModelInstance], + id_: BaseModelId) -> BaseModelInstance | None: """Get object of id_ from class's cache, or None if not found.""" # pylint: disable=consider-iterating-dictionary cache = cls.get_cache() @@ -311,6 +313,8 @@ class BaseModel(Generic[BaseModelId]): cache = self.__class__.get_cache() del cache[self.id_] + # object retrieval and generation + @classmethod def from_table_row(cls: type[BaseModelInstance], # pylint: disable=unused-argument @@ -335,7 +339,7 @@ class BaseModel(Generic[BaseModelId]): """ obj = None if id_ is not None: - obj = cls.get_cached(id_) + obj = cls._get_cached(id_) if not obj: for row in db_conn.row_where(cls.table_name, 'id', id_): obj = cls.from_table_row(db_conn, row) @@ -413,6 +417,8 @@ class BaseModel(Generic[BaseModelId]): return filtered return items + # database writing + def save(self, db_conn: DatabaseConnection) -> None: """Write self to DB and cache and ensure .id_. @@ -442,7 +448,9 @@ class BaseModel(Generic[BaseModelId]): def remove(self, db_conn: DatabaseConnection) -> None: """Remove from DB and cache, including dependencies.""" - if self.id_ is None or self.__class__.get_cached(self.id_) is None: + # pylint: disable=protected-access + # (since we remain within class) + if self.id_ is None or self.__class__._get_cached(self.id_) is None: raise HandledException('cannot remove unsaved item') for attr_name in self.to_save_versioned: getattr(self, attr_name).remove(db_conn)