home · contact · privacy
Refactor BaseModel.from_table_row testing.
[plomtask] / plomtask / db.py
index 054060e13b7c1fa4aa27274efab13ff1c2828815..853b4c68c65780e339b77785b961788370373648 100644 (file)
@@ -315,7 +315,13 @@ class BaseModel(Generic[BaseModelId]):
 
     @classmethod
     def empty_cache(cls) -> None:
-        """Empty class's cache."""
+        """Empty class's cache, and disappear all former inhabitants."""
+        # pylint: disable=protected-access
+        # (cause we remain within the class)
+        if hasattr(cls, 'cache_'):
+            to_disappear = list(cls.cache_.values())
+            for item in to_disappear:
+                item._disappear()
         cls.cache_ = {}
 
     @classmethod
@@ -338,7 +344,7 @@ class BaseModel(Generic[BaseModelId]):
             return obj
         return None
 
-    def _cache(self) -> None:
+    def cache(self) -> None:
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
@@ -369,9 +375,15 @@ class BaseModel(Generic[BaseModelId]):
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
-        """Make from DB row, update DB cache with it."""
+        """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
-        obj._cache()
+        assert obj.id_ is not None
+        for attr_name in cls.to_save_versioned:
+            attr = getattr(obj, attr_name)
+            table_name = attr.table_name
+            for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
+                attr.history_from_row(row_)
+        obj.cache()
         return obj
 
     @classmethod
@@ -485,7 +497,7 @@ class BaseModel(Generic[BaseModelId]):
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
-        self._cache()
+        self.cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations: