From a5e1094d8482bdeee477bfa51a20087d0ed1744b Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sun, 5 Jan 2025 07:11:34 +0100
Subject: [PATCH] Re-work migration mechanisms.

---
 src/ytplom/db.py         | 66 ++++++++++++++++------------------------
 src/ytplom/migrations.py | 20 ++++++------
 src/ytplom/sync.py       |  3 +-
 3 files changed, 38 insertions(+), 51 deletions(-)

diff --git a/src/ytplom/db.py b/src/ytplom/db.py
index 51e94ac..c310b63 100644
--- a/src/ytplom/db.py
+++ b/src/ytplom/db.py
@@ -5,8 +5,8 @@ from base64 import urlsafe_b64decode, urlsafe_b64encode
 from difflib import Differ
 from hashlib import file_digest
 from pathlib import Path
-from sqlite3 import (connect as sql_connect, Connection as SqlConnection,
-                     Cursor as SqlCursor, Row as SqlRow)
+from sqlite3 import (
+        connect as sql_connect, Cursor as SqlCursor, Row as SqlRow)
 from typing import Callable, Literal, NewType, Optional, Self
 # ourselves
 from ytplom.primitives import (
@@ -17,7 +17,7 @@ EXPECTED_DB_VERSION = 6
 PATH_DB = PATH_APP_DATA.joinpath('db.sql')
 
 SqlText = NewType('SqlText', str)
-MigrationsList = list[tuple[Path, Optional[Callable]]]
+MigrationsDict = dict[int, tuple[Optional[Path], Optional[Callable]]]
 
 _SQL_DB_VERSION = SqlText('PRAGMA user_version')
 _PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
@@ -66,9 +66,9 @@ class DbFile:
                  path: Path = PATH_DB,
                  version_to_validate: int = EXPECTED_DB_VERSION
                  ) -> None:
-        self._path = path
-        if not self._path.is_file():
-            raise HandledException(f'no DB file at {self._path}')
+        self.path = path
+        if not self.path.is_file():
+            raise HandledException(f'no DB file at {self.path}')
 
         if version_to_validate < 0:
             return
@@ -78,7 +78,7 @@ class DbFile:
                 f'wrong DB version {user_version} (!= {version_to_validate})')
 
         # ensure schema
-        with sql_connect(self._path) as conn:
+        with sql_connect(self.path) as conn:
             schema_rows = [
                     r[0] for r in
                     conn.execute('SELECT sql FROM sqlite_master ORDER BY sql')
@@ -123,7 +123,7 @@ class DbFile:
                                    + '\n'.join(diff_msg))
 
     def _get_user_version(self) -> int:
-        with sql_connect(self._path) as conn:
+        with sql_connect(self.path) as conn:
             return list(conn.execute(_SQL_DB_VERSION))[0][0]
 
     @staticmethod
@@ -139,49 +139,41 @@ class DbFile:
             conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
             conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
 
-    def connect(self) -> SqlConnection:
-        """Open database file into SQL connection, with autocommit off."""
-        return sql_connect(self._path, autocommit=False)
-
-    def migrate(self, migrations: MigrationsList) -> None:
+    def migrate(self, migrations: MigrationsDict) -> None:
         """Migrate self towards EXPECTED_DB_VERSION"""
         start_version = self._get_user_version()
-        if start_version == EXPECTED_DB_VERSION:
-            print('Database at expected version, no migrations to do.')
-            return
-        if start_version > EXPECTED_DB_VERSION:
+        if start_version >= EXPECTED_DB_VERSION:
             raise HandledException(
-                    f'Cannot migrate backward from version {start_version} to '
-                    f'{EXPECTED_DB_VERSION}.')
-        print(f'Trying to migrate from DB version {start_version} to '
-              f'{EXPECTED_DB_VERSION} …')
+                f'Cannot migrate {start_version} to {EXPECTED_DB_VERSION}.')
         migs_to_do = []
-        migs_by_n = dict(enumerate(migrations))
         for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
-            if n not in migs_by_n:
+            if n not in migrations:
                 raise HandledException(f'Needed migration missing: {n}')
