home · contact · privacy
Slightly improve and re-organize Condition tests.
[plomtask] / plomtask / db.py
index 99998a6ab29f760ba0d62f90395739dad4b521ff..13cdaef5b9c7d3e992f8c92730a9979b9eee2d73 100644 (file)
@@ -4,7 +4,7 @@ from os import listdir
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Self, TypeVar, Generic
+from typing import Any, Self, TypeVar, Generic, Callable
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
@@ -235,10 +235,13 @@ class BaseModel(Generic[BaseModelId]):
     to_save: list[str] = []
     to_save_versioned: list[str] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
+    add_to_dict: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
+    can_create_by_id = False
     _exists = True
+    sorters: dict[str, Callable[..., Any]] = {}
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
@@ -271,6 +274,84 @@ class BaseModel(Generic[BaseModelId]):
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
+    @property
+    def as_dict(self) -> dict[str, object]:
+        """Return self as (json.dumps-compatible) dict."""
+        library: dict[str, dict[str | int, object]] = {}
+        d: dict[str, object] = {'id': self.id_, '_library': library}
+        for to_save in self.to_save:
+            attr = getattr(self, to_save)
+            if hasattr(attr, 'as_dict_into_reference'):
+                d[to_save] = attr.as_dict_into_reference(library)
+            else:
+                d[to_save] = attr
+        if len(self.to_save_versioned) > 0:
+            d['_versioned'] = {}
+        for k in self.to_save_versioned:
+            attr = getattr(self, k)
+            assert isinstance(d['_versioned'], dict)
+            d['_versioned'][k] = attr.history
+        for r in self.to_save_relations:
+            attr_name = r[2]
+            l: list[int | str] = []
+            for rel in getattr(self, attr_name):
+                l += [rel.as_dict_into_reference(library)]
+            d[attr_name] = l
+        for k in self.add_to_dict:
+            d[k] = [x.as_dict_into_reference(library)
+                    for x in getattr(self, k)]
+        return d
+
+    def as_dict_into_reference(self,
+                               library: dict[str, dict[str | int, object]]
+                               ) -> int | str:
+        """Return self.id_ while writing .as_dict into library."""
+        def into_library(library: dict[str, dict[str | int, object]],
+                         cls_name: str,
+                         id_: str | int,
+                         d: dict[str, object]
+                         ) -> None:
+            if cls_name not in library:
+                library[cls_name] = {}
+            if id_ in library[cls_name]:
+                if library[cls_name][id_] != d:
+                    msg = 'Unexpected inequality of entries for ' +\
+                            f'_library at: {cls_name}/{id_}'
+                    raise HandledException(msg)
+            else:
+                library[cls_name][id_] = d
+        as_dict = self.as_dict
+        assert isinstance(as_dict['_library'], dict)
+        for cls_name, dict_of_objs in as_dict['_library'].items():
+            for id_, obj in dict_of_objs.items():
+                into_library(library, cls_name, id_, obj)
+        del as_dict['_library']
+        assert self.id_ is not None
+        into_library(library, self.__class__.__name__, self.id_, as_dict)
+        assert isinstance(as_dict['id'], (int, str))
+        return as_dict['id']
+
+    @classmethod
+    def name_lowercase(cls) -> str:
+        """Convenience method to return cls' name in lowercase."""
+        return cls.__name__.lower()
+
+    @classmethod
+    def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
+                ) -> str:
+        """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed)."""
+        reverse = False
+        if len(sort_key) > 1 and '-' == sort_key[0]:
+            sort_key = sort_key[1:]
+            reverse = True
+        if sort_key not in cls.sorters:
+            sort_key = default
+        sorter: Callable[..., Any] = cls.sorters[sort_key]
+        seq.sort(key=sorter, reverse=reverse)
+        if reverse:
+            sort_key = f'-{sort_key}'
+        return sort_key
+
     # cache management
     # (we primarily use the cache to ensure we work on the same object in
     # memory no matter where and how we retrieve it, e.g. we don't want
@@ -295,7 +376,13 @@ class BaseModel(Generic[BaseModelId]):
 
     @classmethod
     def empty_cache(cls) -> None:
-        """Empty class's cache."""
+        """Empty class's cache, and disappear all former inhabitants."""
+        # pylint: disable=protected-access
+        # (cause we remain within the class)
+        if hasattr(cls, 'cache_'):
+            to_disappear = list(cls.cache_.values())
+            for item in to_disappear:
+                item._disappear()
         cls.cache_ = {}
 
     @classmethod
@@ -310,15 +397,14 @@ class BaseModel(Generic[BaseModelId]):
     def _get_cached(cls: type[BaseModelInstance],
                     id_: BaseModelId) -> BaseModelInstance | None:
         """Get object of id_ from class's cache, or None if not found."""
-        # pylint: disable=consider-iterating-dictionary
         cache = cls.get_cache()
-        if id_ in cache.keys():
+        if id_ in cache:
             obj = cache[id_]
             assert isinstance(obj, cls)
             return obj
         return None
 
-    def _cache(self) -> None:
+    def cache(self) -> None:
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
@@ -349,22 +435,23 @@ class BaseModel(Generic[BaseModelId]):
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
-        """Make from DB row, update DB cache with it."""
+        """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
-        obj._cache()
+        assert obj.id_ is not None
+        for attr_name in cls.to_save_versioned:
+            attr = getattr(obj, attr_name)
+            table_name = attr.table_name
+            for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
+                attr.history_from_row(row_)
+        obj.cache()
         return obj
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection,
-              id_: BaseModelId | None,
-              # pylint: disable=unused-argument
-              create: bool = False) -> Self:
+    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
-
-        If create=True, make anew (but do not cache yet).
         """
         obj = None
         if id_ is not None:
@@ -375,11 +462,22 @@ class BaseModel(Generic[BaseModelId]):
                     break
         if obj:
             return obj
-        if create:
-            obj = cls(id_)
-            return obj
         raise NotFoundException(f'found no object of ID {id_}')
 
+    @classmethod
+    def by_id_or_create(cls, db_conn: DatabaseConnection,
+                        id_: BaseModelId | None
+                        ) -> Self:
+        """Wrapper around .by_id, creating (not caching/saving) if not find."""
+        if not cls.can_create_by_id:
+            raise HandledException('Class cannot .by_id_or_create.')
+        if id_ is None:
+            return cls(None)
+        try:
+            return cls.by_id(db_conn, id_)
+        except NotFoundException:
+            return cls(id_)
+
     @classmethod
     def all(cls: type[BaseModelInstance],
             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
@@ -465,7 +563,7 @@ class BaseModel(Generic[BaseModelId]):
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
-        self._cache()
+        self.cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations: