X-Git-Url: https://plomlompom.com/repos/%7B%7B%20web_path%20%7D%7D/decks/%7B%7Bdeck_id%7D%7D/cards/%7B%7Bcard_id%7D%7D/static/git-logo.png?a=blobdiff_plain;f=plomtask%2Fdb.py;h=df98dd0f130bbd75553b2e628cd739d793e98616;hb=b4a34a415fb31a00ee1e092fcc2a6b5d97edd52a;hp=b2f2142c9c6957c19e90674270a1635082050f59;hpb=99672306cdb97d76d00829b2e491f2df0abcbbd5;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index b2f2142..df98dd0 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -250,14 +250,19 @@ class BaseModel(Generic[BaseModelId]): raise HandledException(msg) self.id_ = id_ + def __hash__(self) -> int: + hashable = [self.id_] + [getattr(self, name) for name in self.to_save] + for definition in self.to_save_relations: + attr = getattr(self, definition[2]) + hashable += [tuple(rel.id_ for rel in attr)] + for name in self.to_save_versioned: + hashable += [hash(getattr(self, name))] + return hash(tuple(hashable)) + def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): return False - to_hash_me = tuple([self.id_] + - [getattr(self, name) for name in self.to_save]) - to_hash_other = tuple([other.id_] + - [getattr(other, name) for name in other.to_save]) - return hash(to_hash_me) == hash(to_hash_other) + return hash(self) == hash(other) def __lt__(self, other: Any) -> bool: if not isinstance(other, self.__class__): @@ -267,9 +272,11 @@ class BaseModel(Generic[BaseModelId]): assert isinstance(other.id_, int) return self.id_ < other.id_ + # cache management + @classmethod - def get_cached(cls: type[BaseModelInstance], - id_: BaseModelId) -> BaseModelInstance | None: + def _get_cached(cls: type[BaseModelInstance], + id_: BaseModelId) -> BaseModelInstance | None: """Get object of id_ from class's cache, or None if not found.""" # pylint: disable=consider-iterating-dictionary cache = cls.get_cache() @@ -306,6 +313,8 @@ class BaseModel(Generic[BaseModelId]): cache = self.__class__.get_cache() del cache[self.id_] + # object retrieval and generation + @classmethod def from_table_row(cls: type[BaseModelInstance], # pylint: disable=unused-argument @@ -330,7 +339,7 @@ class BaseModel(Generic[BaseModelId]): """ obj = None if id_ is not None: - obj = cls.get_cached(id_) + obj = cls._get_cached(id_) if not obj: for row in db_conn.row_where(cls.table_name, 'id', id_): obj = cls.from_table_row(db_conn, row) @@ -408,6 +417,8 @@ class BaseModel(Generic[BaseModelId]): return filtered return items + # database writing + def save(self, db_conn: DatabaseConnection) -> None: """Write self to DB and cache and ensure .id_. @@ -437,7 +448,9 @@ class BaseModel(Generic[BaseModelId]): def remove(self, db_conn: DatabaseConnection) -> None: """Remove from DB and cache, including dependencies.""" - if self.id_ is None or self.__class__.get_cached(self.id_) is None: + # pylint: disable=protected-access + # (since we remain within class) + if self.id_ is None or self.__class__._get_cached(self.id_) is None: raise HandledException('cannot remove unsaved item') for attr_name in self.to_save_versioned: getattr(self, attr_name).remove(db_conn)