home · contact · privacy
Expand POST /todo adoption tests.
[plomtask] / plomtask / db.py
index cce2630cd58bfb8bf7283c2eb0d2f45006d8ba26..67a7fc766ce607095520e174c944f957667011d7 100644 (file)
@@ -4,7 +4,7 @@ from os import listdir
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Self, TypeVar, Generic
+from typing import Any, Self, TypeVar, Generic, Callable
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
@@ -232,14 +232,16 @@ BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
 class BaseModel(Generic[BaseModelId]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
 class BaseModel(Generic[BaseModelId]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
-    to_save: list[str] = []
-    to_save_versioned: list[str] = []
+    to_save_simples: list[str] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
+    versioned_defaults: dict[str, str | float] = {}
+    add_to_dict: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
     can_create_by_id = False
     _exists = True
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
     can_create_by_id = False
     _exists = True
+    sorters: dict[str, Callable[..., Any]] = {}
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
@@ -251,11 +253,12 @@ class BaseModel(Generic[BaseModelId]):
         self.id_ = id_
 
     def __hash__(self) -> int:
         self.id_ = id_
 
     def __hash__(self) -> int:
-        hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+        hashable = [self.id_] + [getattr(self, name)
+                                 for name in self.to_save_simples]
         for definition in self.to_save_relations:
             attr = getattr(self, definition[2])
             hashable += [tuple(rel.id_ for rel in attr)]
         for definition in self.to_save_relations:
             attr = getattr(self, definition[2])
             hashable += [tuple(rel.id_ for rel in attr)]
-        for name in self.to_save_versioned:
+        for name in self.to_save_versioned():
             hashable += [hash(getattr(self, name))]
         return hash(tuple(hashable))
 
             hashable += [hash(getattr(self, name))]
         return hash(tuple(hashable))
 
@@ -272,25 +275,61 @@ class BaseModel(Generic[BaseModelId]):
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
+    @classmethod
+    def to_save_versioned(cls) -> list[str]:
+        """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
+        return list(cls.versioned_defaults.keys())
+
     @property
     @property
-    def as_dict(self) -> dict[str, object]:
-        """Return self as (json.dumps-coompatible) dict."""
+    def as_dict_and_refs(self) -> tuple[dict[str, object],
+                                        list[BaseModel[int] | BaseModel[str]]]:
+        """Return self as json.dumps-ready dict, list of referenced objects."""
         d: dict[str, object] = {'id': self.id_}
         d: dict[str, object] = {'id': self.id_}
-        if len(self.to_save_versioned) > 0:
+        refs: list[BaseModel[int] | BaseModel[str]] = []
+        for to_save in self.to_save_simples:
+            d[to_save] = getattr(self, to_save)
+        if len(self.to_save_versioned()) > 0:
             d['_versioned'] = {}
             d['_versioned'] = {}
-        for k in self.to_save:
-            attr = getattr(self, k)
-            if hasattr(attr, 'as_dict'):
-                d[k] = attr.as_dict
-            d[k] = attr
-        for k in self.to_save_versioned:
+        for k in self.to_save_versioned():
             attr = getattr(self, k)
             assert isinstance(d['_versioned'], dict)
             d['_versioned'][k] = attr.history
             attr = getattr(self, k)
             assert isinstance(d['_versioned'], dict)
             d['_versioned'][k] = attr.history
-        for r in self.to_save_relations:
-            attr_name = r[2]
-            d[attr_name] = [x.as_dict for x in getattr(self, attr_name)]
-        return d
+        rels_to_collect = [rel[2] for rel in self.to_save_relations]
+        rels_to_collect += self.add_to_dict
+        for attr_name in rels_to_collect:
+            rel_list = []
+            for item in getattr(self, attr_name):
+                rel_list += [item.id_]
+                if item not in refs:
+                    refs += [item]
+            d[attr_name] = rel_list
+        return d, refs
+
+    @classmethod
+    def name_lowercase(cls) -> str:
+        """Convenience method to return cls' name in lowercase."""
+        return cls.__name__.lower()
+
+    @classmethod
+    def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
+                ) -> str:
+        """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
+
+        Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
+        ensure predictability where parts of seq are of same sort value.
+        """
+        reverse = False
+        if len(sort_key) > 1 and '-' == sort_key[0]:
+            sort_key = sort_key[1:]
+            reverse = True
+        if sort_key not in cls.sorters:
+            sort_key = default
+        seq.sort(key=lambda x: x.id_, reverse=reverse)
+        sorter: Callable[..., Any] = cls.sorters[sort_key]
+        seq.sort(key=sorter, reverse=reverse)
+        if reverse:
+            sort_key = f'-{sort_key}'
+        return sort_key
 
     # cache management
     # (we primarily use the cache to ensure we work on the same object in
 
     # cache management
     # (we primarily use the cache to ensure we work on the same object in
@@ -378,7 +417,7 @@ class BaseModel(Generic[BaseModelId]):
         """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
         assert obj.id_ is not None
         """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
         assert obj.id_ is not None
-        for attr_name in cls.to_save_versioned:
+        for attr_name in cls.to_save_versioned():
             attr = getattr(obj, attr_name)
             table_name = attr.table_name
             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
             attr = getattr(obj, attr_name)
             table_name = attr.table_name
             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
@@ -447,7 +486,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of items in database within (open) date_range interval.
+        """Return list of items in DB within (closed) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -489,7 +528,7 @@ class BaseModel(Generic[BaseModelId]):
         """Write self to DB and cache and ensure .id_.
 
         Write both to DB, and to cache. To DB, write .id_ and attributes
         """Write self to DB and cache and ensure .id_.
 
         Write both to DB, and to cache. To DB, write .id_ and attributes
-        listed in cls.to_save[_versioned|_relations].
+        listed in cls.to_save_[simples|versioned|_relations].
 
         Ensure self.id_ by setting it to what the DB command returns as the
         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
 
         Ensure self.id_ by setting it to what the DB command returns as the
         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
@@ -497,14 +536,14 @@ class BaseModel(Generic[BaseModelId]):
         only the case with the Day class, where it's to be a date string.
         """
         values = tuple([self.id_] + [getattr(self, key)
         only the case with the Day class, where it's to be a date string.
         """
         values = tuple([self.id_] + [getattr(self, key)
-                                     for key in self.to_save])
+                                     for key in self.to_save_simples])
         table_name = self.table_name
         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
         table_name = self.table_name
         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
-        for attr_name in self.to_save_versioned:
+        for attr_name in self.to_save_versioned():
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
             assert isinstance(self.id_, (int, str))
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
             assert isinstance(self.id_, (int, str))
@@ -516,7 +555,7 @@ class BaseModel(Generic[BaseModelId]):
         """Remove from DB and cache, including dependencies."""
         if self.id_ is None or self._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
         """Remove from DB and cache, including dependencies."""
         if self.id_ is None or self._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
-        for attr_name in self.to_save_versioned:
+        for attr_name in self.to_save_versioned():
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)