"""Database management."""
from __future__ import annotations
from os import listdir
-from os.path import isfile
+from os.path import basename, isfile
from difflib import Differ
from sqlite3 import (
connect as sql_connect, Connection as SqlConnection, Cursor, Row)
post_sql_steps: Callable[[SqlConnection], None] | None
) -> None:
if sql_path:
- start_tok = str(sql_path).split('_', maxsplit=1)[0]
+ start_tok = str(basename(sql_path)).split('_', maxsplit=1)[0]
if (not start_tok.isdigit()) or int(start_tok) != target_version:
raise HandledException(f'migration to {target_version} mapped '
f'to bad path {sql_path}')
return migs
def perform(self, conn: SqlConnection) -> None:
- """Do 1) script at sql_path, 2) after_sql_steps, 3) version setting."""
+ """Do 1) script at sql_path, 2) post_sql_steps, 3) version setting."""
if self._sql_path:
with open(self._sql_path, 'r', encoding='utf8') as f:
conn.executescript(f.read())