home · contact · privacy
Slightly reduce the do_POST_todo code.
[plomtask] / plomtask / db.py
index 13cdaef5b9c7d3e992f8c92730a9979b9eee2d73..ee5f3b99f04e66460ef4e0a90c5d10d600210881 100644 (file)
@@ -232,9 +232,9 @@ BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
 class BaseModel(Generic[BaseModelId]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
-    to_save: list[str] = []
-    to_save_versioned: list[str] = []
+    to_save_simples: list[str] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
+    versioned_defaults: dict[str, str | float] = {}
     add_to_dict: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
@@ -253,11 +253,12 @@ class BaseModel(Generic[BaseModelId]):
         self.id_ = id_
 
     def __hash__(self) -> int:
-        hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+        hashable = [self.id_] + [getattr(self, name)
+                                 for name in self.to_save_simples]
         for definition in self.to_save_relations:
             attr = getattr(self, definition[2])
             hashable += [tuple(rel.id_ for rel in attr)]
-        for name in self.to_save_versioned:
+        for name in self.to_save_versioned():
             hashable += [hash(getattr(self, name))]
         return hash(tuple(hashable))
 
@@ -274,62 +275,35 @@ class BaseModel(Generic[BaseModelId]):
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
+    @classmethod
+    def to_save_versioned(cls) -> list[str]:
+        """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
+        return list(cls.versioned_defaults.keys())
+
     @property
-    def as_dict(self) -> dict[str, object]:
-        """Return self as (json.dumps-compatible) dict."""
-        library: dict[str, dict[str | int, object]] = {}
-        d: dict[str, object] = {'id': self.id_, '_library': library}
-        for to_save in self.to_save:
-            attr = getattr(self, to_save)
-            if hasattr(attr, 'as_dict_into_reference'):
-                d[to_save] = attr.as_dict_into_reference(library)
-            else:
-                d[to_save] = attr
-        if len(self.to_save_versioned) > 0:
+    def as_dict_and_refs(self) -> tuple[dict[str, object],
+                                        list[BaseModel[int] | BaseModel[str]]]:
+        """Return self as json.dumps-ready dict, list of referenced objects."""
+        d: dict[str, object] = {'id': self.id_}
+        refs: list[BaseModel[int] | BaseModel[str]] = []
+        for to_save in self.to_save_simples:
+            d[to_save] = getattr(self, to_save)
+        if len(self.to_save_versioned()) > 0:
             d['_versioned'] = {}
-        for k in self.to_save_versioned:
+        for k in self.to_save_versioned():
             attr = getattr(self, k)
             assert isinstance(d['_versioned'], dict)
             d['_versioned'][k] = attr.history
-        for r in self.to_save_relations:
-            attr_name = r[2]
-            l: list[int | str] = []
-            for rel in getattr(self, attr_name):
-                l += [rel.as_dict_into_reference(library)]
-            d[attr_name] = l
-        for k in self.add_to_dict:
-            d[k] = [x.as_dict_into_reference(library)
-                    for x in getattr(self, k)]
-        return d
-
-    def as_dict_into_reference(self,
-                               library: dict[str, dict[str | int, object]]
-                               ) -> int | str:
-        """Return self.id_ while writing .as_dict into library."""
-        def into_library(library: dict[str, dict[str | int, object]],
-                         cls_name: str,
-                         id_: str | int,
-                         d: dict[str, object]
-                         ) -> None:
-            if cls_name not in library:
-                library[cls_name] = {}
-            if id_ in library[cls_name]:
-                if library[cls_name][id_] != d:
-                    msg = 'Unexpected inequality of entries for ' +\
-                            f'_library at: {cls_name}/{id_}'
-                    raise HandledException(msg)
-            else:
-                library[cls_name][id_] = d
-        as_dict = self.as_dict
-        assert isinstance(as_dict['_library'], dict)
-        for cls_name, dict_of_objs in as_dict['_library'].items():
-            for id_, obj in dict_of_objs.items():
-                into_library(library, cls_name, id_, obj)
-        del as_dict['_library']
-        assert self.id_ is not None
-        into_library(library, self.__class__.__name__, self.id_, as_dict)
-        assert isinstance(as_dict['id'], (int, str))
-        return as_dict['id']
+        rels_to_collect = [rel[2] for rel in self.to_save_relations]
+        rels_to_collect += self.add_to_dict
+        for attr_name in rels_to_collect:
+            rel_list = []
+            for item in getattr(self, attr_name):
+                rel_list += [item.id_]
+                if item not in refs:
+                    refs += [item]
+            d[attr_name] = rel_list
+        return d, refs
 
     @classmethod
     def name_lowercase(cls) -> str:
@@ -339,13 +313,18 @@ class BaseModel(Generic[BaseModelId]):
     @classmethod
     def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
                 ) -> str:
-        """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed)."""
+        """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
+
+        Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
+        ensure predictability where parts of seq are of same sort value.
+        """
         reverse = False
         if len(sort_key) > 1 and '-' == sort_key[0]:
             sort_key = sort_key[1:]
             reverse = True
         if sort_key not in cls.sorters:
             sort_key = default
+        seq.sort(key=lambda x: x.id_, reverse=reverse)
         sorter: Callable[..., Any] = cls.sorters[sort_key]
         seq.sort(key=sorter, reverse=reverse)
         if reverse:
@@ -361,7 +340,8 @@ class BaseModel(Generic[BaseModelId]):
     def __getattribute__(self, name: str) -> Any:
         """Ensure fail if ._disappear() was called, except to check ._exists"""
         if name != '_exists' and not super().__getattribute__('_exists'):
-            raise HandledException('Object does not exist.')
+            msg = f'Object for attribute does not exist: {name}'
+            raise HandledException(msg)
         return super().__getattribute__(name)
 
     def _disappear(self) -> None:
@@ -386,10 +366,11 @@ class BaseModel(Generic[BaseModelId]):
         cls.cache_ = {}
 
     @classmethod
-    def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
+    def get_cache(cls: type[BaseModelInstance]
+                  ) -> dict[Any, BaseModelInstance]:
         """Get cache dictionary, create it if not yet existing."""
         if not hasattr(cls, 'cache_'):
-            d: dict[Any, BaseModel[Any]] = {}
+            d: dict[Any, BaseModelInstance] = {}
             cls.cache_ = d
         return cls.cache_
 
@@ -438,7 +419,7 @@ class BaseModel(Generic[BaseModelId]):
         """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
         assert obj.id_ is not None
