home · contact · privacy
Minor BaseModel code re-organization.
[plomtask] / plomtask / db.py
index 2ea7421feec578dab3fb8ba9e1eecf547eb6a36c..df98dd0f130bbd75553b2e628cd739d793e98616 100644 (file)
@@ -18,8 +18,9 @@ class UnmigratedDbException(HandledException):
     """To identify case of unmigrated DB file."""
 
 
     """To identify case of unmigrated DB file."""
 
 
-class DatabaseFile:  # pylint: disable=too-few-public-methods
+class DatabaseFile:
     """Represents the sqlite3 database's file."""
     """Represents the sqlite3 database's file."""
+    # pylint: disable=too-few-public-methods
 
     def __init__(self, path: str) -> None:
         self.path = path
 
     def __init__(self, path: str) -> None:
         self.path = path
@@ -38,7 +39,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
         migrations = cls._available_migrations()
-        from_version = cls.get_version_of_db(path)
+        from_version = cls._get_version_of_db(path)
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
         migrations_todo = migrations[from_version+1:]
         for j, filename in enumerate(migrations_todo):
             with sql_connect(path) as conn:
@@ -54,7 +55,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
-        if self.user_version != EXPECTED_DB_VERSION:
+        if self._user_version != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
         self._validate_schema()
 
             raise UnmigratedDbException()
         self._validate_schema()
 
@@ -86,7 +87,7 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return migrations_list
 
     @staticmethod
         return migrations_list
 
     @staticmethod
-    def get_version_of_db(path: str) -> int:
+    def _get_version_of_db(path: str) -> int:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
         """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
@@ -99,9 +100,11 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         return db_version
 
     @property
         return db_version
 
     @property
-    def user_version(self) -> int:
+    def _user_version(self) -> int:
         """Get DB user_version."""
         """Get DB user_version."""
-        return self.__class__.get_version_of_db(self.path)
+        # pylint: disable=protected-access
+        # (since we remain within class)
+        return self.__class__._get_version_of_db(self.path)
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
@@ -156,8 +159,7 @@ class DatabaseConnection:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
     """A single connection to the database."""
 
     def __init__(self, db_file: DatabaseFile) -> None:
-        self.file = db_file
-        self.conn = sql_connect(self.file.path)
+        self.conn = sql_connect(db_file.path)
 
     def commit(self) -> None:
         """Commit SQL transaction."""
 
     def commit(self) -> None:
         """Commit SQL transaction."""
@@ -167,6 +169,11 @@ class DatabaseConnection:
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
         """Add commands to SQL transaction."""
         return self.conn.execute(code, inputs)
 
+    def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
+        """Wrapper around .exec appending adequate " (?, …)" to code."""
+        q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
+        return self.exec(f'{code} {q_marks_from_values}', inputs)
+
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
     def close(self) -> None:
         """Close DB connection."""
         self.conn.close()
@@ -183,8 +190,7 @@ class DatabaseConnection:
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
         self.delete_where(table_name, key, target)
         for row in rows:
             values = tuple(row[:key_index] + [target] + row[key_index:])
-            q_marks = self.__class__.q_marks_from_values(values)
-            self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
+            self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
 
     def row_where(self, table_name: str, key: str,
                   target: int | str) -> list[Row]:
@@ -220,11 +226,6 @@ class DatabaseConnection:
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
         """Delete from table where key == target."""
         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
 
-    @staticmethod
-    def q_marks_from_values(values: tuple[Any]) -> str:
-        """Return placeholder to insert values into SQL code."""
-        return '(' + ','.join(['?'] * len(values)) + ')'
-
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
 
 BaseModelId = TypeVar('BaseModelId', int, str)
 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
@@ -244,16 +245,24 @@ class BaseModel(Generic[BaseModelId]):
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise HandledException(msg)
+        if isinstance(id_, str) and "" == id_:
+            msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
+            raise HandledException(msg)
         self.id_ = id_
 
         self.id_ = id_
 
+    def __hash__(self) -> int:
+        hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+        for definition in self.to_save_relations:
+            attr = getattr(self, definition[2])
+            hashable += [tuple(rel.id_ for rel in attr)]
+        for name in self.to_save_versioned:
+            hashable += [hash(getattr(self, name))]
+        return hash(tuple(hashable))
+
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, self.__class__):
             return False
-        to_hash_me = tuple([self.id_] +
-                           [getattr(self, name) for name in self.to_save])
-        to_hash_other = tuple([other.id_] +
-                              [getattr(other, name) for name in other.to_save])
-        return hash(to_hash_me) == hash(to_hash_other)
+        return hash(self) == hash(other)
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
 
     def __lt__(self, other: Any) -> bool:
         if not isinstance(other, self.__class__):
@@ -263,9 +272,11 @@ class BaseModel(Generic[BaseModelId]):
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
         assert isinstance(other.id_, int)
         return self.id_ < other.id_
 
+    # cache management
+
     @classmethod
     @classmethod
-    def get_cached(cls: type[BaseModelInstance],
-                   id_: BaseModelId) -> BaseModelInstance | None:
+    def _get_cached(cls: type[BaseModelInstance],
+                    id_: BaseModelId) -> BaseModelInstance | None:
         """Get object of id_ from class's cache, or None if not found."""
         # pylint: disable=consider-iterating-dictionary
         cache = cls.get_cache()
         """Get object of id_ from class's cache, or None if not found."""
         # pylint: disable=consider-iterating-dictionary
         cache = cls.get_cache()
