home · contact · privacy
Expand POST /todo adoption tests.
[plomtask] / plomtask / db.py
index 1cecc16f6985b555f25757ad6e9f65724311a287..67a7fc766ce607095520e174c944f957667011d7 100644 (file)
@@ -4,7 +4,7 @@ from os import listdir
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Self, TypeVar, Generic
+from typing import Any, Self, TypeVar, Generic, Callable
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
@@ -232,13 +232,16 @@ BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
 class BaseModel(Generic[BaseModelId]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
 class BaseModel(Generic[BaseModelId]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
-    to_save: list[str] = []
-    to_save_versioned: list[str] = []
+    to_save_simples: list[str] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
+    versioned_defaults: dict[str, str | float] = {}
+    add_to_dict: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
+    can_create_by_id = False
     _exists = True
     _exists = True
+    sorters: dict[str, Callable[..., Any]] = {}
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
@@ -250,11 +253,12 @@ class BaseModel(Generic[BaseModelId]):
         self.id_ = id_
 
     def __hash__(self) -> int:
         self.id_ = id_
 
     def __hash__(self) -> int:
-        hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+        hashable = [self.id_] + [getattr(self, name)
+                                 for name in self.to_save_simples]
         for definition in self.to_save_relations:
             attr = getattr(self, definition[2])
             hashable += [tuple(rel.id_ for rel in attr)]
         for definition in self.to_save_relations:
             attr = getattr(self, definition[2])
             hashable += [tuple(rel.id_ for rel in attr)]
-        for name in self.to_save_versioned:
+        for name in self.to_save_versioned():
             hashable += [hash(getattr(self, name))]
         return hash(tuple(hashable))
 
             hashable += [hash(getattr(self, name))]
         return hash(tuple(hashable))
 
@@ -271,22 +275,61 @@ class BaseModel(Generic[BaseModelId]):
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
+    @classmethod
+    def to_save_versioned(cls) -> list[str]:
+        """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
+        return list(cls.versioned_defaults.keys())
+
     @property
     @property
-    def as_dict(self) -> dict[str, object]:
-        """Return self as (json.dumps-coompatible) dict."""
+    def as_dict_and_refs(self) -> tuple[dict[str, object],
+                                        list[BaseModel[int] | BaseModel[str]]]:
+        """Return self as json.dumps-ready dict, list of referenced objects."""
         d: dict[str, object] = {'id': self.id_}
         d: dict[str, object] = {'id': self.id_}
-        for k in self.to_save:
+        refs: list[BaseModel[int] | BaseModel[str]] = []
+        for to_save in self.to_save_simples:
+            d[to_save] = getattr(self, to_save)
+        if len(self.to_save_versioned()) > 0:
+            d['_versioned'] = {}
+        for k in self.to_save_versioned():
             attr = getattr(self, k)
             attr = getattr(self, k)
-            if hasattr(attr, 'as_dict'):
-                d[k] = attr.as_dict
-            d[k] = attr
-        for k in self.to_save_versioned:
-            attr = getattr(self, k)
-            d[k] = attr.as_dict
-        for r in self.to_save_relations:
-            attr_name = r[2]
-            d[attr_name] = [x.as_dict for x in getattr(self, attr_name)]
-        return d
+            assert isinstance(d['_versioned'], dict)
+            d['_versioned'][k] = attr.history
+        rels_to_collect = [rel[2] for rel in self.to_save_relations]
+        rels_to_collect += self.add_to_dict
+        for attr_name in rels_to_collect:
+            rel_list = []
+            for item in getattr(self, attr_name):
+                rel_list += [item.id_]
+                if item not in refs:
+                    refs += [item]
+            d[attr_name] = rel_list
+        return d, refs
+
+    @classmethod
+    def name_lowercase(cls) -> str:
+        """Convenience method to return cls' name in lowercase."""
+        return cls.__name__.lower()
+
+    @classmethod
+    def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
+                ) -> str:
+        """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
+
+        Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
+        ensure predictability where parts of seq are of same sort value.
+        """
+        reverse = False
+        if len(sort_key) > 1 and '-' == sort_key[0]:
+            sort_key = sort_key[1:]
+            reverse = True
+        if sort_key not in cls.sorters:
+            sort_key = default
+        seq.sort(key=lambda x: x.id_, reverse=reverse)
+        sorter: Callable[..., Any] = cls.sorters[sort_key]
+        seq.sort(key=sorter, reverse=reverse)
+        if reverse:
+            sort_key = f'-{sort_key}'
+        return sort_key
 
     # cache management
     # (we primarily use the cache to ensure we work on the same object in
 
     # cache management
     # (we primarily use the cache to ensure we work on the same object in
@@ -312,7 +355,13 @@ class BaseModel(Generic[BaseModelId]):
 
     @classmethod
     def empty_cache(cls) -> None:
 
     @classmethod
     def empty_cache(cls) -> None:
-        """Empty class's cache."""
+        """Empty class's cache, and disappear all former inhabitants."""
+        # pylint: disable=protected-access
+        # (cause we remain within the class)
+        if hasattr(cls, 'cache_'):
+            to_disappear = list(cls.cache_.values())
+            for item in to_disappear:
+                item._disappear()
         cls.cache_ = {}
 
     @classmethod
         cls.cache_ = {}
 
     @classmethod
@@ -327,15 +376,14 @@ class BaseModel(Generic[BaseModelId]):
     def _get_cached(cls: type[BaseModelInstance],
                     id_: BaseModelId) -> BaseModelInstance | None:
         """Get object of id_ from class's cache, or None if not found."""
     def _get_cached(cls: type[BaseModelInstance],
                     id_: BaseModelId) -> BaseModelInstance | None:
         """Get object of id_ from class's cache, or None if not found."""
-        # pylint: disable=consider-iterating-dictionary
         cache = cls.get_cache()
         cache = cls.get_cache()
-        if id_ in cache.keys():
+        if id_ in cache:
             obj = cache[id_]
             assert isinstance(obj, cls)
             return obj
         return None
 
             obj = cache[id_]
             assert isinstance(obj, cls)
             return obj
         return None
 
-    def _cache(self) -> None:
+    def cache(self) -> None:
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
         """Update object in class's cache.
 
         Also calls ._disappear if cache holds older reference to object of same
@@ -366,22 +414,23 @@ class BaseModel(Generic[BaseModelId]):
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
-        """Make from DB row, update DB cache with it."""
+        """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
         obj = cls(*row)
-        obj._cache()
+        assert obj.id_ is not None
+        for attr_name in cls.to_save_versioned():
+            attr = getattr(obj, attr_name)
+            table_name = attr.table_name
+            for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
+                attr.history_from_row(row_)
+        obj.cache()
         return obj
 
     @classmethod
         return obj
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection,
-              id_: BaseModelId | None,
-              # pylint: disable=unused-argument
-              create: bool = False) -> Self:
+    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
-
-        If create=True, make anew (but do not cache yet).
         """
         obj = None
         if id_ is not None:
         """
         obj = None
         if id_ is not None:
