From ad94c0df56c82981c0832bae0a3969f91f49f042 Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Mon, 6 Jan 2025 17:20:11 +0100 Subject: [PATCH] Reorganize migrations code. --- src/ytplom/db.py | 78 +++++++++++++++++++++++++++------------- src/ytplom/migrations.py | 18 +++++----- 2 files changed, 63 insertions(+), 33 deletions(-) diff --git a/src/ytplom/db.py b/src/ytplom/db.py index c310b63..46cd9ab 100644 --- a/src/ytplom/db.py +++ b/src/ytplom/db.py @@ -14,10 +14,9 @@ from ytplom.primitives import ( EXPECTED_DB_VERSION = 6 -PATH_DB = PATH_APP_DATA.joinpath('db.sql') +PATH_DB = PATH_APP_DATA.joinpath('TESTdb.sql') SqlText = NewType('SqlText', str) -MigrationsDict = dict[int, tuple[Optional[Path], Optional[Callable]]] _SQL_DB_VERSION = SqlText('PRAGMA user_version') _PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations') @@ -139,35 +138,66 @@ class DbFile: conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8')) conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') - def migrate(self, migrations: MigrationsDict) -> None: + def migrate(self, migrations: set['DbMigration']) -> None: """Migrate self towards EXPECTED_DB_VERSION""" start_version = self._get_user_version() - if start_version >= EXPECTED_DB_VERSION: + if start_version == EXPECTED_DB_VERSION: raise HandledException( - f'Cannot migrate {start_version} to {EXPECTED_DB_VERSION}.') - migs_to_do = [] - for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]: - if n not in migrations: - raise HandledException(f'Needed migration missing: {n}') - mig_tuple = migrations[n] - if path := mig_tuple[0]: - start_tok = str(path).split('_', maxsplit=1)[0] - if (not start_tok.isdigit()) or int(start_tok) != n: - raise HandledException( - f'migration {n} mapped to bad path {path}') - migs_to_do += [(n, *mig_tuple)] + f'Already at {EXPECTED_DB_VERSION}, nothing to migrate.') + if start_version > EXPECTED_DB_VERSION: + raise HandledException( + f'Cannot migrate backwards from {start_version}' + f'to {EXPECTED_DB_VERSION}.') with DbConn(self) as conn: - for version, filename_sql, after_sql_steps in migs_to_do: - if filename_sql: - conn.exec_script( - SqlText(_PATH_MIGRATIONS.joinpath(filename_sql) - .read_text(encoding='utf8'))) - if after_sql_steps: - after_sql_steps(conn) - conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}')) + for migration in DbMigration.from_to_in_set( + start_version, EXPECTED_DB_VERSION, migrations): + migration.perform(conn) conn.commit() +class DbMigration: + """Representation of DbFile migration data.""" + + def __init__(self, + version: int, + sql_path: Optional[Path] = None, + after_sql_steps: Optional[Callable[['DbConn'], None]] = None + ) -> None: + if sql_path: + start_tok = str(sql_path).split('_', maxsplit=1)[0] + if (not start_tok.isdigit()) or int(start_tok) != version: + raise HandledException( + f'migration {version} mapped to bad path {sql_path}') + self._version = version + self._sql_path = sql_path + self._after_sql_steps = after_sql_steps + + @classmethod + def from_to_in_set( + cls, from_version: int, to_version: int, migs_set: set[Self] + ) -> list[Self]: + """From migs_set make sorted unbroken list from_version to_version.""" + selected_migs = [] + for version in [n+1 for n in range(from_version, to_version)]: + matching_migs = [m for m in migs_set if version == m._version] + if not matching_migs: + raise HandledException(f'Missing migration of v{version}') + if len(matching_migs) > 1: + raise HandledException(f'More than 1 Migration of v{version}') + selected_migs += [matching_migs[0]] + return selected_migs + + def perform(self, conn: 'DbConn') -> None: + """Do 1) script at sql_path, 2) after_sql_steps, 3) versino setting.""" + if self._sql_path: + conn.exec_script( + SqlText(_PATH_MIGRATIONS.joinpath(self._sql_path) + .read_text(encoding='utf8'))) + if self._after_sql_steps: + self._after_sql_steps(conn) + conn.exec(SqlText(f'{_SQL_DB_VERSION} = {self._version}')) + + class DbConn: """Wrapper for sqlite3 connections.""" diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py index be76b3d..f075cdf 100644 --- a/src/ytplom/migrations.py +++ b/src/ytplom/migrations.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Callable # ourselves -from ytplom.db import DbConn, DbFile, MigrationsDict, SqlText +from ytplom.db import DbConn, DbFile, DbMigration, SqlText from ytplom.primitives import HandledException @@ -55,14 +55,14 @@ def _mig_4_convert_digests(conn: DbConn) -> None: _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex) -_MIGRATIONS: MigrationsDict = { - 0: (Path('0_init.sql'), None), - 1: (Path('1_add_files_last_updated.sql'), None), - 2: (Path('2_add_files_sha512.sql'), _mig_2_calc_digests), - 3: (Path('3_files_redo.sql'), None), - 4: (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests), - 5: (Path('5_files_redo.sql'), None), - 6: (Path('6_add_files_tags.sql'), None) +_MIGRATIONS: set[DbMigration] = { + DbMigration(0, Path('0_init.sql'), None), + DbMigration(1, Path('1_add_files_last_updated.sql'), None), + DbMigration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests), + DbMigration(3, Path('3_files_redo.sql'), None), + DbMigration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests), + DbMigration(5, Path('5_files_redo.sql'), None), + DbMigration(6, Path('6_add_files_tags.sql'), None) } -- 2.30.2