From: Christian Heller Date: Wed, 8 Jan 2025 13:18:26 +0000 (+0100) Subject: Re-structure DB migration to allow for post-SQL steps in Python code. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/te"st.html?a=commitdiff_plain;h=464b0110897b63c96d8997bc06f5fff95d367de9;p=plomtask Re-structure DB migration to allow for post-SQL steps in Python code. --- diff --git a/plomtask/db.py b/plomtask/db.py index 07c31e9..5076e85 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -14,12 +14,80 @@ EXPECTED_DB_VERSION = 5 MIGRATIONS_DIR = 'migrations' FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql' PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}' +SQL_FOR_DB_VERSION = 'PRAGMA user_version' class UnmigratedDbException(HandledException): """To identify case of unmigrated DB file.""" +class DatabaseMigration: + """Collects Database migration data.""" + + def __init__(self, + target_version: int, + sql_path: str, + post_sql_steps: Callable[[SqlConnection], None] | None + ) -> None: + if sql_path: + start_tok = str(sql_path).split('_', maxsplit=1)[0] + if (not start_tok.isdigit()) or int(start_tok) != target_version: + raise HandledException(f'migration to {target_version} mapped ' + f'to bad path {sql_path}') + self._target_version = target_version + self._sql_path = sql_path + self._post_sql_steps = post_sql_steps + + @classmethod + def migrations_after(cls, starting_from: int) -> list[Self]: + """Make sorted unbroken list of available migrations >starting_from.""" + msg_prefix = 'Migration directory contains' + msg_bad_entry = f'{msg_prefix} unexpected entry: ' + migs = [] + total_migs = set() + post_sql_steps_added = set() + for entry in [e for e in listdir(MIGRATIONS_DIR) + if e != FILENAME_DB_SCHEMA]: + toks = entry.split('_', maxsplit=1) + if len(toks) < 2 or (not toks[0].isdigit()): + raise HandledException(f'{msg_bad_entry}{entry}') + i = int(toks[0]) + if i <= starting_from: + continue + if i > EXPECTED_DB_VERSION: + raise HandledException(f'{msg_prefix} uexpected version {i}') + post_sql_steps = MIGRATION_STEPS_POST_SQL.get(i, None) + if post_sql_steps: + post_sql_steps_added.add(i) + total_migs.add( + cls(i, f'{MIGRATIONS_DIR}/{entry}', post_sql_steps)) + for k in [k for k in MIGRATION_STEPS_POST_SQL + if k > starting_from + and k not in post_sql_steps_added]: + total_migs.add(cls(k, '', MIGRATION_STEPS_POST_SQL[k])) + for i in range(starting_from + 1, EXPECTED_DB_VERSION + 1): + # pylint: disable=protected-access + migs_found = [m for m in total_migs if m._target_version == i] + if not migs_found: + raise HandledException(f'{msg_prefix} no migration of v. {i}') + if len(migs_found) > 1: + raise HandledException(f'{msg_prefix} >1 migration of v. {i}') + migs += migs_found + return migs + + def perform(self, conn: SqlConnection) -> None: + """Do 1) script at sql_path, 2) after_sql_steps, 3) version setting.""" + if self._sql_path: + with open(self._sql_path, 'r', encoding='utf8') as f: + conn.executescript(f.read()) + if self._post_sql_steps: + self._post_sql_steps(conn) + conn.execute(f'{SQL_FOR_DB_VERSION} = {self._target_version}') + + +MIGRATION_STEPS_POST_SQL: dict[int, Callable[[SqlConnection], None]] = {} + + class DatabaseFile: """Represents the sqlite3 database's file.""" # pylint: disable=too-few-public-methods @@ -34,24 +102,19 @@ class DatabaseFile: with sql_connect(path) as conn: with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: conn.executescript(f.read()) - conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}') + conn.execute(f'{SQL_FOR_DB_VERSION} = {EXPECTED_DB_VERSION}') return cls(path) @classmethod def migrate(cls, path: str) -> DatabaseFile: - """Apply migrations from_version to EXPECTED_DB_VERSION.""" + """Apply migrations from current version to EXPECTED_DB_VERSION.""" from_version = cls._get_version_of_db(path) if from_version >= EXPECTED_DB_VERSION: raise HandledException( f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}') - migrations = cls._available_migrations() - migrations_todo = migrations[from_version+1:] with sql_connect(path, autocommit=False) as conn: - for j, filename in enumerate(migrations_todo): - with open(f'{MIGRATIONS_DIR}/{filename}', 'r', - encoding='utf-8') as f: - conn.executescript(f.read()) - conn.execute(f'PRAGMA user_version = {from_version + j + 1}') + for mig in DatabaseMigration.migrations_after(from_version): + mig.perform(conn) cls._validate_schema(conn) conn.commit() return cls(path) @@ -65,39 +128,11 @@ class DatabaseFile: with sql_connect(self.path) as conn: self._validate_schema(conn) - @staticmethod - def _available_migrations() -> list[str]: - """Validate migrations directory and return sorted entries.""" - msg_too_big = 'Migration directory points beyond expected DB version.' - msg_bad_entry = 'Migration directory contains unexpected entry: ' - msg_missing = 'Migration directory misses migration of number: ' - migrations = {} - for entry in listdir(MIGRATIONS_DIR): - if entry == FILENAME_DB_SCHEMA: - continue - toks = entry.split('_', 1) - if len(toks) < 2: - raise HandledException(msg_bad_entry + entry) - try: - i = int(toks[0]) - except ValueError as e: - raise HandledException(msg_bad_entry + entry) from e - if i > EXPECTED_DB_VERSION: - raise HandledException(msg_too_big) - migrations[i] = toks[1] - migrations_list = [] - for i in range(EXPECTED_DB_VERSION + 1): - if i not in migrations: - raise HandledException(msg_missing + str(i)) - migrations_list += [f'{i}_{migrations[i]}'] - return migrations_list - @staticmethod def _get_version_of_db(path: str) -> int: """Get DB user_version, fail if outside expected range.""" - sql_for_db_version = 'PRAGMA user_version' with sql_connect(path) as conn: - db_version = list(conn.execute(sql_for_db_version))[0][0] + db_version = list(conn.execute(SQL_FOR_DB_VERSION))[0][0] assert isinstance(db_version, int) return db_version