+#!/usr/bin/env python3
+"""Script to migrate DB to most recent schema."""
+from sys import exit as sys_exit
+from os import scandir
+from os.path import basename, isfile
+from sqlite3 import connect as sql_connect
+from ytplom.misc import (
+ EXPECTED_DB_VERSION, PATH_DB, PATH_DB_SCHEMA, PATH_MIGRATIONS,
+ SQL_DB_VERSION, HandledException, get_db_version)
+
+
+def main() -> None:
+ """Try to migrate DB towards EXPECTED_DB_VERSION."""
+ start_version = get_db_version(PATH_DB)
+ if start_version == EXPECTED_DB_VERSION:
+ print('Database at expected version, no migrations to do.')
+ sys_exit(0)
+ elif start_version > EXPECTED_DB_VERSION:
+ raise HandledException(
+ f'Cannot migrate backward from version {start_version} to '
+ f'{EXPECTED_DB_VERSION}.')
+ print(f'Trying to migrate from DB version {start_version} to '
+ f'{EXPECTED_DB_VERSION} …')
+ needed = [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]
+ migrations = {}
+ for entry in [entry for entry in scandir(PATH_MIGRATIONS)
+ if isfile(entry) and entry.path != PATH_DB_SCHEMA]:
+ toks = basename(entry.path).split('_')
+ try:
+ version = int(toks[0])
+ except ValueError as e:
+ msg = f'Found illegal migration path {entry.path}, aborting.'
+ raise HandledException(msg) from e
+ if version in needed:
+ migrations[version] = entry.path
+ missing = [n for n in needed if n not in migrations]
+ if missing:
+ raise HandledException(f'Needed migrations missing: {missing}')
+ with sql_connect(PATH_DB) as conn:
+ for version_number, migration_path in migrations.items():
+ print(f'Applying migration {version_number}: {migration_path}')
+ with open(migration_path, 'r', encoding='utf8') as f:
+ conn.executescript(f.read())
+ conn.execute(f'{SQL_DB_VERSION} = {version_number}')
+
+
+if __name__ == '__main__':
+ main()