-MIGRATION_STEPS_POST_SQL: dict[int, Callable[[SqlConnection], None]] = {
- 6: _mig_6_calc_days_since_millennium
-}
-
-
-class DatabaseFile:
- """Represents the sqlite3 database's file."""
- # pylint: disable=too-few-public-methods
-
- def __init__(self, path: str) -> None:
- self.path = path
- self._check()
-
- @classmethod
- def create_at(cls, path: str) -> DatabaseFile:
- """Make new DB file at path."""
- with sql_connect(path) as conn:
- with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
- conn.executescript(f.read())
- conn.execute(f'{SQL_FOR_DB_VERSION} = {EXPECTED_DB_VERSION}')
- return cls(path)
-
- @classmethod
- def migrate(cls, path: str) -> DatabaseFile:
- """Apply migrations from current version to EXPECTED_DB_VERSION."""
- from_version = cls._get_version_of_db(path)
- if from_version >= EXPECTED_DB_VERSION:
- raise HandledException(
- f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}')
- with sql_connect(path, autocommit=False) as conn:
- for mig in DatabaseMigration.migrations_after(from_version):
- mig.perform(conn)
- cls._validate_schema(conn)
- conn.commit()
- return cls(path)
-
- def _check(self) -> None:
- """Check file exists, and is of proper DB version and schema."""
- if not isfile(self.path):
- raise NotFoundException
- if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION:
- raise UnmigratedDbException()
- with sql_connect(self.path) as conn:
- self._validate_schema(conn)
-
- @staticmethod
- def _get_version_of_db(path: str) -> int:
- """Get DB user_version, fail if outside expected range."""
- with sql_connect(path) as conn:
- db_version = list(conn.execute(SQL_FOR_DB_VERSION))[0][0]
- assert isinstance(db_version, int)
- return db_version
-
- @staticmethod
- def _validate_schema(conn: SqlConnection) -> None:
- """Compare found schema with what's stored at PATH_DB_SCHEMA."""
- schema_rows_normed = []
- indent = ' '
- for row in [
- r[0] for r in conn.execute(
- 'SELECT sql FROM sqlite_master ORDER BY sql')
- if r[0]]:
- row_normed = []
- for subrow in [sr.rstrip() for sr in row.split('\n')]:
- in_parentheses = 0
- split_at = []
- for i, c in enumerate(subrow):
- if '(' == c:
- in_parentheses += 1
- elif ')' == c:
- in_parentheses -= 1
- elif ',' == c and 0 == in_parentheses:
- split_at += [i + 1]
- prev_split = 0
- for i in split_at:
- if segment := subrow[prev_split:i].strip():
- row_normed += [f'{indent}{segment}']
- prev_split = i
- if segment := subrow[prev_split:].strip():
- row_normed += [f'{indent}{segment}']
- row_normed[0] = row_normed[0].lstrip() # no indent for opening …
- row_normed[-1] = row_normed[-1].lstrip() # … and closing line
- if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
- row_normed[-3] = row_normed[-3] + ','
- row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
- row_normed[-1] = row_normed[-1] + ';'
- schema_rows_normed += row_normed
- with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
- expected_rows = f.read().rstrip().splitlines()
- if expected_rows != schema_rows_normed:
- raise HandledException(
- 'Unexpected tables schema. Diff to {path_expected_schema}:\n' +
- '\n'.join(Differ().compare(schema_rows_normed, expected_rows)))
-
-
-class DatabaseConnection: