--- /dev/null
+"""Database management."""
+from difflib import Differ
+from pathlib import Path
+from sqlite3 import connect as sql_connect, Cursor as DbCursor
+from typing import Any, Callable, Literal, Optional, Self, TypeVar
+from abc import ABC, abstractmethod
+
+
+_SQL_DB_VERSION = 'PRAGMA user_version'
+TypePlomDbMigration = TypeVar('TypePlomDbMigration', bound='PlomDbMigration')
+TypePlomDbFile = TypeVar('TypePlomDbFile', bound='PlomDbFile')
+
+
+class PlomDbException(Exception):
+ """Collects 1) a terse machine-readable name, 2) human-friendly message."""
+
+ def __init__(self, name: str, *args: Any, msg: str = '', **kwargs: Any
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.name = name
+ self.msg = msg
+
+
+class PlomDbFile:
+ """File readable as DB of expected schema, user version."""
+ indent_n: int = 4
+ target_version: int
+ path_schema: Path
+ default_path: Path
+ mig_class: type['PlomDbMigration']
+
+ def __init__(self,
+ path: Optional[Path] = None,
+ skip_validations: bool = False
+ ) -> None:
+ self.path = path if path else self.default_path
+ if not self.path.is_file():
+ raise PlomDbException('no_is_file', f'no DB file at {self.path}')
+ if skip_validations:
+ return
+ if (user_version := self._get_user_version()) != self.target_version:
+ raise PlomDbException(
+ 'bad_version',
+ f'wrong DB version {user_version} (!= {self.target_version})')
+ with PlomDbConn(self) as conn:
+ self._validate_schema(conn)
+
+ @classmethod
+ def _validate_schema(cls, conn: 'PlomDbConn') -> None:
+ sch_rows_normed = []
+ indent = cls.indent_n * ' '
+ for row in [
+ r[0] for r in conn.exec(
+ 'SELECT sql FROM sqlite_master ORDER BY sql')
+ if r[0]]:
+ row_normed = []
+ for subrow in [sr.rstrip() for sr in row.split('\n')]:
+ in_parentheses = 0
+ split_at = []
+ for i, c in enumerate(subrow):
+ if '(' == c:
+ in_parentheses += 1
+ elif ')' == c:
+ in_parentheses -= 1
+ elif ',' == c and 0 == in_parentheses:
+ split_at += [i + 1]
+ prev_split = 0
+ for i in split_at:
+ if segment := subrow[prev_split:i].strip():
+ row_normed += [f'{indent}{segment}']
+ prev_split = i
+ if segment := subrow[prev_split:].strip():
+ row_normed += [f'{indent}{segment}']
+ row_normed[0] = row_normed[0].lstrip() # no indent for opening …
+ row_normed[-1] = row_normed[-1].lstrip() # … and closing line
+ if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
+ row_normed[-3] = row_normed[-3] + ','
+ row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
+ row_normed[-1] = row_normed[-1] + ';'
+ sch_rows_normed += row_normed
+ expected_rows =\
+ cls.path_schema.read_text(encoding='utf8').rstrip().splitlines()
+ if expected_rows != sch_rows_normed:
+ raise PlomDbException(
+ 'bad_schema',
+ 'Unexpected tables schema. Diff to {cls.path_schema}:\n'
+ + '\n'.join(Differ().compare(sch_rows_normed, expected_rows)))
+
+ def _get_user_version(self) -> int:
+ with sql_connect(self.path) as conn:
+ val = list(conn.execute(_SQL_DB_VERSION))[0][0]
+ assert isinstance(val, int)
+ return val
+
+ @classmethod
+ def create(cls, path_db: Optional[Path] = None) -> None:
+ """Create DB file at path_db according to file at self.path_schema.."""
+ path_db = path_db if path_db else cls.default_path
+ if path_db.exists():
+ raise PlomDbException('no_create_path_exists',
+ f'There already exists a node at {path_db}.')
+ if not path_db.parent.is_dir():
+ raise PlomDbException(
+ 'no_create_no_dir',
+ f'No directory {path_db.parent} found to write into.')
+ with sql_connect(path_db) as conn:
+ conn.executescript(cls.path_schema.read_text(encoding='utf8'))
+ conn.execute(f'{_SQL_DB_VERSION} = {cls.target_version}')
+
+ def migrate(self, migrations: set[TypePlomDbMigration]) -> None:
+ """Migrate towards .target_version, following migrations."""
+ from_version = self._get_user_version()
+ if from_version >= self.target_version:
+ raise PlomDbException(
+ 'no_migrate_path',
+ f'No migrating {from_version} to {self.target_version}.')
+ with PlomDbConn(self) as conn:
+ for migration in self.mig_class.gather(from_version, migrations):
+ migration.perform(conn)
+ self._validate_schema(conn)
+ conn.commit()
+
+
+class PlomDbConn:
+ """SQL connection to PlomDbFile."""
+ default_path: Path
+
+ def __init__(self, db_file: Optional[TypePlomDbFile] = None) -> None:
+ self._conn = sql_connect(
+ db_file.path if db_file else self.default_path,
+ autocommit=False)
+ # additional sqlite3.Connection shortcuts beyond .exec
+ self.exec_script = self._conn.executescript
+ self.commit = self._conn.commit
+
+ def __enter__(self) -> Self: # context manager entry
+ return self
+
+ def __exit__(self, *_: Any) -> Literal[False]: # context manager exit
+ self._conn.close()
+ return False
+
+ def exec(self,
+ sql: str,
+ inputs: tuple[Any, ...] = tuple(),
+ build_q_marks: bool = True
+ ) -> DbCursor:
+ """Wraps sqlite3.Connection.execute, appends (!) len(inputs) '?'s."""
+ if len(inputs) > 0:
+ if build_q_marks:
+ q_marks = ('?' if len(inputs) == 1
+ else '(' + ','.join(['?'] * len(inputs)) + ')')
+ return self._conn.execute(f'{sql} {q_marks}', inputs)
+ return self._conn.execute(sql, inputs)
+ return self._conn.execute(sql)
+
+
+class PlomDbMigration(ABC):
+ """Collects and enacts PlomDbFile migration commands."""
+ migs_dir_path: Path = Path()
+
+ def __init__(self,
+ target_version: int,
+ sql_path: Optional[Path] = None,
+ post_sql_steps: Optional[Callable] = None
+ ) -> None:
+ if sql_path:
+ start_tok = sql_path.name.split('_', maxsplit=1)[0]
+ if (not start_tok.isdigit()) or int(start_tok) != target_version:
+ raise PlomDbException(
+ 'no_migrate_bad_path',
+ f'bad path {sql_path} for migration to {target_version}')
+ self.target_version = target_version
+ self._sql_path = sql_path
+ self._post_sql_steps = post_sql_steps
+
+ def perform(self, conn: PlomDbConn) -> None:
+ """Do ._sql_path script and ._post_sql_steps, set .target_version."""
+ if self._sql_path:
+ sql_path = self.__class__.migs_dir_path.joinpath(self._sql_path)
+ conn.exec_script(sql_path.read_text(encoding='utf8'))
+ if self._post_sql_steps:
+ self._post_sql_steps(conn)
+ conn.exec(f'{_SQL_DB_VERSION} = {self.target_version}')
+
+ @classmethod
+ @abstractmethod
+ def gather(cls,
+ from_version: int,
+ base_set: set[TypePlomDbMigration]
+ ) -> list[TypePlomDbMigration]:
+ """Return sorted list of migrations to perform."""