@@ -392,11 +441,22 @@ class BaseModel(Generic[BaseModelId]):
                     break
         if obj:
             return obj
                     break
         if obj:
             return obj
-        if create:
-            obj = cls(id_)
-            return obj
         raise NotFoundException(f'found no object of ID {id_}')
 
         raise NotFoundException(f'found no object of ID {id_}')
 
+    @classmethod
+    def by_id_or_create(cls, db_conn: DatabaseConnection,
+                        id_: BaseModelId | None
+                        ) -> Self:
+        """Wrapper around .by_id, creating (not caching/saving) if not find."""
+        if not cls.can_create_by_id:
+            raise HandledException('Class cannot .by_id_or_create.')
+        if id_ is None:
+            return cls(None)
+        try:
+            return cls.by_id(db_conn, id_)
+        except NotFoundException:
+            return cls(id_)
+
     @classmethod
     def all(cls: type[BaseModelInstance],
             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
     @classmethod
     def all(cls: type[BaseModelInstance],
             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
@@ -426,7 +486,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of items in database within (open) date_range interval.
+        """Return list of items in DB within (closed) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -468,7 +528,7 @@ class BaseModel(Generic[BaseModelId]):
         """Write self to DB and cache and ensure .id_.
 
         Write both to DB, and to cache. To DB, write .id_ and attributes
         """Write self to DB and cache and ensure .id_.
 
         Write both to DB, and to cache. To DB, write .id_ and attributes
-        listed in cls.to_save[_versioned|_relations].
+        listed in cls.to_save_[simples|versioned|_relations].
 
         Ensure self.id_ by setting it to what the DB command returns as the
         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
 
         Ensure self.id_ by setting it to what the DB command returns as the
         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
@@ -476,14 +536,14 @@ class BaseModel(Generic[BaseModelId]):
         only the case with the Day class, where it's to be a date string.
         """
         values = tuple([self.id_] + [getattr(self, key)
         only the case with the Day class, where it's to be a date string.
         """
         values = tuple([self.id_] + [getattr(self, key)
-                                     for key in self.to_save])
+                                     for key in self.to_save_simples])
         table_name = self.table_name
         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         table_name = self.table_name
         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
-        self._cache()
-        for attr_name in self.to_save_versioned:
+        self.cache()
+        for attr_name in self.to_save_versioned():
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
             assert isinstance(self.id_, (int, str))
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
             assert isinstance(self.id_, (int, str))
@@ -495,7 +555,7 @@ class BaseModel(Generic[BaseModelId]):
         """Remove from DB and cache, including dependencies."""
         if self.id_ is None or self._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
         """Remove from DB and cache, including dependencies."""
         if self.id_ is None or self._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
-        for attr_name in self.to_save_versioned:
+        for attr_name in self.to_save_versioned():
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)