From: Christian Heller Date: Wed, 15 Jan 2025 14:21:56 +0000 (+0100) Subject: Include plomlib for its db.py, adapt DB code to it. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/unset_cookie?a=commitdiff_plain;h=refs%2Fheads%2Fmaster;p=plomtask Include plomlib for its db.py, adapt DB code to it. --- diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..42cf7f3 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "plomlib"] + path = plomlib + url = https://plomlompom.com/repos/clone/plomlib diff --git a/plomlib b/plomlib new file mode 160000 index 0000000..743dbe0 --- /dev/null +++ b/plomlib @@ -0,0 +1 @@ +Subproject commit 743dbe0d493ddeb47eca981fa5be6d78e4d754c9 diff --git a/plomtask/db.py b/plomtask/db.py index be849b6..cc138ad 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -2,212 +2,88 @@ from __future__ import annotations from datetime import date as dt_date from os import listdir -from os.path import basename, isfile -from difflib import Differ -from sqlite3 import ( - connect as sql_connect, Connection as SqlConnection, Cursor, Row) -from typing import Any, Self, Callable +from pathlib import Path +from sqlite3 import Row +from typing import cast, Any, Self, Callable from plomtask.exceptions import (HandledException, NotFoundException, BadFormatException) +from plomlib.db import ( + PlomDbConn, PlomDbFile, PlomDbMigration, TypePlomDbMigration) -EXPECTED_DB_VERSION = 7 -MIGRATIONS_DIR = 'migrations' -FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql' -PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}' -SQL_FOR_DB_VERSION = 'PRAGMA user_version' +_EXPECTED_DB_VERSION = 7 +_MIGRATIONS_DIR = Path('migrations') +_FILENAME_DB_SCHEMA = f'init_{_EXPECTED_DB_VERSION}.sql' +_PATH_DB_SCHEMA = _MIGRATIONS_DIR.joinpath(_FILENAME_DB_SCHEMA) -class UnmigratedDbException(HandledException): - """To identify case of unmigrated DB file.""" +def _mig_6_calc_days_since_millennium(conn: PlomDbConn) -> None: + rows = conn.exec('SELECT * FROM days').fetchall() + for row in [list(r) for r in rows]: + row[-1] = (dt_date.fromisoformat(row[0]) - dt_date(2000, 1, 1)).days + conn.exec('REPLACE INTO days VALUES', tuple(row)) + +MIGRATION_STEPS_POST_SQL: dict[int, Callable[[PlomDbConn], None]] = { + 6: _mig_6_calc_days_since_millennium +} -class DatabaseMigration: - """Collects Database migration data.""" - def __init__(self, - target_version: int, - sql_path: str, - post_sql_steps: Callable[[SqlConnection], None] | None - ) -> None: - if sql_path: - start_tok = str(basename(sql_path)).split('_', maxsplit=1)[0] - if (not start_tok.isdigit()) or int(start_tok) != target_version: - raise HandledException(f'migration to {target_version} mapped ' - f'to bad path {sql_path}') - self._target_version = target_version - self._sql_path = sql_path - self._post_sql_steps = post_sql_steps +class DatabaseMigration(PlomDbMigration): + """Collects and enacts DatabaseFile migration commands.""" + migs_dir_path = _MIGRATIONS_DIR @classmethod - def migrations_after(cls, starting_from: int) -> list[Self]: - """Make sorted unbroken list of available migrations >starting_from.""" + def gather(cls, from_version: int, base_set: set[TypePlomDbMigration] + ) -> list[TypePlomDbMigration]: msg_prefix = 'Migration directory contains' msg_bad_entry = f'{msg_prefix} unexpected entry: ' migs = [] total_migs = set() post_sql_steps_added = set() - for entry in [e for e in listdir(MIGRATIONS_DIR) - if e != FILENAME_DB_SCHEMA]: + for entry in [e for e in listdir(cls.migs_dir_path) + if e != _FILENAME_DB_SCHEMA]: + path = cls.migs_dir_path.joinpath(entry) + if not path.is_file(): + continue toks = entry.split('_', maxsplit=1) if len(toks) < 2 or (not toks[0].isdigit()): raise HandledException(f'{msg_bad_entry}{entry}') i = int(toks[0]) - if i <= starting_from: + if i <= from_version: continue - if i > EXPECTED_DB_VERSION: - raise HandledException(f'{msg_prefix} uexpected version {i}') + if i > _EXPECTED_DB_VERSION: + raise HandledException(f'{msg_prefix} unexpected version {i}') post_sql_steps = MIGRATION_STEPS_POST_SQL.get(i, None) if post_sql_steps: post_sql_steps_added.add(i) - total_migs.add( - cls(i, f'{MIGRATIONS_DIR}/{entry}', post_sql_steps)) + total_migs.add(cls(i, Path(entry), post_sql_steps)) for k in [k for k in MIGRATION_STEPS_POST_SQL - if k > starting_from + if k > from_version and k not in post_sql_steps_added]: - total_migs.add(cls(k, '', MIGRATION_STEPS_POST_SQL[k])) - for i in range(starting_from + 1, EXPECTED_DB_VERSION + 1): - # pylint: disable=protected-access - migs_found = [m for m in total_migs if m._target_version == i] + total_migs.add(cls(k, None, MIGRATION_STEPS_POST_SQL[k])) + for i in range(from_version + 1, _EXPECTED_DB_VERSION + 1): + migs_found = [m for m in total_migs if m.target_version == i] if not migs_found: raise HandledException(f'{msg_prefix} no migration of v. {i}') if len(migs_found) > 1: raise HandledException(f'{msg_prefix} >1 migration of v. {i}') migs += migs_found - return migs - - def perform(self, conn: SqlConnection) -> None: - """Do 1) script at sql_path, 2) post_sql_steps, 3) version setting.""" - if self._sql_path: - with open(self._sql_path, 'r', encoding='utf8') as f: - conn.executescript(f.read()) - if self._post_sql_steps: - self._post_sql_steps(conn) - conn.execute(f'{SQL_FOR_DB_VERSION} = {self._target_version}') + return cast(list[TypePlomDbMigration], migs) -def _mig_6_calc_days_since_millennium(conn: SqlConnection) -> None: - rows = conn.execute('SELECT * FROM days').fetchall() - for row in [list(r) for r in rows]: - row[-1] = (dt_date.fromisoformat(row[0]) - dt_date(2000, 1, 1)).days - conn.execute('REPLACE INTO days VALUES (?, ?, ?)', tuple(row)) +class DatabaseFile(PlomDbFile): + """File readable as DB of expected schema, user version.""" + target_version = _EXPECTED_DB_VERSION + path_schema = _PATH_DB_SCHEMA + mig_class = DatabaseMigration -MIGRATION_STEPS_POST_SQL: dict[int, Callable[[SqlConnection], None]] = { - 6: _mig_6_calc_days_since_millennium -} - - -class DatabaseFile: - """Represents the sqlite3 database's file.""" - # pylint: disable=too-few-public-methods - - def __init__(self, path: str) -> None: - self.path = path - self._check() - - @classmethod - def create_at(cls, path: str) -> DatabaseFile: - """Make new DB file at path.""" - with sql_connect(path) as conn: - with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: - conn.executescript(f.read()) - conn.execute(f'{SQL_FOR_DB_VERSION} = {EXPECTED_DB_VERSION}') - return cls(path) - - @classmethod - def migrate(cls, path: str) -> DatabaseFile: - """Apply migrations from current version to EXPECTED_DB_VERSION.""" - from_version = cls._get_version_of_db(path) - if from_version >= EXPECTED_DB_VERSION: - raise HandledException( - f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}') - with sql_connect(path, autocommit=False) as conn: - for mig in DatabaseMigration.migrations_after(from_version): - mig.perform(conn) - cls._validate_schema(conn) - conn.commit() - return cls(path) - - def _check(self) -> None: - """Check file exists, and is of proper DB version and schema.""" - if not isfile(self.path): - raise NotFoundException - if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION: - raise UnmigratedDbException() - with sql_connect(self.path) as conn: - self._validate_schema(conn) - - @staticmethod - def _get_version_of_db(path: str) -> int: - """Get DB user_version, fail if outside expected range.""" - with sql_connect(path) as conn: - db_version = list(conn.execute(SQL_FOR_DB_VERSION))[0][0] - assert isinstance(db_version, int) - return db_version - - @staticmethod - def _validate_schema(conn: SqlConnection) -> None: - """Compare found schema with what's stored at PATH_DB_SCHEMA.""" - schema_rows_normed = [] - indent = ' ' - for row in [ - r[0] for r in conn.execute( - 'SELECT sql FROM sqlite_master ORDER BY sql') - if r[0]]: - row_normed = [] - for subrow in [sr.rstrip() for sr in row.split('\n')]: - in_parentheses = 0 - split_at = [] - for i, c in enumerate(subrow): - if '(' == c: - in_parentheses += 1 - elif ')' == c: - in_parentheses -= 1 - elif ',' == c and 0 == in_parentheses: - split_at += [i + 1] - prev_split = 0 - for i in split_at: - if segment := subrow[prev_split:i].strip(): - row_normed += [f'{indent}{segment}'] - prev_split = i - if segment := subrow[prev_split:].strip(): - row_normed += [f'{indent}{segment}'] - row_normed[0] = row_normed[0].lstrip() # no indent for opening … - row_normed[-1] = row_normed[-1].lstrip() # … and closing line - if row_normed[-1] != ')' and row_normed[-3][-1] != ',': - row_normed[-3] = row_normed[-3] + ',' - row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')'] - row_normed[-1] = row_normed[-1] + ';' - schema_rows_normed += row_normed - with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: - expected_rows = f.read().rstrip().splitlines() - if expected_rows != schema_rows_normed: - raise HandledException( - 'Unexpected tables schema. Diff to {path_expected_schema}:\n' + - '\n'.join(Differ().compare(schema_rows_normed, expected_rows))) - - -class DatabaseConnection: +class DatabaseConnection(PlomDbConn): """A single connection to the database.""" - def __init__(self, db_file: DatabaseFile) -> None: - self._conn = sql_connect(db_file.path, autocommit=False) - self.commit = self._conn.commit - self.close = self._conn.close - - def exec(self, - code: str, - inputs: tuple[Any, ...] = tuple(), - build_q_marks: bool = True - ) -> Cursor: - """Wrapper around sqlite3.Connection.execute, building '?' if inputs""" - if len(inputs) > 0: - if build_q_marks: - q_marks = ('?' if len(inputs) == 1 - else '(' + ','.join(['?'] * len(inputs)) + ')') - return self._conn.execute(f'{code} {q_marks}', inputs) - return self._conn.execute(code, inputs) - return self._conn.execute(code) + def close(self) -> None: + """Shortcut to sqlite3.Connection.close().""" + self._conn.close() def rewrite_relations(self, table_name: str, key: str, target: int | str, rows: list[list[Any]], key_index: int = 0) -> None: diff --git a/run.py b/run.py index c69dc6a..0d50d25 100755 --- a/run.py +++ b/run.py @@ -2,9 +2,11 @@ """Call this to start the application.""" from sys import exit as sys_exit from os import environ -from plomtask.exceptions import HandledException, NotFoundException +from pathlib import Path +from plomtask.exceptions import HandledException from plomtask.http import TaskHandler, TaskServer -from plomtask.db import DatabaseFile, UnmigratedDbException +from plomtask.db import DatabaseFile +from plomlib.db import PlomDbException PLOMTASK_DB_PATH = environ.get('PLOMTASK_DB_PATH') HTTP_PORT = 8082 @@ -24,21 +26,27 @@ if __name__ == '__main__': try: if not PLOMTASK_DB_PATH: raise HandledException('PLOMTASK_DB_PATH not set.') + db_path = Path(PLOMTASK_DB_PATH) try: - db_file = DatabaseFile(PLOMTASK_DB_PATH) - except NotFoundException: - yes_or_fail(DB_CREATION_ASK, 'Cannot run without DB.') - db_file = DatabaseFile.create_at(PLOMTASK_DB_PATH) - except UnmigratedDbException: - yes_or_fail(DB_MIGRATE_ASK, 'Cannot run with unmigrated DB.') - db_file = DatabaseFile.migrate(PLOMTASK_DB_PATH) - server = TaskServer(db_file, ('localhost', HTTP_PORT), TaskHandler) - print(f'running at port {HTTP_PORT}') - try: - server.serve_forever() - except KeyboardInterrupt: - print('aborting due to keyboard interrupt') - server.server_close() + db_file = DatabaseFile(db_path) + except PlomDbException as e: + if e.name == 'no_is_file': + yes_or_fail(DB_CREATION_ASK, 'Cannot run without DB.') + DatabaseFile.create(db_path) + elif e.name == 'bad_version': + yes_or_fail(DB_MIGRATE_ASK, 'Cannot run with unmigrated DB.') + db_file = DatabaseFile(db_path, skip_validations=True) + db_file.migrate(set()) + else: + raise e + else: + server = TaskServer(db_file, ('localhost', HTTP_PORT), TaskHandler) + print(f'running at port {HTTP_PORT}') + try: + server.serve_forever() + except KeyboardInterrupt: + print('aborting due to keyboard interrupt') + server.server_close() except HandledException as e: print(f'Aborting because: {e}') sys_exit(1) diff --git a/scripts/pre-commit b/scripts/pre-commit index 7abafb9..0dd4d45 100755 --- a/scripts/pre-commit +++ b/scripts/pre-commit @@ -2,7 +2,7 @@ set -e for dir in $(echo '.' 'plomtask' 'tests'); do echo "Running mypy on ${dir}/ …." - python3 -m mypy --strict ${dir}/*.py + python3 -m mypy ${dir}/*.py echo "Running flake8 on ${dir}/ …" python3 -m flake8 ${dir}/*.py echo "Running pylint on ${dir}/ …" diff --git a/tests/utils.py b/tests/utils.py index dd7dddc..4882ab3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -5,6 +5,7 @@ from datetime import datetime, date as dt_date, timedelta from unittest import TestCase from typing import Mapping, Any, Callable from threading import Thread +from pathlib import Path from http.client import HTTPConnection from time import sleep from json import loads as json_loads, dumps as json_dumps @@ -195,7 +196,9 @@ class TestCaseWithDB(TestCaseAugmented): Process.empty_cache() ProcessStep.empty_cache() Todo.empty_cache() - self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}') + db_path = Path(f'test_db:{uuid4()}') + DatabaseFile.create(db_path) + self.db_file = DatabaseFile(db_path) self.db_conn = DatabaseConnection(self.db_file) def tearDown(self) -> None: