home · contact · privacy
Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
[plomtask] / plomtask / db.py
index 853b4c68c65780e339b77785b961788370373648..f6ef1cb6724a4a8cb13cc3d74d8c745338294463 100644 (file)
@@ -238,6 +238,7 @@ class BaseModel(Generic[BaseModelId]):
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
+    can_create_by_id = False
     _exists = True
 
     def __init__(self, id_: BaseModelId | None) -> None:
@@ -388,15 +389,12 @@ class BaseModel(Generic[BaseModelId]):
 
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection,
-              id_: BaseModelId | None,
-              # pylint: disable=unused-argument
-              create: bool = False) -> Self:
+              id_: BaseModelId | None
+              ) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
-
-        If create=True, make anew (but do not cache yet).
         """
         obj = None
         if id_ is not None:
@@ -407,10 +405,20 @@ class BaseModel(Generic[BaseModelId]):
                     break
         if obj:
             return obj
-        if create:
+        raise NotFoundException(f'found no object of ID {id_}')
+
+    @classmethod
+    def by_id_or_create(cls, db_conn: DatabaseConnection,
+                        id_: BaseModelId | None
+                        ) -> Self:
+        """Wrapper around .by_id, creating (not caching/saving) if not find."""
+        if not cls.can_create_by_id:
+            raise HandledException('Class cannot .by_id_or_create.')
+        try:
+            return cls.by_id(db_conn, id_)
+        except NotFoundException:
             obj = cls(id_)
             return obj
-        raise NotFoundException(f'found no object of ID {id_}')
 
     @classmethod
     def all(cls: type[BaseModelInstance],