home · contact · privacy
Enhance BaseModel comparisons by hashing versioned and relations attributes.
[plomtask] / plomtask / db.py
index 2ea7421feec578dab3fb8ba9e1eecf547eb6a36c..a47dff15917651940483afe83f41152399bedf7f 100644 (file)
@@ -18,8 +18,9 @@ class UnmigratedDbException(HandledException):
     """To identify case of unmigrated DB file."""
 
 
     """To identify case of unmigrated DB file."""
 
 
-class DatabaseFile:  # pylint: disable=too-few-public-methods
+class DatabaseFile:
     """Represents the sqlite3 database's file."""
     """Represents the sqlite3 database's file."""
+    # pylint: disable=too-few-public-methods
 
     def __init__(self, path: str) -> None:
         self.path = path
 
     def __init__(self, path: str) -> None:
         self.path = path
@@ -38,7 +39,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
-        from_version = cls.get_version_of_db(path)
+        from_version = cls._get_version_of_db(path)
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
@@ -54,7 +55,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
-        if self.user_version != EXPECTED_DB_VERSION:
+        if self._user_version != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
         self._validate_schema()
 
             raise UnmigratedDbException()
         self._validate_schema()
 
@@ -86,7 +87,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return migrations_list
 
     @staticmethod
         return migrations_list
 
     @staticmethod
-    def get_version_of_db(path: str) -> int:
+    def _get_version_of_db(path: str) -> int:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
@@ -99,9 +100,11 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return db_version
 
     @property
         return db_version
 
     @property
-    def user_version(self) -> int:
+    def _user_version(self) -> int:
         """Get DB user_version."""
         """Get DB user_version."""
-        return self.__class__.get_version_of_db(self.path)
+        # pylint: disable=protected-access
+        # (since we remain within class)
+        return self.__class__._get_version_of_db(self.path)
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
@@ -156,8 +159,7 @@ class DatabaseConnection:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
-        self.file = db_file
-        self.conn = sql_connect(self.file.path)
+        self.conn = sql_connect(db_file.path)
 
     def commit(self) -> None:
         """Commit SQL transaction."""
 
     def commit(self) -> None:
         """Commit SQL transaction."""
@@ -167,6 +169,11 @@ class DatabaseConnection:
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
+    def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
+        """Wrapper around .exec appending adequate " (?, …)" to code."""
+        q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
+        return self.exec(f'{code} {q_marks_from_values}', inputs)
+
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
@@ -183,8 +190,7 @@ class DatabaseConnection:
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
-            q_marks = self.__class__.q_marks_from_values(values)
-            self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
+            self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
@@ -220,11 +226,6 @@ class DatabaseConnection:
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
-    @staticmethod
-    def q_marks_from_values(values: tuple[Any]) -> str:
-        """Return placeholder to insert values into SQL code."""
-        return '(' + ','.join(['?'] * len(values)) + ')'
-
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
@@ -244,16 +245,24 @@ class BaseModel(Generic[BaseModelId]):
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
+        if isinstance(id_, str) and "" == id_:
+            msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
+            raise HandledException(msg)
         self.id_ = id_
 
         self.id_ = id_
 
+    def __hash__(self) -> int:
+        hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+        for definition in self.to_save_relations:
+            attr = getattr(self, definition[2])
+            hashable += [tuple(rel.id_ for rel in attr)]
+        for name in self.to_save_versioned:
+            hashable += [hash(getattr(self, name))]
+        return hash(tuple(hashable))
+
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
-        to_hash_me = tuple([self.id_] +
-                           [getattr(self, name) for name in self.to_save])
-        to_hash_other = tuple([other.id_] +
-                              [getattr(other, name) for name in other.to_save])
-        return hash(to_hash_me) == hash(to_hash_other)
+        return hash(self) == hash(other)
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
@@ -368,7 +377,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of Days in database within (open) date_range interval.
+        """Return list of items in database within (open) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -417,10 +426,9 @@ class BaseModel(Generic[BaseModelId]):
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
-        q_marks = DatabaseConnection.q_marks_from_values(values)
         table_name = self.table_name
         table_name = self.table_name
-        cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
-                              values)
+        cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
+                                      values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()