-class DbFile:
- """Wrapper around the file of a sqlite3 database."""
-
- def __init__(self,
- path: Path = PATH_DB,
- version_to_validate: int = EXPECTED_DB_VERSION
- ) -> None:
- self.path = path
- if not self.path.is_file():
- raise HandledException(f'no DB file at {self.path}')
- if version_to_validate < 0:
- return
- if (user_version := self._get_user_version()) != version_to_validate:
- raise HandledException(
- f'wrong DB version {user_version} (!= {version_to_validate})')
- with DbConn(self) as conn:
- self._validate_schema(conn)
-
- @staticmethod
- def _validate_schema(conn: 'DbConn'):
- schema_rows_normed = []
- indent = ' '
- for row in [
- r[0] for r in conn.exec(SqlText(
- 'SELECT sql FROM sqlite_master ORDER BY sql'))
- if r[0]]:
- row_normed = []
- for subrow in [sr.rstrip() for sr in row.split('\n')]:
- in_parentheses = 0
- split_at = []
- for i, c in enumerate(subrow):
- if '(' == c:
- in_parentheses += 1
- elif ')' == c:
- in_parentheses -= 1
- elif ',' == c and 0 == in_parentheses:
- split_at += [i + 1]
- prev_split = 0
- for i in split_at:
- if segment := subrow[prev_split:i].strip():
- row_normed += [f'{indent}{segment}']
- prev_split = i
- if segment := subrow[prev_split:].strip():
- row_normed += [f'{indent}{segment}']
- row_normed[0] = row_normed[0].lstrip() # no indent for opening …
- row_normed[-1] = row_normed[-1].lstrip() # … and closing line
- if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
- row_normed[-3] = row_normed[-3] + ','
- row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
- row_normed[-1] = row_normed[-1] + ';'
- schema_rows_normed += row_normed
- if ((expected_rows :=
- _PATH_DB_SCHEMA.read_text(encoding='utf8').rstrip().splitlines()
- ) != schema_rows_normed):
- raise HandledException(
- 'Unexpected tables schema. Diff to {_PATH_DB_SCHEMA}:\n' +
- '\n'.join(Differ().compare(schema_rows_normed, expected_rows)))
-
- def _get_user_version(self) -> int:
- with sql_connect(self.path) as conn:
- return list(conn.execute(_SQL_DB_VERSION))[0][0]
-
- @staticmethod
- def create(path: Path = PATH_DB) -> None:
- """Create DB file at path according to _PATH_DB_SCHEMA."""
- if path.exists():
- raise HandledException(
- f'There already exists a node at {path}.')
- if not path.parent.is_dir():
- raise HandledException(
- f'No directory {path.parent} found to write into.')
- with sql_connect(path) as conn:
- conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
- conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
-
- def migrate(self, migrations: set['DbMigration']) -> None:
- """Migrate self towards EXPECTED_DB_VERSION"""
- start_version = self._get_user_version()
- if start_version == EXPECTED_DB_VERSION:
- raise HandledException(
- f'Already at {EXPECTED_DB_VERSION}, nothing to migrate.')
- if start_version > EXPECTED_DB_VERSION:
- raise HandledException(
- f'Cannot migrate backwards from {start_version}'
- f'to {EXPECTED_DB_VERSION}.')
- with DbConn(self) as conn:
- for migration in DbMigration.from_to_in_set(
- start_version, EXPECTED_DB_VERSION, migrations):
- migration.perform(conn)
- self._validate_schema(conn)
- conn.commit()
-
-
-class DbMigration:
- """Representation of DbFile migration data."""
-
- def __init__(self,
- version: int,
- sql_path: Optional[Path] = None,
- after_sql_steps: Optional[Callable[['DbConn'], None]] = None
- ) -> None:
- if sql_path:
- start_tok = str(sql_path).split('_', maxsplit=1)[0]
- if (not start_tok.isdigit()) or int(start_tok) != version:
- raise HandledException(
- f'migration {version} mapped to bad path {sql_path}')
- self._version = version
- self._sql_path = sql_path
- self._after_sql_steps = after_sql_steps