from difflib import Differ
from hashlib import file_digest
from pathlib import Path
-from sqlite3 import (connect as sql_connect, Connection as SqlConnection,
- Cursor as SqlCursor, Row as SqlRow)
+from sqlite3 import (
+ connect as sql_connect, Cursor as SqlCursor, Row as SqlRow)
from typing import Callable, Literal, NewType, Optional, Self
# ourselves
from ytplom.primitives import (
PATH_DB = PATH_APP_DATA.joinpath('db.sql')
SqlText = NewType('SqlText', str)
-MigrationsList = list[tuple[Path, Optional[Callable]]]
+MigrationsDict = dict[int, tuple[Optional[Path], Optional[Callable]]]
_SQL_DB_VERSION = SqlText('PRAGMA user_version')
_PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
path: Path = PATH_DB,
version_to_validate: int = EXPECTED_DB_VERSION
) -> None:
- self._path = path
- if not self._path.is_file():
- raise HandledException(f'no DB file at {self._path}')
+ self.path = path
+ if not self.path.is_file():
+ raise HandledException(f'no DB file at {self.path}')
if version_to_validate < 0:
return
f'wrong DB version {user_version} (!= {version_to_validate})')
# ensure schema
- with sql_connect(self._path) as conn:
+ with sql_connect(self.path) as conn:
schema_rows = [
r[0] for r in
conn.execute('SELECT sql FROM sqlite_master ORDER BY sql')
+ '\n'.join(diff_msg))
def _get_user_version(self) -> int:
- with sql_connect(self._path) as conn:
+ with sql_connect(self.path) as conn:
return list(conn.execute(_SQL_DB_VERSION))[0][0]
@staticmethod
conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
- def connect(self) -> SqlConnection:
- """Open database file into SQL connection, with autocommit off."""
- return sql_connect(self._path, autocommit=False)
-
- def migrate(self, migrations: MigrationsList) -> None:
+ def migrate(self, migrations: MigrationsDict) -> None:
"""Migrate self towards EXPECTED_DB_VERSION"""
start_version = self._get_user_version()
- if start_version == EXPECTED_DB_VERSION:
- print('Database at expected version, no migrations to do.')
- return
- if start_version > EXPECTED_DB_VERSION:
+ if start_version >= EXPECTED_DB_VERSION:
raise HandledException(
- f'Cannot migrate backward from version {start_version} to '
- f'{EXPECTED_DB_VERSION}.')
- print(f'Trying to migrate from DB version {start_version} to '
- f'{EXPECTED_DB_VERSION} …')
+ f'Cannot migrate {start_version} to {EXPECTED_DB_VERSION}.')
migs_to_do = []
- migs_by_n = dict(enumerate(migrations))
for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
- if n not in migs_by_n:
+ if n not in migrations:
raise HandledException(f'Needed migration missing: {n}')
- migs_to_do += [(n, *migs_by_n[n])]
- for version, filename_sql, after_sql_steps in migs_to_do:
- print(f'Running migration towards: {version}')
- with DbConn(self.connect()) as conn:
+ mig_tuple = migrations[n]
+ if path := mig_tuple[0]:
+ start_tok = str(path).split('_', maxsplit=1)[0]
+ if (not start_tok.isdigit()) or int(start_tok) != n:
+ raise HandledException(
+ f'migration {n} mapped to bad path {path}')
+ migs_to_do += [(n, *mig_tuple)]
+ with DbConn(self) as conn:
+ for version, filename_sql, after_sql_steps in migs_to_do:
if filename_sql:
- print(f'Executing {filename_sql}')
- path_sql = _PATH_MIGRATIONS.joinpath(filename_sql)
conn.exec_script(
- SqlText(path_sql.read_text(encoding='utf8')))
+ SqlText(_PATH_MIGRATIONS.joinpath(filename_sql)
+ .read_text(encoding='utf8')))
if after_sql_steps:
- print('Running additional steps')
after_sql_steps(conn)
conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}'))
- conn.commit()
- print('Finished migrations.')
+ conn.commit()
class DbConn:
"""Wrapper for sqlite3 connections."""
- def __init__(self, sql_conn: Optional[SqlConnection] = None) -> None:
- self._conn = sql_conn or DbFile().connect()
+ def __init__(self, db_file: Optional[DbFile] = None) -> None:
+ self._conn = sql_connect((db_file or DbFile()).path, autocommit=False)
+ self.commit = self._conn.commit
def __enter__(self) -> Self:
return self
"""Wrapper around sqlite3.Connection.executescript."""
self._conn.executescript(sql)
- def commit(self) -> None:
- """Commit changes (i.e. DbData.save() calls) to database."""
- self._conn.commit()
-
class DbData:
"""Abstraction of common DB operation."""
from pathlib import Path
from typing import Callable
# ourselves
-from ytplom.db import DbConn, DbFile, MigrationsList, SqlText
+from ytplom.db import DbConn, DbFile, MigrationsDict, SqlText
from ytplom.primitives import HandledException
_rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
-_MIGRATIONS: MigrationsList = [
- (Path('0_init.sql'), None),
- (Path('1_add_files_last_updated.sql'), None),
- (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
- (Path('3_files_redo.sql'), None),
- (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
- (Path('5_files_redo.sql'), None),
- (Path('6_add_files_tags.sql'), None)
-]
+_MIGRATIONS: MigrationsDict = {
+ 0: (Path('0_init.sql'), None),
+ 1: (Path('1_add_files_last_updated.sql'), None),
+ 2: (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
+ 3: (Path('3_files_redo.sql'), None),
+ 4: (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+ 5: (Path('5_files_redo.sql'), None),
+ 6: (Path('6_add_files_tags.sql'), None)
+}
def migrate():