From 743dbe0d493ddeb47eca981fa5be6d78e4d754c9 Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Wed, 15 Jan 2025 15:10:20 +0100 Subject: [PATCH] First commit. --- db.py | 192 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) create mode 100644 db.py diff --git a/db.py b/db.py new file mode 100644 index 0000000..e13ac14 --- /dev/null +++ b/db.py @@ -0,0 +1,192 @@ +"""Database management.""" +from difflib import Differ +from pathlib import Path +from sqlite3 import connect as sql_connect, Cursor as DbCursor +from typing import Any, Callable, Literal, Optional, Self, TypeVar +from abc import ABC, abstractmethod + + +_SQL_DB_VERSION = 'PRAGMA user_version' +TypePlomDbMigration = TypeVar('TypePlomDbMigration', bound='PlomDbMigration') +TypePlomDbFile = TypeVar('TypePlomDbFile', bound='PlomDbFile') + + +class PlomDbException(Exception): + """Collects 1) a terse machine-readable name, 2) human-friendly message.""" + + def __init__(self, name: str, *args: Any, msg: str = '', **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self.name = name + self.msg = msg + + +class PlomDbFile: + """File readable as DB of expected schema, user version.""" + indent_n: int = 4 + target_version: int + path_schema: Path + default_path: Path + mig_class: type['PlomDbMigration'] + + def __init__(self, + path: Optional[Path] = None, + skip_validations: bool = False + ) -> None: + self.path = path if path else self.default_path + if not self.path.is_file(): + raise PlomDbException('no_is_file', f'no DB file at {self.path}') + if skip_validations: + return + if (user_version := self._get_user_version()) != self.target_version: + raise PlomDbException( + 'bad_version', + f'wrong DB version {user_version} (!= {self.target_version})') + with PlomDbConn(self) as conn: + self._validate_schema(conn) + + @classmethod + def _validate_schema(cls, conn: 'PlomDbConn') -> None: + sch_rows_normed = [] + indent = cls.indent_n * ' ' + for row in [ + r[0] for r in conn.exec( + 'SELECT sql FROM sqlite_master ORDER BY sql') + if r[0]]: + row_normed = [] + for subrow in [sr.rstrip() for sr in row.split('\n')]: + in_parentheses = 0 + split_at = [] + for i, c in enumerate(subrow): + if '(' == c: + in_parentheses += 1 + elif ')' == c: + in_parentheses -= 1 + elif ',' == c and 0 == in_parentheses: + split_at += [i + 1] + prev_split = 0 + for i in split_at: + if segment := subrow[prev_split:i].strip(): + row_normed += [f'{indent}{segment}'] + prev_split = i + if segment := subrow[prev_split:].strip(): + row_normed += [f'{indent}{segment}'] + row_normed[0] = row_normed[0].lstrip() # no indent for opening … + row_normed[-1] = row_normed[-1].lstrip() # … and closing line + if row_normed[-1] != ')' and row_normed[-3][-1] != ',': + row_normed[-3] = row_normed[-3] + ',' + row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')'] + row_normed[-1] = row_normed[-1] + ';' + sch_rows_normed += row_normed + expected_rows =\ + cls.path_schema.read_text(encoding='utf8').rstrip().splitlines() + if expected_rows != sch_rows_normed: + raise PlomDbException( + 'bad_schema', + 'Unexpected tables schema. Diff to {cls.path_schema}:\n' + + '\n'.join(Differ().compare(sch_rows_normed, expected_rows))) + + def _get_user_version(self) -> int: + with sql_connect(self.path) as conn: + val = list(conn.execute(_SQL_DB_VERSION))[0][0] + assert isinstance(val, int) + return val + + @classmethod + def create(cls, path_db: Optional[Path] = None) -> None: + """Create DB file at path_db according to file at self.path_schema..""" + path_db = path_db if path_db else cls.default_path + if path_db.exists(): + raise PlomDbException('no_create_path_exists', + f'There already exists a node at {path_db}.') + if not path_db.parent.is_dir(): + raise PlomDbException( + 'no_create_no_dir', + f'No directory {path_db.parent} found to write into.') + with sql_connect(path_db) as conn: + conn.executescript(cls.path_schema.read_text(encoding='utf8')) + conn.execute(f'{_SQL_DB_VERSION} = {cls.target_version}') + + def migrate(self, migrations: set[TypePlomDbMigration]) -> None: + """Migrate towards .target_version, following migrations.""" + from_version = self._get_user_version() + if from_version >= self.target_version: + raise PlomDbException( + 'no_migrate_path', + f'No migrating {from_version} to {self.target_version}.') + with PlomDbConn(self) as conn: + for migration in self.mig_class.gather(from_version, migrations): + migration.perform(conn) + self._validate_schema(conn) + conn.commit() + + +class PlomDbConn: + """SQL connection to PlomDbFile.""" + default_path: Path + + def __init__(self, db_file: Optional[TypePlomDbFile] = None) -> None: + self._conn = sql_connect( + db_file.path if db_file else self.default_path, + autocommit=False) + # additional sqlite3.Connection shortcuts beyond .exec + self.exec_script = self._conn.executescript + self.commit = self._conn.commit + + def __enter__(self) -> Self: # context manager entry + return self + + def __exit__(self, *_: Any) -> Literal[False]: # context manager exit + self._conn.close() + return False + + def exec(self, + sql: str, + inputs: tuple[Any, ...] = tuple(), + build_q_marks: bool = True + ) -> DbCursor: + """Wraps sqlite3.Connection.execute, appends (!) len(inputs) '?'s.""" + if len(inputs) > 0: + if build_q_marks: + q_marks = ('?' if len(inputs) == 1 + else '(' + ','.join(['?'] * len(inputs)) + ')') + return self._conn.execute(f'{sql} {q_marks}', inputs) + return self._conn.execute(sql, inputs) + return self._conn.execute(sql) + + +class PlomDbMigration(ABC): + """Collects and enacts PlomDbFile migration commands.""" + migs_dir_path: Path = Path() + + def __init__(self, + target_version: int, + sql_path: Optional[Path] = None, + post_sql_steps: Optional[Callable] = None + ) -> None: + if sql_path: + start_tok = sql_path.name.split('_', maxsplit=1)[0] + if (not start_tok.isdigit()) or int(start_tok) != target_version: + raise PlomDbException( + 'no_migrate_bad_path', + f'bad path {sql_path} for migration to {target_version}') + self.target_version = target_version + self._sql_path = sql_path + self._post_sql_steps = post_sql_steps + + def perform(self, conn: PlomDbConn) -> None: + """Do ._sql_path script and ._post_sql_steps, set .target_version.""" + if self._sql_path: + sql_path = self.__class__.migs_dir_path.joinpath(self._sql_path) + conn.exec_script(sql_path.read_text(encoding='utf8')) + if self._post_sql_steps: + self._post_sql_steps(conn) + conn.exec(f'{_SQL_DB_VERSION} = {self.target_version}') + + @classmethod + @abstractmethod + def gather(cls, + from_version: int, + base_set: set[TypePlomDbMigration] + ) -> list[TypePlomDbMigration]: + """Return sorted list of migrations to perform.""" -- 2.30.2