-        for attr_name in cls.to_save_versioned:
+        for attr_name in cls.to_save_versioned():
             attr = getattr(obj, attr_name)
             table_name = attr.table_name
             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
@@ -498,7 +479,7 @@ class BaseModel(Generic[BaseModelId]):
                 item = cls.by_id(db_conn, id_)
                 assert item.id_ is not None
                 items[item.id_] = item
-        return list(items.values())
+        return sorted(list(items.values()))
 
     @classmethod
     def by_date_range_with_limits(cls: type[BaseModelInstance],
@@ -507,7 +488,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of items in database within (open) date_range interval.
+        """Return list of items in DB within (closed) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -549,7 +530,7 @@ class BaseModel(Generic[BaseModelId]):
         """Write self to DB and cache and ensure .id_.
 
         Write both to DB, and to cache. To DB, write .id_ and attributes
-        listed in cls.to_save[_versioned|_relations].
+        listed in cls.to_save_[simples|versioned|_relations].
 
         Ensure self.id_ by setting it to what the DB command returns as the
         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
@@ -557,14 +538,14 @@ class BaseModel(Generic[BaseModelId]):
         only the case with the Day class, where it's to be a date string.
         """
         values = tuple([self.id_] + [getattr(self, key)
-                                     for key in self.to_save])
+                                     for key in self.to_save_simples])
         table_name = self.table_name
         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
                                       values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
-        for attr_name in self.to_save_versioned:
+        for attr_name in self.to_save_versioned():
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
             assert isinstance(self.id_, (int, str))
@@ -576,7 +557,7 @@ class BaseModel(Generic[BaseModelId]):
         """Remove from DB and cache, including dependencies."""
         if self.id_ is None or self._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
-        for attr_name in self.to_save_versioned:
+        for attr_name in self.to_save_versioned():
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)