MIGRATIONS_DIR = 'migrations'
FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
+SQL_FOR_DB_VERSION = 'PRAGMA user_version'
class UnmigratedDbException(HandledException):
"""To identify case of unmigrated DB file."""
+class DatabaseMigration:
+ """Collects Database migration data."""
+
+ def __init__(self,
+ target_version: int,
+ sql_path: str,
+ post_sql_steps: Callable[[SqlConnection], None] | None
+ ) -> None:
+ if sql_path:
+ start_tok = str(sql_path).split('_', maxsplit=1)[0]
+ if (not start_tok.isdigit()) or int(start_tok) != target_version:
+ raise HandledException(f'migration to {target_version} mapped '
+ f'to bad path {sql_path}')
+ self._target_version = target_version
+ self._sql_path = sql_path
+ self._post_sql_steps = post_sql_steps
+
+ @classmethod
+ def migrations_after(cls, starting_from: int) -> list[Self]:
+ """Make sorted unbroken list of available migrations >starting_from."""
+ msg_prefix = 'Migration directory contains'
+ msg_bad_entry = f'{msg_prefix} unexpected entry: '
+ migs = []
+ total_migs = set()
+ post_sql_steps_added = set()
+ for entry in [e for e in listdir(MIGRATIONS_DIR)
+ if e != FILENAME_DB_SCHEMA]:
+ toks = entry.split('_', maxsplit=1)
+ if len(toks) < 2 or (not toks[0].isdigit()):
+ raise HandledException(f'{msg_bad_entry}{entry}')
+ i = int(toks[0])
+ if i <= starting_from:
+ continue
+ if i > EXPECTED_DB_VERSION:
+ raise HandledException(f'{msg_prefix} uexpected version {i}')
+ post_sql_steps = MIGRATION_STEPS_POST_SQL.get(i, None)
+ if post_sql_steps:
+ post_sql_steps_added.add(i)
+ total_migs.add(
+ cls(i, f'{MIGRATIONS_DIR}/{entry}', post_sql_steps))
+ for k in [k for k in MIGRATION_STEPS_POST_SQL
+ if k > starting_from
+ and k not in post_sql_steps_added]:
+ total_migs.add(cls(k, '', MIGRATION_STEPS_POST_SQL[k]))
+ for i in range(starting_from + 1, EXPECTED_DB_VERSION + 1):
+ # pylint: disable=protected-access
+ migs_found = [m for m in total_migs if m._target_version == i]
+ if not migs_found:
+ raise HandledException(f'{msg_prefix} no migration of v. {i}')
+ if len(migs_found) > 1:
+ raise HandledException(f'{msg_prefix} >1 migration of v. {i}')
+ migs += migs_found
+ return migs
+
+ def perform(self, conn: SqlConnection) -> None:
+ """Do 1) script at sql_path, 2) after_sql_steps, 3) version setting."""
+ if self._sql_path:
+ with open(self._sql_path, 'r', encoding='utf8') as f:
+ conn.executescript(f.read())
+ if self._post_sql_steps:
+ self._post_sql_steps(conn)
+ conn.execute(f'{SQL_FOR_DB_VERSION} = {self._target_version}')
+
+
+MIGRATION_STEPS_POST_SQL: dict[int, Callable[[SqlConnection], None]] = {}
+
+
class DatabaseFile:
"""Represents the sqlite3 database's file."""
# pylint: disable=too-few-public-methods
with sql_connect(path) as conn:
with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
conn.executescript(f.read())
- conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
+ conn.execute(f'{SQL_FOR_DB_VERSION} = {EXPECTED_DB_VERSION}')
return cls(path)
@classmethod
def migrate(cls, path: str) -> DatabaseFile:
- """Apply migrations from_version to EXPECTED_DB_VERSION."""
+ """Apply migrations from current version to EXPECTED_DB_VERSION."""
from_version = cls._get_version_of_db(path)
if from_version >= EXPECTED_DB_VERSION:
raise HandledException(
f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}')
- migrations = cls._available_migrations()
- migrations_todo = migrations[from_version+1:]
with sql_connect(path, autocommit=False) as conn:
- for j, filename in enumerate(migrations_todo):
- with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
- encoding='utf-8') as f:
- conn.executescript(f.read())
- conn.execute(f'PRAGMA user_version = {from_version + j + 1}')
+ for mig in DatabaseMigration.migrations_after(from_version):
+ mig.perform(conn)
cls._validate_schema(conn)
conn.commit()
return cls(path)
with sql_connect(self.path) as conn:
self._validate_schema(conn)
- @staticmethod
- def _available_migrations() -> list[str]:
- """Validate migrations directory and return sorted entries."""
- msg_too_big = 'Migration directory points beyond expected DB version.'
- msg_bad_entry = 'Migration directory contains unexpected entry: '
- msg_missing = 'Migration directory misses migration of number: '
- migrations = {}
- for entry in listdir(MIGRATIONS_DIR):
- if entry == FILENAME_DB_SCHEMA:
- continue
- toks = entry.split('_', 1)
- if len(toks) < 2:
- raise HandledException(msg_bad_entry + entry)
- try:
- i = int(toks[0])
- except ValueError as e:
- raise HandledException(msg_bad_entry + entry) from e
- if i > EXPECTED_DB_VERSION:
- raise HandledException(msg_too_big)
- migrations[i] = toks[1]
- migrations_list = []
- for i in range(EXPECTED_DB_VERSION + 1):
- if i not in migrations:
- raise HandledException(msg_missing + str(i))
- migrations_list += [f'{i}_{migrations[i]}']
- return migrations_list
-
@staticmethod
def _get_version_of_db(path: str) -> int:
"""Get DB user_version, fail if outside expected range."""
- sql_for_db_version = 'PRAGMA user_version'
with sql_connect(path) as conn:
- db_version = list(conn.execute(sql_for_db_version))[0][0]
+ db_version = list(conn.execute(SQL_FOR_DB_VERSION))[0][0]
assert isinstance(db_version, int)
return db_version