From ad94c0df56c82981c0832bae0a3969f91f49f042 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 6 Jan 2025 17:20:11 +0100
Subject: [PATCH] Reorganize migrations code.

---
 src/ytplom/db.py         | 78 +++++++++++++++++++++++++++-------------
 src/ytplom/migrations.py | 18 +++++-----
 2 files changed, 63 insertions(+), 33 deletions(-)

diff --git a/src/ytplom/db.py b/src/ytplom/db.py
index c310b63..46cd9ab 100644
--- a/src/ytplom/db.py
+++ b/src/ytplom/db.py
@@ -14,10 +14,9 @@ from ytplom.primitives import (
 
 
 EXPECTED_DB_VERSION = 6
-PATH_DB = PATH_APP_DATA.joinpath('db.sql')
+PATH_DB = PATH_APP_DATA.joinpath('TESTdb.sql')
 
 SqlText = NewType('SqlText', str)
-MigrationsDict = dict[int, tuple[Optional[Path], Optional[Callable]]]
 
 _SQL_DB_VERSION = SqlText('PRAGMA user_version')
 _PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
@@ -139,35 +138,66 @@ class DbFile:
             conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
             conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
 
-    def migrate(self, migrations: MigrationsDict) -> None:
+    def migrate(self, migrations: set['DbMigration']) -> None:
         """Migrate self towards EXPECTED_DB_VERSION"""
         start_version = self._get_user_version()
-        if start_version >= EXPECTED_DB_VERSION:
+        if start_version == EXPECTED_DB_VERSION:
             raise HandledException(
-                f'Cannot migrate {start_version} to {EXPECTED_DB_VERSION}.')
-        migs_to_do = []
-        for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
-            if n not in migrations:
-                raise HandledException(f'Needed migration missing: {n}')
-            mig_tuple = migrations[n]
-            if path := mig_tuple[0]:
-                start_tok = str(path).split('_', maxsplit=1)[0]
-                if (not start_tok.isdigit()) or int(start_tok) != n:
-                    raise HandledException(
-                        f'migration {n} mapped to bad path {path}')
-            migs_to_do += [(n, *mig_tuple)]
+                f'Already at {EXPECTED_DB_VERSION}, nothing to migrate.')
+        if start_version > EXPECTED_DB_VERSION:
+            raise HandledException(
+                f'Cannot migrate backwards from {start_version}'
+                f'to {EXPECTED_DB_VERSION}.')
         with DbConn(self) as conn:
-            for version, filename_sql, after_sql_steps in migs_to_do:
-                if filename_sql:
-                    conn.exec_script(
-                        SqlText(_PATH_MIGRATIONS.joinpath(filename_sql)
-                                .read_text(encoding='utf8')))
-                if after_sql_steps:
-                    after_sql_steps(conn)
-                conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}'))
+            for migration in DbMigration.from_to_in_set(
+                    start_version, EXPECTED_DB_VERSION, migrations):
+                migration.perform(conn)
             conn.commit()
 
 
+class DbMigration:
+    """Representation of DbFile migration data."""
+
+    def __init__(self,
+                 version: int,
+                 sql_path: Optional[Path] = None,
+                 after_sql_steps: Optional[Callable[['DbConn'], None]] = None
+                 ) -> None:
+        if sql_path:
+            start_tok = str(sql_path).split('_', maxsplit=1)[0]
+            if (not start_tok.isdigit()) or int(start_tok) != version:
+                raise HandledException(
+                        f'migration {version} mapped to bad path {sql_path}')
+        self._version = version
+        self._sql_path = sql_path
+        self._after_sql_steps = after_sql_steps
+
+    @classmethod
+    def from_to_in_set(
+            cls, from_version: int, to_version: int, migs_set: set[Self]
+            ) -> list[Self]:
+        """From migs_set make sorted unbroken list from_version to_version."""
+        selected_migs = []
+        for version in [n+1 for n in range(from_version, to_version)]:
+            matching_migs = [m for m in migs_set if version == m._version]
+            if not matching_migs:
+                raise HandledException(f'Missing migration of v{version}')
+            if len(matching_migs) > 1:
+                raise HandledException(f'More than 1 Migration of v{version}')
+            selected_migs += [matching_migs[0]]
+        return selected_migs
+
+    def perform(self, conn: 'DbConn') -> None:
+        """Do 1) script at sql_path, 2) after_sql_steps, 3) versino setting."""
+        if self._sql_path:
+            conn.exec_script(
+                SqlText(_PATH_MIGRATIONS.joinpath(self._sql_path)
+                        .read_text(encoding='utf8')))
+        if self._after_sql_steps:
+            self._after_sql_steps(conn)
+        conn.exec(SqlText(f'{_SQL_DB_VERSION} = {self._version}'))
+
+
 class DbConn:
     """Wrapper for sqlite3 connections."""
 
diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py
index be76b3d..f075cdf 100644
--- a/src/ytplom/migrations.py
+++ b/src/ytplom/migrations.py
@@ -4,7 +4,7 @@
 from pathlib import Path
 from typing import Callable
 # ourselves
-from ytplom.db import DbConn, DbFile, MigrationsDict, SqlText
+from ytplom.db import DbConn, DbFile, DbMigration, SqlText
 from ytplom.primitives import HandledException
 
 
@@ -55,14 +55,14 @@ def _mig_4_convert_digests(conn: DbConn) -> None:
     _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
 
 
-_MIGRATIONS: MigrationsDict = {
-    0: (Path('0_init.sql'), None),
-    1: (Path('1_add_files_last_updated.sql'), None),
-    2: (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
-    3: (Path('3_files_redo.sql'), None),
-    4: (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
-    5: (Path('5_files_redo.sql'), None),
-    6: (Path('6_add_files_tags.sql'), None)
+_MIGRATIONS: set[DbMigration] = {
+    DbMigration(0, Path('0_init.sql'), None),
+    DbMigration(1, Path('1_add_files_last_updated.sql'), None),
+    DbMigration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
+    DbMigration(3, Path('3_files_redo.sql'), None),
+    DbMigration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+    DbMigration(5, Path('5_files_redo.sql'), None),
+    DbMigration(6, Path('6_add_files_tags.sql'), None)
 }
 
 
-- 
2.30.2