-            migs_to_do += [(n, *migs_by_n[n])]
-        for version, filename_sql, after_sql_steps in migs_to_do:
-            print(f'Running migration towards: {version}')
-            with DbConn(self.connect()) as conn:
+            mig_tuple = migrations[n]
+            if path := mig_tuple[0]:
+                start_tok = str(path).split('_', maxsplit=1)[0]
+                if (not start_tok.isdigit()) or int(start_tok) != n:
+                    raise HandledException(
+                        f'migration {n} mapped to bad path {path}')
+            migs_to_do += [(n, *mig_tuple)]
+        with DbConn(self) as conn:
+            for version, filename_sql, after_sql_steps in migs_to_do:
                 if filename_sql:
-                    print(f'Executing {filename_sql}')
-                    path_sql = _PATH_MIGRATIONS.joinpath(filename_sql)
                     conn.exec_script(
-                            SqlText(path_sql.read_text(encoding='utf8')))
+                        SqlText(_PATH_MIGRATIONS.joinpath(filename_sql)
+                                .read_text(encoding='utf8')))
                 if after_sql_steps:
-                    print('Running additional steps')
                     after_sql_steps(conn)
                 conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}'))
-                conn.commit()
-        print('Finished migrations.')
+            conn.commit()
 
 
 class DbConn:
     """Wrapper for sqlite3 connections."""
 
-    def __init__(self, sql_conn: Optional[SqlConnection] = None) -> None:
-        self._conn = sql_conn or DbFile().connect()
+    def __init__(self, db_file: Optional[DbFile] = None) -> None:
+        self._conn = sql_connect((db_file or DbFile()).path, autocommit=False)
+        self.commit = self._conn.commit
 
     def __enter__(self) -> Self:
         return self
@@ -203,10 +195,6 @@ class DbConn:
         """Wrapper around sqlite3.Connection.executescript."""
         self._conn.executescript(sql)
 
-    def commit(self) -> None:
-        """Commit changes (i.e. DbData.save() calls) to database."""
-        self._conn.commit()
-
 
 class DbData:
     """Abstraction of common DB operation."""
diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py
index 5cacc95..be76b3d 100644
--- a/src/ytplom/migrations.py
+++ b/src/ytplom/migrations.py
@@ -4,7 +4,7 @@
 from pathlib import Path
 from typing import Callable
 # ourselves
-from ytplom.db import DbConn, DbFile, MigrationsList, SqlText
+from ytplom.db import DbConn, DbFile, MigrationsDict, SqlText
 from ytplom.primitives import HandledException
 
 
@@ -55,15 +55,15 @@ def _mig_4_convert_digests(conn: DbConn) -> None:
     _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
 
 
-_MIGRATIONS: MigrationsList = [
-    (Path('0_init.sql'), None),
-    (Path('1_add_files_last_updated.sql'), None),
-    (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
-    (Path('3_files_redo.sql'), None),
-    (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
-    (Path('5_files_redo.sql'), None),
-    (Path('6_add_files_tags.sql'), None)
-]
+_MIGRATIONS: MigrationsDict = {
+    0: (Path('0_init.sql'), None),
+    1: (Path('1_add_files_last_updated.sql'), None),
+    2: (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
+    3: (Path('3_files_redo.sql'), None),
+    4: (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+    5: (Path('5_files_redo.sql'), None),
+    6: (Path('6_add_files_tags.sql'), None)
+}
 
 
 def migrate():
diff --git a/src/ytplom/sync.py b/src/ytplom/sync.py
index 89a198d..5cc966b 100644
--- a/src/ytplom/sync.py
+++ b/src/ytplom/sync.py
@@ -75,8 +75,7 @@ def _sync_relations(host_names: tuple[str, str],
 def _sync_dbs(scp: SCPClient) -> None:
     """Download remote DB, run sync_(objects|relations), put remote DB back."""
     scp.get(PATH_DB, _PATH_DB_REMOTE)
-    with DbConn(DbFile(PATH_DB).connect()) as db_local, \
-            DbConn(DbFile(_PATH_DB_REMOTE).connect()) as db_remote:
+    with DbConn() as db_local, DbConn(DbFile(_PATH_DB_REMOTE)) as db_remote:
         for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile):
             _back_and_forth(_sync_objects, (db_local, db_remote), cls)
         for yt_video_local in YoutubeVideo.get_all(db_local):
-- 
2.30.2