@@ -302,6 +313,8 @@ class BaseModel(Generic[BaseModelId]):
         cache = self.__class__.get_cache()
         del cache[self.id_]
 
         cache = self.__class__.get_cache()
         del cache[self.id_]
 
+    # object retrieval and generation
+
     @classmethod
     def from_table_row(cls: type[BaseModelInstance],
                        # pylint: disable=unused-argument
     @classmethod
     def from_table_row(cls: type[BaseModelInstance],
                        # pylint: disable=unused-argument
@@ -326,7 +339,7 @@ class BaseModel(Generic[BaseModelId]):
         """
         obj = None
         if id_ is not None:
         """
         obj = None
         if id_ is not None:
-            obj = cls.get_cached(id_)
+            obj = cls._get_cached(id_)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
                     obj = cls.from_table_row(db_conn, row)
             if not obj:
                 for row in db_conn.row_where(cls.table_name, 'id', id_):
                     obj = cls.from_table_row(db_conn, row)
@@ -368,7 +381,7 @@ class BaseModel(Generic[BaseModelId]):
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
                                   date_col: str = 'day'
                                   ) -> tuple[list[BaseModelInstance], str,
                                              str]:
-        """Return list of Days in database within (open) date_range interval.
+        """Return list of items in database within (open) date_range interval.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
 
         If no range values provided, defaults them to 'yesterday' and
         'tomorrow'. Knows to properly interpret these and 'today' as value.
@@ -404,6 +417,8 @@ class BaseModel(Generic[BaseModelId]):
             return filtered
         return items
 
             return filtered
         return items
 
+    # database writing
+
     def save(self, db_conn: DatabaseConnection) -> None:
         """Write self to DB and cache and ensure .id_.
 
     def save(self, db_conn: DatabaseConnection) -> None:
         """Write self to DB and cache and ensure .id_.
 
@@ -417,10 +432,9 @@ class BaseModel(Generic[BaseModelId]):
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
         """
         values = tuple([self.id_] + [getattr(self, key)
                                      for key in self.to_save])
-        q_marks = DatabaseConnection.q_marks_from_values(values)
         table_name = self.table_name
         table_name = self.table_name
-        cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
-                              values)
+        cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
+                                      values)
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
@@ -434,7 +448,9 @@ class BaseModel(Generic[BaseModelId]):
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
 
     def remove(self, db_conn: DatabaseConnection) -> None:
         """Remove from DB and cache, including dependencies."""
-        if self.id_ is None or self.__class__.get_cached(self.id_) is None:
+        # pylint: disable=protected-access
+        # (since we remain within class)
+        if self.id_ is None or self.__class__._get_cached(self.id_) is None:
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)
             raise HandledException('cannot remove unsaved item')
         for attr_name in self.to_save_versioned:
             getattr(self, attr_name).remove(db_conn)