home · contact · privacy
Refactor saving and caching tests, treatment of None IDs.
[plomtask] / plomtask / db.py
index 3917ce0558e3f7deb1f917a07eb5dd4df7b279e6..797b08e8412809e3ac23bb7f04fc24fef955be03 100644 (file)
@@ -238,6 +238,7 @@ class BaseModel(Generic[BaseModelId]):
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
+    can_create_by_id = False
     _exists = True
 
     def __init__(self, id_: BaseModelId | None) -> None:
@@ -315,7 +316,13 @@ class BaseModel(Generic[BaseModelId]):
 
     @classmethod
     def empty_cache(cls) -> None:
-        """Empty class's cache."""
+        """Empty class's cache, and disappear all former inhabitants."""
+        # pylint: disable=protected-access
+        # (cause we remain within the class)
+        if hasattr(cls, 'cache_'):
+            to_disappear = list(cls.cache_.values())
+            for item in to_disappear:
+                item._disappear()
         cls.cache_ = {}
 
     @classmethod
@@ -338,7 +345,7 @@ class BaseModel(Generic[BaseModelId]):
             return obj
         return None
 
-    def _cache(self) -> None:
+    def cache(self) -> None:
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
@@ -377,20 +384,15 @@ class BaseModel(Generic[BaseModelId]):
             table_name = attr.table_name
             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
                 attr.history_from_row(row_)
-        obj._cache()
+        obj.cache()
         return obj
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection,
-              id_: BaseModelId | None,
-              # pylint: disable=unused-argument
-              create: bool = False) -> Self:
+    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
-
-        If create=True, make anew (but do not cache yet).
         """
         obj = None
         if id_ is not None:
@@ -401,11 +403,22 @@ class BaseModel(Generic[BaseModelId]):
                     break
         if obj:
             return obj
-        if create:
-            obj = cls(id_)
-            return obj
         raise NotFoundException(f'found no object of ID {id_}')
 
+    @classmethod
+    def by_id_or_create(cls, db_conn: DatabaseConnection,
+                        id_: BaseModelId | None
+                        ) -> Self:
+        """Wrapper around .by_id, creating (not caching/saving) if not find."""
+        if not cls.can_create_by_id:
+            raise HandledException('Class cannot .by_id_or_create.')
+        if id_ is None:
+            return cls(None)
+        try:
+            return cls.by_id(db_conn, id_)
+        except NotFoundException:
+            return cls(id_)
+
     @classmethod
     def all(cls: type[BaseModelInstance],
             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
@@ -491,7 +504,7 @@ class BaseModel(Generic[BaseModelId]):
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
-        self._cache()
+        self.cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations: