home · contact · privacy
Some test clean-ups.
[plomtask] / plomtask / db.py
index 6f0d13a97c4b54e64aa88aeb24258498a5645f91..f067cd35246850d2600c05c31a5e57fcc3d2d925 100644 (file)
@@ -5,7 +5,8 @@ from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from typing import Any, Self, TypeVar, Generic, Callable
-from plomtask.exceptions import HandledException, NotFoundException
+from plomtask.exceptions import (HandledException, NotFoundException,
+                                 BadFormatException)
 from plomtask.dating import valid_date
 
 EXPECTED_DB_VERSION = 5
@@ -246,10 +247,10 @@ class BaseModel(Generic[BaseModelId]):
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
-            raise HandledException(msg)
+            raise BadFormatException(msg)
         if isinstance(id_, str) and "" == id_:
             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
-            raise HandledException(msg)
+            raise BadFormatException(msg)
         self.id_ = id_
 
     def __hash__(self) -> int:
@@ -281,68 +282,29 @@ class BaseModel(Generic[BaseModelId]):
         return list(cls.versioned_defaults.keys())
 
     @property
-    def as_dict(self) -> dict[str, object]:
-        """Return self as (json.dumps-compatible) dict."""
-        library: dict[str, dict[str, object] | dict[int, object]] = {}
-        d: dict[str, object] = {'id': self.id_, '_library': library}
+    def as_dict_and_refs(self) -> tuple[dict[str, object],
+                                        list[BaseModel[int] | BaseModel[str]]]:
+        """Return self as json.dumps-ready dict, list of referenced objects."""
+        d: dict[str, object] = {'id': self.id_}
+        refs: list[BaseModel[int] | BaseModel[str]] = []
         for to_save in self.to_save_simples:
-            attr = getattr(self, to_save)
-            if hasattr(attr, 'as_dict_into_reference'):
-                d[to_save] = attr.as_dict_into_reference(library)
-            else:
-                d[to_save] = attr
+            d[to_save] = getattr(self, to_save)
         if len(self.to_save_versioned()) > 0:
             d['_versioned'] = {}
         for k in self.to_save_versioned():
             attr = getattr(self, k)
             assert isinstance(d['_versioned'], dict)
             d['_versioned'][k] = attr.history
-        for r in self.to_save_relations:
-            attr_name = r[2]
-            l: list[int | str] = []
-            for rel in getattr(self, attr_name):
-                l += [rel.as_dict_into_reference(library)]
-            d[attr_name] = l
-        for k in self.add_to_dict:
-            d[k] = [x.as_dict_into_reference(library)
-                    for x in getattr(self, k)]
-        return d
-
-    def as_dict_into_reference(self,
-                               library: dict[str, dict[str | int, object]]
-                               ) -> int | str:
-        """Return self.id_ while writing .as_dict into library."""
-        # NB: For tighter mypy testing, we might prefer the library argument
-        # to be of type dict[str, dict[str, object] | dict[int, object]
-        # instead. But my current coding knowledge only manage to make that
-        # work by turning the code much more complex, so let's leave it at
-        # that for now …
-
-        def into_library(library: dict[str, dict[str | int, object]],
-                         cls_name: str,
-                         id_: str | int,
-                         d: dict[str, object]
-                         ) -> None:
-            if cls_name not in library:
-                library[cls_name] = {}
-            if id_ in library[cls_name]:
-                if library[cls_name][id_] != d:
-                    msg = 'Unexpected inequality of entries for ' +\
-                            f'_library at: {cls_name}/{id_}'
-                    raise HandledException(msg)
-            else:
-                library[cls_name][id_] = d
-
-        as_dict = self.as_dict
-        assert isinstance(as_dict['_library'], dict)
-        for cls_name, dict_of_objs in as_dict['_library'].items():
-            for id_, obj in dict_of_objs.items():
-                into_library(library, cls_name, id_, obj)
-        del as_dict['_library']
-        assert self.id_ is not None
-        into_library(library, self.__class__.__name__, self.id_, as_dict)
-        assert isinstance(as_dict['id'], (int, str))
-        return as_dict['id']
+        rels_to_collect = [rel[2] for rel in self.to_save_relations]
+        rels_to_collect += self.add_to_dict
+        for attr_name in rels_to_collect:
+            rel_list = []
+            for item in getattr(self, attr_name):
+                rel_list += [item.id_]
+                if item not in refs:
+                    refs += [item]
+            d[attr_name] = rel_list
+        return d, refs
 
     @classmethod
     def name_lowercase(cls) -> str:
@@ -379,7 +341,8 @@ class BaseModel(Generic[BaseModelId]):
     def __getattribute__(self, name: str) -> Any:
         """Ensure fail if ._disappear() was called, except to check ._exists"""
         if name != '_exists' and not super().__getattribute__('_exists'):
-            raise HandledException('Object does not exist.')
+            msg = f'Object for attribute does not exist: {name}'
+            raise HandledException(msg)
         return super().__getattribute__(name)
 
     def _disappear(self) -> None:
@@ -404,16 +367,18 @@ class BaseModel(Generic[BaseModelId]):
         cls.cache_ = {}
 
     @classmethod
-    def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
+    def get_cache(cls: type[BaseModelInstance]
+                  ) -> dict[Any, BaseModelInstance]:
         """Get cache dictionary, create it if not yet existing."""
         if not hasattr(cls, 'cache_'):
-            d: dict[Any, BaseModel[Any]] = {}
+            d: dict[Any, BaseModelInstance] = {}
             cls.cache_ = d
         return cls.cache_
 
     @classmethod
     def _get_cached(cls: type[BaseModelInstance],
-                    id_: BaseModelId) -> BaseModelInstance | None:
+                    id_: BaseModelId
+                    ) -> BaseModelInstance | None:
         """Get object of id_ from class's cache, or None if not found."""
         cache = cls.get_cache()
         if id_ in cache:
@@ -473,6 +438,8 @@ class BaseModel(Generic[BaseModelId]):
         """
         obj = None
         if id_ is not None:
+            if isinstance(id_, int) and id_ == 0:
+                raise BadFormatException('illegal ID of value 0')
             obj = cls._get_cached(id_)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
@@ -486,7 +453,7 @@ class BaseModel(Generic[BaseModelId]):
     def by_id_or_create(cls, db_conn: DatabaseConnection,
                         id_: BaseModelId | None
                         ) -> Self:
-        """Wrapper around .by_id, creating (not caching/saving) if not find."""
+        """Wrapper around .by_id, creating (not caching/saving) if no find."""
         if not cls.can_create_by_id:
             raise HandledException('Class cannot .by_id_or_create.')
         if id_ is None:
@@ -516,7 +483,7 @@ class BaseModel(Generic[BaseModelId]):
                 item = cls.by_id(db_conn, id_)
                 assert item.id_ is not None
                 items[item.id_] = item
-        return list(items.values())
+        return sorted(list(items.values()))
 
     @classmethod
     def by_date_range_with_limits(cls: type[BaseModelInstance],