from __future__ import annotations
from datetime import date as dt_date
from os import listdir
-from os.path import basename, isfile
-from difflib import Differ
-from sqlite3 import (
- connect as sql_connect, Connection as SqlConnection, Cursor, Row)
-from typing import Any, Self, Callable
+from pathlib import Path
+from sqlite3 import Row
+from typing import cast, Any, Self, Callable
from plomtask.exceptions import (HandledException, NotFoundException,
BadFormatException)
+from plomlib.db import (
+ PlomDbConn, PlomDbFile, PlomDbMigration, TypePlomDbMigration)
-EXPECTED_DB_VERSION = 7
-MIGRATIONS_DIR = 'migrations'
-FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
-PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
-SQL_FOR_DB_VERSION = 'PRAGMA user_version'
+_EXPECTED_DB_VERSION = 7
+_MIGRATIONS_DIR = Path('migrations')
+_FILENAME_DB_SCHEMA = f'init_{_EXPECTED_DB_VERSION}.sql'
+_PATH_DB_SCHEMA = _MIGRATIONS_DIR.joinpath(_FILENAME_DB_SCHEMA)
-class UnmigratedDbException(HandledException):
- """To identify case of unmigrated DB file."""
+def _mig_6_calc_days_since_millennium(conn: PlomDbConn) -> None:
+ rows = conn.exec('SELECT * FROM days').fetchall()
+ for row in [list(r) for r in rows]:
+ row[-1] = (dt_date.fromisoformat(row[0]) - dt_date(2000, 1, 1)).days
+ conn.exec('REPLACE INTO days VALUES', tuple(row))
+
+MIGRATION_STEPS_POST_SQL: dict[int, Callable[[PlomDbConn], None]] = {
+ 6: _mig_6_calc_days_since_millennium
+}
-class DatabaseMigration:
- """Collects Database migration data."""
- def __init__(self,
- target_version: int,
- sql_path: str,
- post_sql_steps: Callable[[SqlConnection], None] | None
- ) -> None:
- if sql_path:
- start_tok = str(basename(sql_path)).split('_', maxsplit=1)[0]
- if (not start_tok.isdigit()) or int(start_tok) != target_version:
- raise HandledException(f'migration to {target_version} mapped '
- f'to bad path {sql_path}')
- self._target_version = target_version
- self._sql_path = sql_path
- self._post_sql_steps = post_sql_steps
+class DatabaseMigration(PlomDbMigration):
+ """Collects and enacts DatabaseFile migration commands."""
+ migs_dir_path = _MIGRATIONS_DIR
@classmethod
- def migrations_after(cls, starting_from: int) -> list[Self]:
- """Make sorted unbroken list of available migrations >starting_from."""
+ def gather(cls, from_version: int, base_set: set[TypePlomDbMigration]
+ ) -> list[TypePlomDbMigration]:
msg_prefix = 'Migration directory contains'
msg_bad_entry = f'{msg_prefix} unexpected entry: '
migs = []
total_migs = set()
post_sql_steps_added = set()
- for entry in [e for e in listdir(MIGRATIONS_DIR)
- if e != FILENAME_DB_SCHEMA]:
+ for entry in [e for e in listdir(cls.migs_dir_path)
+ if e != _FILENAME_DB_SCHEMA]:
+ path = cls.migs_dir_path.joinpath(entry)
+ if not path.is_file():
+ continue
toks = entry.split('_', maxsplit=1)
if len(toks) < 2 or (not toks[0].isdigit()):
raise HandledException(f'{msg_bad_entry}{entry}')
i = int(toks[0])
- if i <= starting_from:
+ if i <= from_version:
continue
- if i > EXPECTED_DB_VERSION:
- raise HandledException(f'{msg_prefix} uexpected version {i}')
+ if i > _EXPECTED_DB_VERSION:
+ raise HandledException(f'{msg_prefix} unexpected version {i}')
post_sql_steps = MIGRATION_STEPS_POST_SQL.get(i, None)
if post_sql_steps:
post_sql_steps_added.add(i)
- total_migs.add(
- cls(i, f'{MIGRATIONS_DIR}/{entry}', post_sql_steps))
+ total_migs.add(cls(i, Path(entry), post_sql_steps))
for k in [k for k in MIGRATION_STEPS_POST_SQL
- if k > starting_from
+ if k > from_version
and k not in post_sql_steps_added]:
- total_migs.add(cls(k, '', MIGRATION_STEPS_POST_SQL[k]))
- for i in range(starting_from + 1, EXPECTED_DB_VERSION + 1):
- # pylint: disable=protected-access
- migs_found = [m for m in total_migs if m._target_version == i]
+ total_migs.add(cls(k, None, MIGRATION_STEPS_POST_SQL[k]))
+ for i in range(from_version + 1, _EXPECTED_DB_VERSION + 1):
+ migs_found = [m for m in total_migs if m.target_version == i]
if not migs_found:
raise HandledException(f'{msg_prefix} no migration of v. {i}')
if len(migs_found) > 1:
raise HandledException(f'{msg_prefix} >1 migration of v. {i}')
migs += migs_found
- return migs
-
- def perform(self, conn: SqlConnection) -> None:
- """Do 1) script at sql_path, 2) post_sql_steps, 3) version setting."""
- if self._sql_path:
- with open(self._sql_path, 'r', encoding='utf8') as f:
- conn.executescript(f.read())
- if self._post_sql_steps:
- self._post_sql_steps(conn)
- conn.execute(f'{SQL_FOR_DB_VERSION} = {self._target_version}')
+ return cast(list[TypePlomDbMigration], migs)
-def _mig_6_calc_days_since_millennium(conn: SqlConnection) -> None:
- rows = conn.execute('SELECT * FROM days').fetchall()
- for row in [list(r) for r in rows]:
- row[-1] = (dt_date.fromisoformat(row[0]) - dt_date(2000, 1, 1)).days
- conn.execute('REPLACE INTO days VALUES (?, ?, ?)', tuple(row))
+class DatabaseFile(PlomDbFile):
+ """File readable as DB of expected schema, user version."""
+ target_version = _EXPECTED_DB_VERSION
+ path_schema = _PATH_DB_SCHEMA
+ mig_class = DatabaseMigration
-MIGRATION_STEPS_POST_SQL: dict[int, Callable[[SqlConnection], None]] = {
- 6: _mig_6_calc_days_since_millennium
-}
-
-
-class DatabaseFile:
- """Represents the sqlite3 database's file."""
- # pylint: disable=too-few-public-methods
-
- def __init__(self, path: str) -> None:
- self.path = path
- self._check()
-
- @classmethod
- def create_at(cls, path: str) -> DatabaseFile:
- """Make new DB file at path."""
- with sql_connect(path) as conn:
- with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
- conn.executescript(f.read())
- conn.execute(f'{SQL_FOR_DB_VERSION} = {EXPECTED_DB_VERSION}')
- return cls(path)
-
- @classmethod
- def migrate(cls, path: str) -> DatabaseFile:
- """Apply migrations from current version to EXPECTED_DB_VERSION."""
- from_version = cls._get_version_of_db(path)
- if from_version >= EXPECTED_DB_VERSION:
- raise HandledException(
- f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}')
- with sql_connect(path, autocommit=False) as conn:
- for mig in DatabaseMigration.migrations_after(from_version):
- mig.perform(conn)
- cls._validate_schema(conn)
- conn.commit()
- return cls(path)
-
- def _check(self) -> None:
- """Check file exists, and is of proper DB version and schema."""
- if not isfile(self.path):
- raise NotFoundException
- if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION:
- raise UnmigratedDbException()
- with sql_connect(self.path) as conn:
- self._validate_schema(conn)
-
- @staticmethod
- def _get_version_of_db(path: str) -> int:
- """Get DB user_version, fail if outside expected range."""
- with sql_connect(path) as conn:
- db_version = list(conn.execute(SQL_FOR_DB_VERSION))[0][0]
- assert isinstance(db_version, int)
- return db_version
-
- @staticmethod
- def _validate_schema(conn: SqlConnection) -> None:
- """Compare found schema with what's stored at PATH_DB_SCHEMA."""
- schema_rows_normed = []
- indent = ' '
- for row in [
- r[0] for r in conn.execute(
- 'SELECT sql FROM sqlite_master ORDER BY sql')
- if r[0]]:
- row_normed = []
- for subrow in [sr.rstrip() for sr in row.split('\n')]:
- in_parentheses = 0
- split_at = []
- for i, c in enumerate(subrow):
- if '(' == c:
- in_parentheses += 1
- elif ')' == c:
- in_parentheses -= 1
- elif ',' == c and 0 == in_parentheses:
- split_at += [i + 1]
- prev_split = 0
- for i in split_at:
- if segment := subrow[prev_split:i].strip():
- row_normed += [f'{indent}{segment}']
- prev_split = i
- if segment := subrow[prev_split:].strip():
- row_normed += [f'{indent}{segment}']
- row_normed[0] = row_normed[0].lstrip() # no indent for opening …
- row_normed[-1] = row_normed[-1].lstrip() # … and closing line
- if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
- row_normed[-3] = row_normed[-3] + ','
- row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
- row_normed[-1] = row_normed[-1] + ';'
- schema_rows_normed += row_normed
- with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
- expected_rows = f.read().rstrip().splitlines()
- if expected_rows != schema_rows_normed:
- raise HandledException(
- 'Unexpected tables schema. Diff to {path_expected_schema}:\n' +
- '\n'.join(Differ().compare(schema_rows_normed, expected_rows)))
-
-
-class DatabaseConnection:
+class DatabaseConnection(PlomDbConn):
"""A single connection to the database."""
- def __init__(self, db_file: DatabaseFile) -> None:
- self._conn = sql_connect(db_file.path, autocommit=False)
- self.commit = self._conn.commit
- self.close = self._conn.close
-
- def exec(self,
- code: str,
- inputs: tuple[Any, ...] = tuple(),
- build_q_marks: bool = True
- ) -> Cursor:
- """Wrapper around sqlite3.Connection.execute, building '?' if inputs"""
- if len(inputs) > 0:
- if build_q_marks:
- q_marks = ('?' if len(inputs) == 1
- else '(' + ','.join(['?'] * len(inputs)) + ')')
- return self._conn.execute(f'{code} {q_marks}', inputs)
- return self._conn.execute(code, inputs)
- return self._conn.execute(code)
+ def close(self) -> None:
+ """Shortcut to sqlite3.Connection.close()."""
+ self._conn.close()
def rewrite_relations(self, table_name: str, key: str, target: int | str,
rows: list[list[Any]], key_index: int = 0) -> None:
"""Call this to start the application."""
from sys import exit as sys_exit
from os import environ
-from plomtask.exceptions import HandledException, NotFoundException
+from pathlib import Path
+from plomtask.exceptions import HandledException
from plomtask.http import TaskHandler, TaskServer
-from plomtask.db import DatabaseFile, UnmigratedDbException
+from plomtask.db import DatabaseFile
+from plomlib.db import PlomDbException
PLOMTASK_DB_PATH = environ.get('PLOMTASK_DB_PATH')
HTTP_PORT = 8082
try:
if not PLOMTASK_DB_PATH:
raise HandledException('PLOMTASK_DB_PATH not set.')
+ db_path = Path(PLOMTASK_DB_PATH)
try:
- db_file = DatabaseFile(PLOMTASK_DB_PATH)
- except NotFoundException:
- yes_or_fail(DB_CREATION_ASK, 'Cannot run without DB.')
- db_file = DatabaseFile.create_at(PLOMTASK_DB_PATH)
- except UnmigratedDbException:
- yes_or_fail(DB_MIGRATE_ASK, 'Cannot run with unmigrated DB.')
- db_file = DatabaseFile.migrate(PLOMTASK_DB_PATH)
- server = TaskServer(db_file, ('localhost', HTTP_PORT), TaskHandler)
- print(f'running at port {HTTP_PORT}')
- try:
- server.serve_forever()
- except KeyboardInterrupt:
- print('aborting due to keyboard interrupt')
- server.server_close()
+ db_file = DatabaseFile(db_path)
+ except PlomDbException as e:
+ if e.name == 'no_is_file':
+ yes_or_fail(DB_CREATION_ASK, 'Cannot run without DB.')
+ DatabaseFile.create(db_path)
+ elif e.name == 'bad_version':
+ yes_or_fail(DB_MIGRATE_ASK, 'Cannot run with unmigrated DB.')
+ db_file = DatabaseFile(db_path, skip_validations=True)
+ db_file.migrate(set())
+ else:
+ raise e
+ else:
+ server = TaskServer(db_file, ('localhost', HTTP_PORT), TaskHandler)
+ print(f'running at port {HTTP_PORT}')
+ try:
+ server.serve_forever()
+ except KeyboardInterrupt:
+ print('aborting due to keyboard interrupt')
+ server.server_close()
except HandledException as e:
print(f'Aborting because: {e}')
sys_exit(1)