home · contact · privacy
First commit.
authorChristian Heller <c.heller@plomlompom.de>
Wed, 15 Jan 2025 14:10:20 +0000 (15:10 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Wed, 15 Jan 2025 14:10:20 +0000 (15:10 +0100)
db.py [new file with mode: 0644]

diff --git a/db.py b/db.py
new file mode 100644 (file)
index 0000000..e13ac14
--- /dev/null
+++ b/db.py
@@ -0,0 +1,192 @@
+"""Database management."""
+from difflib import Differ
+from pathlib import Path
+from sqlite3 import connect as sql_connect, Cursor as DbCursor
+from typing import Any, Callable, Literal, Optional, Self, TypeVar
+from abc import ABC, abstractmethod
+
+
+_SQL_DB_VERSION = 'PRAGMA user_version'
+TypePlomDbMigration = TypeVar('TypePlomDbMigration', bound='PlomDbMigration')
+TypePlomDbFile = TypeVar('TypePlomDbFile', bound='PlomDbFile')
+
+
+class PlomDbException(Exception):
+    """Collects 1) a terse machine-readable name, 2) human-friendly message."""
+
+    def __init__(self, name: str, *args: Any, msg: str = '', **kwargs: Any
+                 ) -> None:
+        super().__init__(*args, **kwargs)
+        self.name = name
+        self.msg = msg
+
+
+class PlomDbFile:
+    """File readable as DB of expected schema, user version."""
+    indent_n: int = 4
+    target_version: int
+    path_schema: Path
+    default_path: Path
+    mig_class: type['PlomDbMigration']
+
+    def __init__(self,
+                 path: Optional[Path] = None,
+                 skip_validations: bool = False
+                 ) -> None:
+        self.path = path if path else self.default_path
+        if not self.path.is_file():
+            raise PlomDbException('no_is_file', f'no DB file at {self.path}')
+        if skip_validations:
+            return
+        if (user_version := self._get_user_version()) != self.target_version:
+            raise PlomDbException(
+                'bad_version',
+                f'wrong DB version {user_version} (!= {self.target_version})')
+        with PlomDbConn(self) as conn:
+            self._validate_schema(conn)
+
+    @classmethod
+    def _validate_schema(cls, conn: 'PlomDbConn') -> None:
+        sch_rows_normed = []
+        indent = cls.indent_n * ' '
+        for row in [
+                r[0] for r in conn.exec(
+                    'SELECT sql FROM sqlite_master ORDER BY sql')
+                if r[0]]:
+            row_normed = []
+            for subrow in [sr.rstrip() for sr in row.split('\n')]:
+                in_parentheses = 0
+                split_at = []
+                for i, c in enumerate(subrow):
+                    if '(' == c:
+                        in_parentheses += 1
+                    elif ')' == c:
+                        in_parentheses -= 1
+                    elif ',' == c and 0 == in_parentheses:
+                        split_at += [i + 1]
+                prev_split = 0
+                for i in split_at:
+                    if segment := subrow[prev_split:i].strip():
+                        row_normed += [f'{indent}{segment}']
+                    prev_split = i
+                if segment := subrow[prev_split:].strip():
+                    row_normed += [f'{indent}{segment}']
+            row_normed[0] = row_normed[0].lstrip()  # no indent for opening …
+            row_normed[-1] = row_normed[-1].lstrip()  # … and closing line
+            if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
+                row_normed[-3] = row_normed[-3] + ','
+                row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
+            row_normed[-1] = row_normed[-1] + ';'
+            sch_rows_normed += row_normed
+        expected_rows =\
+            cls.path_schema.read_text(encoding='utf8').rstrip().splitlines()
+        if expected_rows != sch_rows_normed:
+            raise PlomDbException(
+               'bad_schema',
+               'Unexpected tables schema. Diff to {cls.path_schema}:\n'
+               + '\n'.join(Differ().compare(sch_rows_normed, expected_rows)))
+
+    def _get_user_version(self) -> int:
+        with sql_connect(self.path) as conn:
+            val = list(conn.execute(_SQL_DB_VERSION))[0][0]
+            assert isinstance(val, int)
+            return val
+
+    @classmethod
+    def create(cls, path_db: Optional[Path] = None) -> None:
+        """Create DB file at path_db according to file at self.path_schema.."""
+        path_db = path_db if path_db else cls.default_path
+        if path_db.exists():
+            raise PlomDbException('no_create_path_exists',
+                                  f'There already exists a node at {path_db}.')
+        if not path_db.parent.is_dir():
+            raise PlomDbException(
+                    'no_create_no_dir',
+                    f'No directory {path_db.parent} found to write into.')
+        with sql_connect(path_db) as conn:
+            conn.executescript(cls.path_schema.read_text(encoding='utf8'))
+            conn.execute(f'{_SQL_DB_VERSION} = {cls.target_version}')
+
+    def migrate(self, migrations: set[TypePlomDbMigration]) -> None:
+        """Migrate towards .target_version, following migrations."""
+        from_version = self._get_user_version()
+        if from_version >= self.target_version:
+            raise PlomDbException(
+                    'no_migrate_path',
+                    f'No migrating {from_version} to {self.target_version}.')
+        with PlomDbConn(self) as conn:
+            for migration in self.mig_class.gather(from_version, migrations):
+                migration.perform(conn)
+            self._validate_schema(conn)
+            conn.commit()
+
+
+class PlomDbConn:
+    """SQL connection to PlomDbFile."""
+    default_path: Path
+
+    def __init__(self, db_file: Optional[TypePlomDbFile] = None) -> None:
+        self._conn = sql_connect(
+                db_file.path if db_file else self.default_path,
+                autocommit=False)
+        # additional sqlite3.Connection shortcuts beyond .exec
+        self.exec_script = self._conn.executescript
+        self.commit = self._conn.commit
+
+    def __enter__(self) -> Self:  # context manager entry
+        return self
+
+    def __exit__(self, *_: Any) -> Literal[False]:  # context manager exit
+        self._conn.close()
+        return False
+
+    def exec(self,
+             sql: str,
+             inputs: tuple[Any, ...] = tuple(),
+             build_q_marks: bool = True
+             ) -> DbCursor:
+        """Wraps sqlite3.Connection.execute, appends (!) len(inputs) '?'s."""
+        if len(inputs) > 0:
+            if build_q_marks:
+                q_marks = ('?' if len(inputs) == 1
+                           else '(' + ','.join(['?'] * len(inputs)) + ')')
+                return self._conn.execute(f'{sql} {q_marks}', inputs)
+            return self._conn.execute(sql, inputs)
+        return self._conn.execute(sql)
+
+
+class PlomDbMigration(ABC):
+    """Collects and enacts PlomDbFile migration commands."""
+    migs_dir_path: Path = Path()
+
+    def __init__(self,
+                 target_version: int,
+                 sql_path: Optional[Path] = None,
+                 post_sql_steps: Optional[Callable] = None
+                 ) -> None:
+        if sql_path:
+            start_tok = sql_path.name.split('_', maxsplit=1)[0]
+            if (not start_tok.isdigit()) or int(start_tok) != target_version:
+                raise PlomDbException(
+                    'no_migrate_bad_path',
+                    f'bad path {sql_path} for migration to {target_version}')
+        self.target_version = target_version
+        self._sql_path = sql_path
+        self._post_sql_steps = post_sql_steps
+
+    def perform(self, conn: PlomDbConn) -> None:
+        """Do ._sql_path script and ._post_sql_steps, set .target_version."""
+        if self._sql_path:
+            sql_path = self.__class__.migs_dir_path.joinpath(self._sql_path)
+            conn.exec_script(sql_path.read_text(encoding='utf8'))
+        if self._post_sql_steps:
+            self._post_sql_steps(conn)
+        conn.exec(f'{_SQL_DB_VERSION} = {self.target_version}')
+
+    @classmethod
+    @abstractmethod
+    def gather(cls,
+               from_version: int,
+               base_set: set[TypePlomDbMigration]
+               ) -> list[TypePlomDbMigration]:
+        """Return sorted list of migrations to perform."""