home · contact · privacy
Slightly improve and re-organize Condition tests.
[plomtask] / plomtask / db.py
index 2ea7421feec578dab3fb8ba9e1eecf547eb6a36c..13cdaef5b9c7d3e992f8c92730a9979b9eee2d73 100644 (file)
@@ -4,7 +4,7 @@ from os import listdir
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
-from typing import Any, Self, TypeVar, Generic
+from typing import Any, Self, TypeVar, Generic, Callable
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
 from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.dating import valid_date
 
@@ -18,8 +18,9 @@ class UnmigratedDbException(HandledException):
     """To identify case of unmigrated DB file."""
 
 
     """To identify case of unmigrated DB file."""
 
 
-class DatabaseFile:  # pylint: disable=too-few-public-methods
+class DatabaseFile:
     """Represents the sqlite3 database's file."""
     """Represents the sqlite3 database's file."""
+    # pylint: disable=too-few-public-methods
 
     def __init__(self, path: str) -> None:
         self.path = path
 
     def __init__(self, path: str) -> None:
         self.path = path
@@ -38,7 +39,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
-        from_version = cls.get_version_of_db(path)
+        from_version = cls._get_version_of_db(path)
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
@@ -54,7 +55,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
-        if self.user_version != EXPECTED_DB_VERSION:
+        if self._user_version != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
         self._validate_schema()
 
             raise UnmigratedDbException()
         self._validate_schema()
 
@@ -86,7 +87,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return migrations_list
 
     @staticmethod
         return migrations_list
 
     @staticmethod
-    def get_version_of_db(path: str) -> int:
+    def _get_version_of_db(path: str) -> int:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
@@ -99,9 +100,9 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return db_version
 
     @property
         return db_version
 
     @property
-    def user_version(self) -> int:
+    def _user_version(self) -> int:
         """Get DB user_version."""
         """Get DB user_version."""
-        return self.__class__.get_version_of_db(self.path)
+        return self._get_version_of_db(self.path)
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
@@ -156,8 +157,7 @@ class DatabaseConnection:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
-        self.file = db_file
-        self.conn = sql_connect(self.file.path)
+        self.conn = sql_connect(db_file.path)
 
     def commit(self) -> None:
         """Commit SQL transaction."""
 
     def commit(self) -> None:
         """Commit SQL transaction."""
@@ -167,6 +167,11 @@ class DatabaseConnection:
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
+    def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
+        """Wrapper around .exec appending adequate " (?, …)" to code."""
+        q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
+        return self.exec(f'{code} {q_marks_from_values}', inputs)
+
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
@@ -183,8 +188,7 @@ class DatabaseConnection:
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
-            q_marks = self.__class__.q_marks_from_values(values)
-            self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
+            self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
@@ -220,11 +224,6 @@ class DatabaseConnection:
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
-    @staticmethod
-    def q_marks_from_values(values: tuple[Any]) -> str:
-        """Return placeholder to insert values into SQL code."""
-        return '(' + ','.join(['?'] * len(values)) + ')'
-
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
@@ -236,24 +235,36 @@ class BaseModel(Generic[BaseModelId]):
     to_save: list[str] = []
     to_save_versioned: list[str] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
     to_save: list[str] = []
     to_save_versioned: list[str] = []
     to_save_relations: list[tuple[str, str, str, int]] = []
+    add_to_dict: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
+    can_create_by_id = False
+    _exists = True
+    sorters: dict[str, Callable[..., Any]] = {}
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
+        if isinstance(id_, str) and "" == id_:
+            msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
+            raise HandledException(msg)
         self.id_ = id_
 
         self.id_ = id_
 
+    def __hash__(self) -> int:
+        hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+        for definition in self.to_save_relations:
+            attr = getattr(self, definition[2])
+            hashable += [tuple(rel.id_ for rel in attr)]
+        for name in self.to_save_versioned:
+            hashable += [hash(getattr(self, name))]
+        return hash(tuple(hashable))
+
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
-        to_hash_me = tuple([self.id_] +
-                           [getattr(self, name) for name in self.to_save])
-        to_hash_other = tuple([other.id_] +
-                              [getattr(other, name) for name in other.to_save])
-        return hash(to_hash_me) == hash(to_hash_other)
+        return hash(self) == hash(other)
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
@@ -263,21 +274,115 @@ class BaseModel(Generic[BaseModelId]):
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
+    @property
+    def as_dict(self) -> dict[str, object]:
+        """Return self as (json.dumps-compatible) dict."""
+        library: dict[str, dict[str | int, object]] = {}
+        d: dict[str, object] = {'id': self.id_, '_library': library}
+        for to_save in self.to_save:
+            attr = getattr(self, to_save)
+            if hasattr(attr, 'as_dict_into_reference'):
+                d[to_save] = attr.as_dict_into_reference(library)
+            else:
+                d[to_save] = attr
+        if len(self.to_save_versioned) > 0:
+            d['_versioned'] = {}
+        for k in self.to_save_versioned:
+            attr = getattr(self, k)
+            assert isinstance(d['_versioned'], dict)
+            d['_versioned'][k] = attr.history
+        for r in self.to_save_relations:
+            attr_name = r[2]
+            l: list[int | str] = []
+            for rel in getattr(self, attr_name):
+                l += [rel.as_dict_into_reference(library)]
+            d[attr_name] = l
+        for k in self.add_to_dict:
+            d[k] = [x.as_dict_into_reference(library)
+                    for x in getattr(self, k)]
+        return d
+
+    def as_dict_into_reference(self,
+                               library: dict[str, dict[str | int, object]]
+                               ) -> int | str:
+        """Return self.id_ while writing .as_dict into library."""
+        def into_library(library: dict[str, dict[str | int, object]],
+                         cls_name: str,
+                         id_: str | int,
+                         d: dict[str, object]
+                         ) -> None:
+            if cls_name not in library:
+                library[cls_name] = {}
+            if id_ in library[cls_name]:
+                if library[cls_name][id_] != d:
+                    msg = 'Unexpected inequality of entries for ' +\
+                            f'_library at: {cls_name}/{id_}'
+                    raise HandledException(msg)
+            else:
+                library[cls_name][id_] = d
+        as_dict = self.as_dict
+        assert isinstance(as_dict['_library'], dict)
+        for cls_name, dict_of_objs in as_dict['_library'].items():
+            for id_, obj in dict_of_objs.items():
+                into_library(library, cls_name, id_, obj)
+        del as_dict['_library']
+        assert self.id_ is not None
+        into_library(library, self.__class__.__name__, self.id_, as_dict)
+        assert isinstance(as_dict['id'], (int, str))
+        return as_dict['id']
+
     @classmethod
     @classmethod
-    def get_cached(cls: type[BaseModelInstance],
-                   id_: BaseModelId) -> BaseModelInstance | None:
-        """Get object of id_ from class's cache, or None if not found."""
-        # pylint: disable=consider-iterating-dictionary
-        cache = cls.get_cache()
-        if id_ in cache.keys():
-            obj = cache[id_]
-            assert isinstance(obj, cls)
-            return obj
-        return None
+    def name_lowercase(cls) -> str:
+        """Convenience method to return cls' name in lowercase."""
+        return cls.__name__.lower()
+
+    @classmethod
+    def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
+                ) -> str:
+        """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed)."""
+        reverse = False
+        if len(sort_key) > 1 and '-' == sort_key[0]:
+            sort_key = sort_key[1:]
+            reverse = True
+        if sort_key not in cls.sorters:
+            sort_key = default
+        sorter: Callable[..., Any] = cls.sorters[sort_key]
+        seq.sort(key=sorter, reverse=reverse)
+        if reverse:
+            sort_key = f'-{sort_key}'
+        return sort_key
+
+    # cache management
+    # (we primarily use the cache to ensure we work on the same object in
+    # memory no matter where and how we retrieve it, e.g. we don't want
+    # .by_id() calls to create a new object each time, but rather a pointer
+    # to the one already instantiated)
+
+    def __getattribute__(self, name: str) -> Any:
+        """Ensure fail if ._disappear() was called, except to check ._exists"""
+        if name != '_exists' and not super().__getattribute__('_exists'):
+            raise HandledException('Object does not exist.')
+        return super().__getattribute__(name)
+
+    def _disappear(self) -> None:
+        """Invalidate object, make future use raise exceptions."""
+        assert self.id_ is not None
+        if self._get_cached(self.id_):
+            self._uncache()
+        to_kill = list(self.__dict__.keys())
+        for attr in to_kill:
+            delattr(self, attr)
+        self._exists = False
 
     @classmethod
     def empty_cache(cls) -> None:
 
     @classmethod
     def empty_cache(cls) -> None:
-        """Empty class's cache."""
+        """Empty class's cache, and disappear all former inhabitants."""
+        # pylint: disable=protected-access
+        # (cause we remain within the class)
+        if hasattr(cls, 'cache_'):
+            to_disappear = list(cls.cache_.values())
+            for item in to_disappear:
+                item._disappear()
         cls.cache_ = {}
 
     @classmethod
         cls.cache_ = {}
 
     @classmethod
@@ -288,57 +393,91 @@ class BaseModel(Generic[BaseModelId]):
             cls.cache_ = d
         return cls.cache_
 
             cls.cache_ = d
         return cls.cache_
 
