home · contact · privacy
Overhaul caching.
[plomtask] / plomtask / db.py
index 2ea7421feec578dab3fb8ba9e1eecf547eb6a36c..99998a6ab29f760ba0d62f90395739dad4b521ff 100644 (file)
@@ -18,8 +18,9 @@ class UnmigratedDbException(HandledException):
     """To identify case of unmigrated DB file."""
 
 
     """To identify case of unmigrated DB file."""
 
 
-class DatabaseFile:  # pylint: disable=too-few-public-methods
+class DatabaseFile:
     """Represents the sqlite3 database's file."""
     """Represents the sqlite3 database's file."""
+    # pylint: disable=too-few-public-methods
 
     def __init__(self, path: str) -> None:
         self.path = path
 
     def __init__(self, path: str) -> None:
         self.path = path
@@ -38,7 +39,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
-        from_version = cls.get_version_of_db(path)
+        from_version = cls._get_version_of_db(path)
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
@@ -54,7 +55,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
-        if self.user_version != EXPECTED_DB_VERSION:
+        if self._user_version != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
         self._validate_schema()
 
             raise UnmigratedDbException()
         self._validate_schema()
 
@@ -86,7 +87,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return migrations_list
 
     @staticmethod
         return migrations_list
 
     @staticmethod
-    def get_version_of_db(path: str) -> int:
+    def _get_version_of_db(path: str) -> int:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
@@ -99,9 +100,9 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return db_version
 
     @property
         return db_version
 
     @property
-    def user_version(self) -> int:
+    def _user_version(self) -> int:
         """Get DB user_version."""
         """Get DB user_version."""
-        return self.__class__.get_version_of_db(self.path)
+        return self._get_version_of_db(self.path)
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
@@ -156,8 +157,7 @@ class DatabaseConnection:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
-        self.file = db_file
-        self.conn = sql_connect(self.file.path)
+        self.conn = sql_connect(db_file.path)
 
     def commit(self) -> None:
         """Commit SQL transaction."""
 
     def commit(self) -> None:
         """Commit SQL transaction."""
@@ -167,6 +167,11 @@ class DatabaseConnection:
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
+    def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
+        """Wrapper around .exec appending adequate " (?, …)" to code."""
+        q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
+        return self.exec(f'{code} {q_marks_from_values}', inputs)
+
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
@@ -183,8 +188,7 @@ class DatabaseConnection:
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
-            q_marks = self.__class__.q_marks_from_values(values)
-            self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
+            self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
@@ -220,11 +224,6 @@ class DatabaseConnection:
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
-    @staticmethod
-    def q_marks_from_values(values: tuple[Any]) -> str:
-        """Return placeholder to insert values into SQL code."""
-        return '(' + ','.join(['?'] * len(values)) + ')'
-
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
@@ -239,21 +238,30 @@ class BaseModel(Generic[BaseModelId]):
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
+    _exists = True
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
 
     def __init__(self, id_: BaseModelId | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
+        if isinstance(id_, str) and "" == id_:
+            msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
+            raise HandledException(msg)
         self.id_ = id_
 
         self.id_ = id_
 
+    def __hash__(self) -> int:
+        hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+        for definition in self.to_save_relations:
+            attr = getattr(self, definition[2])
+            hashable += [tuple(rel.id_ for rel in attr)]
+        for name in self.to_save_versioned:
+            hashable += [hash(getattr(self, name))]
+        return hash(tuple(hashable))
+
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
-        to_hash_me = tuple([self.id_] +
-                           [getattr(self, name) for name in self.to_save])
-        to_hash_other = tuple([other.id_] +
-                              [getattr(other, name) for name in other.to_save])
-        return hash(to_hash_me) == hash(to_hash_other)
+        return hash(self) == hash(other)
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
@@ -263,17 +271,27 @@ class BaseModel(Generic[BaseModelId]):
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
-    @classmethod
-    def get_cached(cls: type[BaseModelInstance],
-                   id_: BaseModelId) -> BaseModelInstance | None:
-        """Get object of id_ from class's cache, or None if not found."""
-        # pylint: disable=consider-iterating-dictionary
-        cache = cls.get_cache()
-        if id_ in cache.keys():
-            obj = cache[id_]
-            assert isinstance(obj, cls)
-            return obj
-        return None
+    # cache management
+    # (we primarily use the cache to ensure we work on the same object in
+    # memory no matter where and how we retrieve it, e.g. we don't want
+    # .by_id() calls to create a new object each time, but rather a pointer
+    # to the one already instantiated)
+
+    def __getattribute__(self, name: str) -> Any:
+        """Ensure fail if ._disappear() was called, except to check ._exists"""
+        if name != '_exists' and not super().__getattribute__('_exists'):
+            raise HandledException('Object does not exist.')
+        return super().__getattribute__(name)
+
+    def _disappear(self) -> None:
+        """Invalidate object, make future use raise exceptions."""
+        assert self.id_ is not None
+        if self._get_cached(self.id_):
+            self._uncache()
+        to_kill = list(self.__dict__.keys())
+        for attr in to_kill:
+            delattr(self, attr)
+        self._exists = False
 
     @classmethod
     def empty_cache(cls) -> None:
 
     @classmethod
     def empty_cache(cls) -> None:
@@ -288,28 +306,52 @@ class BaseModel(Generic[BaseModelId]):
             cls.cache_ = d
         return cls.cache_
 
             cls.cache_ = d
         return cls.cache_
 
-    def cache(self) -> None:
-        """Update object in class's cache."""
+    @classmethod
+    def _get_cached(cls: type[BaseModelInstance],
+                    id_: BaseModelId) -> BaseModelInstance | None:
+        """Get object of id_ from class's cache, or None if not found."""
+        # pylint: disable=consider-iterating-dictionary
+        cache = cls.get_cache()
+        if id_ in cache.keys():
+            obj = cache[id_]
+            assert isinstance(obj, cls)
+            return obj
+        return None
+
+    def _cache(self) -> None:
+        """Update object in class's cache.
+
+        Also calls ._disappear if cache holds older reference to object of same
+        ID, but different memory address, to avoid doing anything with
+        dangling leftovers.
+        """
         if self.id_ is None:
             raise HandledException('Cannot cache object without ID.')
         if self.id_ is None:
             raise HandledException('Cannot cache object without ID.')
-        cache = self.__class__.get_cache()
+        cache = self.get_cache()
+        old_cached = self._get_cached(self.id_)
+        if old_cached and id(old_cached) != id(self):
+            # pylint: disable=protected-access
+            # (cause we remain within the class)
+            old_cached._disappear()
         cache[self.id_] = self
 
         cache[self.id_] = self
 
-    def uncache(self) -> None:
+    def _uncache(self) -> None:
         """Remove self from cache."""
         if self.id_ is None:
             raise HandledException('Cannot un-cache object without ID.')
         """Remove self from cache."""
         if self.id_ is None:
             raise HandledException('Cannot un-cache object without ID.')
-        cache = self.__class__.get_cache()
+        cache = self.get_cache()
         del cache[self.id_]
 
         del cache[self.id_]
 
+    # object retrieval and generation
+
     @classmethod
     def from_table_row(cls: type[BaseModelInstance],
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
     @classmethod
     def from_table_row(cls: type[BaseModelInstance],
                        # pylint: disable=unused-argument
                        db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> BaseModelInstance:
-        """Make from DB row, write to DB cache."""
+        """Make from DB row, update DB cache with it."""
         obj = cls(*row)
         obj = cls(*row)
-        obj.cache()
+        obj._cache()
         return obj
 
     @classmethod
         return obj
 
     @classmethod
@@ -326,11 +368,10 @@ class BaseModel(Generic[BaseModelId]):
         """
         obj = None
         if id_ is not None:
         """
         obj = None
         if id_ is not None:
-            obj = cls.get_cached(id_)
+            obj = cls._get_cached(id_)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
                     obj = cls.from_table_row(db_conn, row)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
                     obj = cls.from_table_row(db_conn, row)
-                    obj.cache()
                     break
         if obj:
             return obj
                     break
         if obj:
             return obj
@@ -368,7 +409,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of Days in database within (open) date_range interval.
+        """Return list of items in database within (open) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -404,6 +445,8 @@ class BaseModel(Generic[BaseModelId]):
             return filtered
         return items
 
             return filtered
         return items
 
+    # database writing
+
     def save(self, db_conn: DatabaseConnection) -> None:
         """Write self to DB and cache and ensure .id_.
 
     def save(self, db_conn: DatabaseConnection) -> None:
         """Write self to DB and cache and ensure .id_.
 
@@ -417,13 +460,12 @@ class BaseModel(Generic[BaseModelId]):
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
-        q_marks = DatabaseConnection.q_marks_from_values(values)
         table_name = self.table_name
         table_name = self.table_name
-        cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
-                              values)
+        cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
+                                      values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
-        self.cache()
+        self._cache()
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
@@ -434,11 +476,12 @@ class BaseModel(Generic[BaseModelId]):
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
-        if self.id_ is None or self.__class__.get_cached(self.id_) is None:
+        if self.id_ is None or self._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)
         for table, column, attr_name, _ in self.to_save_relations:
             db_conn.delete_where(table, column, self.id_)
-        self.uncache()
+        self._uncache()
         db_conn.delete_where(self.table_name, 'id', self.id_)
         db_conn.delete_where(self.table_name, 'id', self.id_)
+        self._disappear()