From: Christian Heller Date: Sun, 5 Jan 2025 06:11:34 +0000 (+0100) Subject: Re-work migration mechanisms. X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/%7B%7Bprefix%7D%7D/pick_tasks?a=commitdiff_plain;h=a5e1094d8482bdeee477bfa51a20087d0ed1744b;p=ytplom Re-work migration mechanisms. --- diff --git a/src/ytplom/db.py b/src/ytplom/db.py index 51e94ac..c310b63 100644 --- a/src/ytplom/db.py +++ b/src/ytplom/db.py @@ -5,8 +5,8 @@ from base64 import urlsafe_b64decode, urlsafe_b64encode from difflib import Differ from hashlib import file_digest from pathlib import Path -from sqlite3 import (connect as sql_connect, Connection as SqlConnection, - Cursor as SqlCursor, Row as SqlRow) +from sqlite3 import ( + connect as sql_connect, Cursor as SqlCursor, Row as SqlRow) from typing import Callable, Literal, NewType, Optional, Self # ourselves from ytplom.primitives import ( @@ -17,7 +17,7 @@ EXPECTED_DB_VERSION = 6 PATH_DB = PATH_APP_DATA.joinpath('db.sql') SqlText = NewType('SqlText', str) -MigrationsList = list[tuple[Path, Optional[Callable]]] +MigrationsDict = dict[int, tuple[Optional[Path], Optional[Callable]]] _SQL_DB_VERSION = SqlText('PRAGMA user_version') _PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations') @@ -66,9 +66,9 @@ class DbFile: path: Path = PATH_DB, version_to_validate: int = EXPECTED_DB_VERSION ) -> None: - self._path = path - if not self._path.is_file(): - raise HandledException(f'no DB file at {self._path}') + self.path = path + if not self.path.is_file(): + raise HandledException(f'no DB file at {self.path}') if version_to_validate < 0: return @@ -78,7 +78,7 @@ class DbFile: f'wrong DB version {user_version} (!= {version_to_validate})') # ensure schema - with sql_connect(self._path) as conn: + with sql_connect(self.path) as conn: schema_rows = [ r[0] for r in conn.execute('SELECT sql FROM sqlite_master ORDER BY sql') @@ -123,7 +123,7 @@ class DbFile: + '\n'.join(diff_msg)) def _get_user_version(self) -> int: - with sql_connect(self._path) as conn: + with sql_connect(self.path) as conn: return list(conn.execute(_SQL_DB_VERSION))[0][0] @staticmethod @@ -139,49 +139,41 @@ class DbFile: conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8')) conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') - def connect(self) -> SqlConnection: - """Open database file into SQL connection, with autocommit off.""" - return sql_connect(self._path, autocommit=False) - - def migrate(self, migrations: MigrationsList) -> None: + def migrate(self, migrations: MigrationsDict) -> None: """Migrate self towards EXPECTED_DB_VERSION""" start_version = self._get_user_version() - if start_version == EXPECTED_DB_VERSION: - print('Database at expected version, no migrations to do.') - return - if start_version > EXPECTED_DB_VERSION: + if start_version >= EXPECTED_DB_VERSION: raise HandledException( - f'Cannot migrate backward from version {start_version} to ' - f'{EXPECTED_DB_VERSION}.') - print(f'Trying to migrate from DB version {start_version} to ' - f'{EXPECTED_DB_VERSION} …') + f'Cannot migrate {start_version} to {EXPECTED_DB_VERSION}.') migs_to_do = [] - migs_by_n = dict(enumerate(migrations)) for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]: - if n not in migs_by_n: + if n not in migrations: raise HandledException(f'Needed migration missing: {n}') - migs_to_do += [(n, *migs_by_n[n])] - for version, filename_sql, after_sql_steps in migs_to_do: - print(f'Running migration towards: {version}') - with DbConn(self.connect()) as conn: + mig_tuple = migrations[n] + if path := mig_tuple[0]: + start_tok = str(path).split('_', maxsplit=1)[0] + if (not start_tok.isdigit()) or int(start_tok) != n: + raise HandledException( + f'migration {n} mapped to bad path {path}') + migs_to_do += [(n, *mig_tuple)] + with DbConn(self) as conn: + for version, filename_sql, after_sql_steps in migs_to_do: if filename_sql: - print(f'Executing {filename_sql}') - path_sql = _PATH_MIGRATIONS.joinpath(filename_sql) conn.exec_script( - SqlText(path_sql.read_text(encoding='utf8'))) + SqlText(_PATH_MIGRATIONS.joinpath(filename_sql) + .read_text(encoding='utf8'))) if after_sql_steps: - print('Running additional steps') after_sql_steps(conn) conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}')) - conn.commit() - print('Finished migrations.') + conn.commit() class DbConn: """Wrapper for sqlite3 connections.""" - def __init__(self, sql_conn: Optional[SqlConnection] = None) -> None: - self._conn = sql_conn or DbFile().connect() + def __init__(self, db_file: Optional[DbFile] = None) -> None: + self._conn = sql_connect((db_file or DbFile()).path, autocommit=False) + self.commit = self._conn.commit def __enter__(self) -> Self: return self @@ -203,10 +195,6 @@ class DbConn: """Wrapper around sqlite3.Connection.executescript.""" self._conn.executescript(sql) - def commit(self) -> None: - """Commit changes (i.e. DbData.save() calls) to database.""" - self._conn.commit() - class DbData: """Abstraction of common DB operation.""" diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py index 5cacc95..be76b3d 100644 --- a/src/ytplom/migrations.py +++ b/src/ytplom/migrations.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Callable # ourselves -from ytplom.db import DbConn, DbFile, MigrationsList, SqlText +from ytplom.db import DbConn, DbFile, MigrationsDict, SqlText from ytplom.primitives import HandledException @@ -55,15 +55,15 @@ def _mig_4_convert_digests(conn: DbConn) -> None: _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex) -_MIGRATIONS: MigrationsList = [ - (Path('0_init.sql'), None), - (Path('1_add_files_last_updated.sql'), None), - (Path('2_add_files_sha512.sql'), _mig_2_calc_digests), - (Path('3_files_redo.sql'), None), - (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests), - (Path('5_files_redo.sql'), None), - (Path('6_add_files_tags.sql'), None) -] +_MIGRATIONS: MigrationsDict = { + 0: (Path('0_init.sql'), None), + 1: (Path('1_add_files_last_updated.sql'), None), + 2: (Path('2_add_files_sha512.sql'), _mig_2_calc_digests), + 3: (Path('3_files_redo.sql'), None), + 4: (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests), + 5: (Path('5_files_redo.sql'), None), + 6: (Path('6_add_files_tags.sql'), None) +} def migrate(): diff --git a/src/ytplom/sync.py b/src/ytplom/sync.py index 89a198d..5cc966b 100644 --- a/src/ytplom/sync.py +++ b/src/ytplom/sync.py @@ -75,8 +75,7 @@ def _sync_relations(host_names: tuple[str, str], def _sync_dbs(scp: SCPClient) -> None: """Download remote DB, run sync_(objects|relations), put remote DB back.""" scp.get(PATH_DB, _PATH_DB_REMOTE) - with DbConn(DbFile(PATH_DB).connect()) as db_local, \ - DbConn(DbFile(_PATH_DB_REMOTE).connect()) as db_remote: + with DbConn() as db_local, DbConn(DbFile(_PATH_DB_REMOTE)) as db_remote: for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile): _back_and_forth(_sync_objects, (db_local, db_remote), cls) for yt_video_local in YoutubeVideo.get_all(db_local):