+    @classmethod
+    def _get_cached(cls: type[BaseModelInstance],
+                    id_: BaseModelId) -> BaseModelInstance | None:
+        """Get object of id_ from class's cache, or None if not found."""
+        cache = cls.get_cache()
+        if id_ in cache:
+            obj = cache[id_]
+            assert isinstance(obj, cls)
+            return obj
+        return None
+
     def cache(self) -> None:
     def cache(self) -> None:
-        """Update object in class's cache."""
+        """Update object in class's cache.
+
+        Also calls ._disappear if cache holds older reference to object of same
+        ID, but different memory address, to avoid doing anything with
+        dangling leftovers.
+        """
         if self.id_ is None:
             raise HandledException('Cannot cache object without ID.')
         if self.id_ is None:
             raise HandledException('Cannot cache object without ID.')
-        cache = self.__class__.get_cache()
+        cache = self.get_cache()
+        old_cached = self._get_cached(self.id_)
+        if old_cached and id(old_cached) != id(self):
+            # pylint: disable=protected-access
+            # (cause we remain within the class)
+            old_cached._disappear()
         cache[self.id_] = self
 
         cache[self.id_] = self
 
-    def uncache(self) -> None:
+    def _uncache(self) -> None:
         """Remove self from cache."""
         if self.id_ is None:
             raise HandledException('Cannot un-cache object without ID.')
         """Remove self from cache."""
         if self.id_ is None:
             raise HandledException('Cannot un-cache object without ID.')
-        cache = self.__class__.get_cache()
+        cache = self.get_cache()
         del cache[self.id_]
 
         del cache[self.id_]
 
+    # object retrieval and generation
+
     @classmethod
     def from_table_row(cls: type[BaseModelInstance],
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
     @classmethod
     def from_table_row(cls: type[BaseModelInstance],
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
-        """Make from DB row, write to DB cache."""
+        """Make from DB row (sans relations), update DB cache with it."""
         obj = cls(*row)
         obj = cls(*row)
+        assert obj.id_ is not None
+        for attr_name in cls.to_save_versioned:
+            attr = getattr(obj, attr_name)
+            table_name = attr.table_name
+            for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
+                attr.history_from_row(row_)
         obj.cache()
         return obj
 
     @classmethod
         obj.cache()
         return obj
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection,
-              id_: BaseModelId | None,
-              # pylint: disable=unused-argument
-              create: bool = False) -> Self:
+    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
-
-        If create=True, make anew (but do not cache yet).
         """
         obj = None
         if id_ is not None:
         """
         obj = None
         if id_ is not None:
-            obj = cls.get_cached(id_)
+            obj = cls._get_cached(id_)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
                     obj = cls.from_table_row(db_conn, row)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
                     obj = cls.from_table_row(db_conn, row)
-                    obj.cache()
                     break
         if obj:
             return obj
                     break
         if obj:
             return obj
-        if create:
-            obj = cls(id_)
-            return obj
         raise NotFoundException(f'found no object of ID {id_}')
 
         raise NotFoundException(f'found no object of ID {id_}')
 
+    @classmethod
+    def by_id_or_create(cls, db_conn: DatabaseConnection,
+                        id_: BaseModelId | None
+                        ) -> Self:
+        """Wrapper around .by_id, creating (not caching/saving) if not find."""
+        if not cls.can_create_by_id:
+            raise HandledException('Class cannot .by_id_or_create.')
+        if id_ is None:
+            return cls(None)
+        try:
+            return cls.by_id(db_conn, id_)
+        except NotFoundException:
+            return cls(id_)
+
     @classmethod
     def all(cls: type[BaseModelInstance],
             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
     @classmethod
     def all(cls: type[BaseModelInstance],
             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
@@ -368,7 +507,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of Days in database within (open) date_range interval.
+        """Return list of items in database within (open) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -404,6 +543,8 @@ class BaseModel(Generic[BaseModelId]):
             return filtered
         return items
 
             return filtered
         return items
 
+    # database writing
+
     def save(self, db_conn: DatabaseConnection) -> None:
         """Write self to DB and cache and ensure .id_.
 
     def save(self, db_conn: DatabaseConnection) -> None:
         """Write self to DB and cache and ensure .id_.
 
@@ -417,10 +558,9 @@ class BaseModel(Generic[BaseModelId]):
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
-        q_marks = DatabaseConnection.q_marks_from_values(values)
         table_name = self.table_name
         table_name = self.table_name
-        cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
-                              values)
+        cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
+                                      values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
@@ -434,11 +574,12 @@ class BaseModel(Generic[BaseModelId]):
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
-        if self.id_ is None or self.__class__.get_cached(self.id_) is None:
+        if self.id_ is None or self._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)
-        self.uncache()
+        self._uncache()
         db_conn.delete_where(self.table_name, 'id', self.id_)
         db_conn.delete_where(self.table_name, 'id', self.id_)
+        self._disappear()