From 464b0110897b63c96d8997bc06f5fff95d367de9 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Wed, 8 Jan 2025 14:18:26 +0100
Subject: [PATCH] Re-structure DB migration to allow for post-SQL steps in
 Python code.

---
 plomtask/db.py | 111 ++++++++++++++++++++++++++++++++-----------------
 1 file changed, 73 insertions(+), 38 deletions(-)

diff --git a/plomtask/db.py b/plomtask/db.py
index 07c31e9..5076e85 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -14,12 +14,80 @@ EXPECTED_DB_VERSION = 5
 MIGRATIONS_DIR = 'migrations'
 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
+SQL_FOR_DB_VERSION = 'PRAGMA user_version'
 
 
 class UnmigratedDbException(HandledException):
     """To identify case of unmigrated DB file."""
 
 
+class DatabaseMigration:
+    """Collects Database migration data."""
+
+    def __init__(self,
+                 target_version: int,
+                 sql_path: str,
+                 post_sql_steps: Callable[[SqlConnection], None] | None
+                 ) -> None:
+        if sql_path:
+            start_tok = str(sql_path).split('_', maxsplit=1)[0]
+            if (not start_tok.isdigit()) or int(start_tok) != target_version:
+                raise HandledException(f'migration to {target_version} mapped '
+                                       f'to bad path {sql_path}')
+        self._target_version = target_version
+        self._sql_path = sql_path
+        self._post_sql_steps = post_sql_steps
+
+    @classmethod
+    def migrations_after(cls, starting_from: int) -> list[Self]:
+        """Make sorted unbroken list of available migrations >starting_from."""
+        msg_prefix = 'Migration directory contains'
+        msg_bad_entry = f'{msg_prefix} unexpected entry: '
+        migs = []
+        total_migs = set()
+        post_sql_steps_added = set()
+        for entry in [e for e in listdir(MIGRATIONS_DIR)
+                      if e != FILENAME_DB_SCHEMA]:
+            toks = entry.split('_', maxsplit=1)
+            if len(toks) < 2 or (not toks[0].isdigit()):
+                raise HandledException(f'{msg_bad_entry}{entry}')
+            i = int(toks[0])
+            if i <= starting_from:
+                continue
+            if i > EXPECTED_DB_VERSION:
+                raise HandledException(f'{msg_prefix} uexpected version {i}')
+            post_sql_steps = MIGRATION_STEPS_POST_SQL.get(i, None)
+            if post_sql_steps:
+                post_sql_steps_added.add(i)
+            total_migs.add(
+                    cls(i, f'{MIGRATIONS_DIR}/{entry}', post_sql_steps))
+        for k in [k for k in MIGRATION_STEPS_POST_SQL
+                  if k > starting_from
+                  and k not in post_sql_steps_added]:
+            total_migs.add(cls(k, '', MIGRATION_STEPS_POST_SQL[k]))
+        for i in range(starting_from + 1, EXPECTED_DB_VERSION + 1):
+            # pylint: disable=protected-access
+            migs_found = [m for m in total_migs if m._target_version == i]
+            if not migs_found:
+                raise HandledException(f'{msg_prefix} no migration of v. {i}')
+            if len(migs_found) > 1:
+                raise HandledException(f'{msg_prefix} >1 migration of v. {i}')
+            migs += migs_found
+        return migs
+
+    def perform(self, conn: SqlConnection) -> None:
+        """Do 1) script at sql_path, 2) after_sql_steps, 3) version setting."""
+        if self._sql_path:
+            with open(self._sql_path, 'r', encoding='utf8') as f:
+                conn.executescript(f.read())
+        if self._post_sql_steps:
+            self._post_sql_steps(conn)
+        conn.execute(f'{SQL_FOR_DB_VERSION} = {self._target_version}')
+
+
+MIGRATION_STEPS_POST_SQL: dict[int, Callable[[SqlConnection], None]] = {}
+
+
 class DatabaseFile:
     """Represents the sqlite3 database's file."""
     # pylint: disable=too-few-public-methods
@@ -34,24 +102,19 @@ class DatabaseFile:
         with sql_connect(path) as conn:
             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
                 conn.executescript(f.read())
-            conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
+            conn.execute(f'{SQL_FOR_DB_VERSION} = {EXPECTED_DB_VERSION}')
         return cls(path)
 
     @classmethod
     def migrate(cls, path: str) -> DatabaseFile:
-        """Apply migrations from_version to EXPECTED_DB_VERSION."""
+        """Apply migrations from current version to EXPECTED_DB_VERSION."""
         from_version = cls._get_version_of_db(path)
         if from_version >= EXPECTED_DB_VERSION:
             raise HandledException(
                     f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}')
-        migrations = cls._available_migrations()
-        migrations_todo = migrations[from_version+1:]
         with sql_connect(path, autocommit=False) as conn:
-            for j, filename in enumerate(migrations_todo):
-                with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
-                          encoding='utf-8') as f:
-                    conn.executescript(f.read())
-                conn.execute(f'PRAGMA user_version = {from_version + j + 1}')
+            for mig in DatabaseMigration.migrations_after(from_version):
+                mig.perform(conn)
             cls._validate_schema(conn)
             conn.commit()
         return cls(path)
@@ -65,39 +128,11 @@ class DatabaseFile:
         with sql_connect(self.path) as conn:
             self._validate_schema(conn)
 
-    @staticmethod
-    def _available_migrations() -> list[str]:
-        """Validate migrations directory and return sorted entries."""
-        msg_too_big = 'Migration directory points beyond expected DB version.'
-        msg_bad_entry = 'Migration directory contains unexpected entry: '
-        msg_missing = 'Migration directory misses migration of number: '
-        migrations = {}
-        for entry in listdir(MIGRATIONS_DIR):
-            if entry == FILENAME_DB_SCHEMA:
-                continue
-            toks = entry.split('_', 1)
-            if len(toks) < 2:
-                raise HandledException(msg_bad_entry + entry)
-            try:
-                i = int(toks[0])
-            except ValueError as e:
-                raise HandledException(msg_bad_entry + entry) from e
-            if i > EXPECTED_DB_VERSION:
-                raise HandledException(msg_too_big)
-            migrations[i] = toks[1]
-        migrations_list = []
-        for i in range(EXPECTED_DB_VERSION + 1):
-            if i not in migrations:
-                raise HandledException(msg_missing + str(i))
-            migrations_list += [f'{i}_{migrations[i]}']
-        return migrations_list
-
     @staticmethod
     def _get_version_of_db(path: str) -> int:
         """Get DB user_version, fail if outside expected range."""
-        sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
-            db_version = list(conn.execute(sql_for_db_version))[0][0]
+            db_version = list(conn.execute(SQL_FOR_DB_VERSION))[0][0]
         assert isinstance(db_version, int)
         return db_version
 
-- 
2.30.2