X-Git-Url: https://plomlompom.com/repos/feed.xml?a=blobdiff_plain;f=plomtask%2Fdb.py;h=7962eabeffd28964c0892b87f7ce35e6052a2f3e;hb=83266154e9140151c975586d21f393a5eb3f4ef4;hp=e4d5f6e9a42fe4726c5dff35e9488e13b61eea13;hpb=206a9111fdc95fcb24ae4793a7536e1facf82b71;p=plomtask diff --git a/plomtask/db.py b/plomtask/db.py index e4d5f6e..7962eab 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -1,13 +1,20 @@ """Database management.""" from __future__ import annotations +from os import listdir from os.path import isfile from difflib import Differ from sqlite3 import connect as sql_connect, Cursor, Row from typing import Any, Self, TypeVar, Generic from plomtask.exceptions import HandledException, NotFoundException -PATH_DB_SCHEMA = 'scripts/init.sql' -EXPECTED_DB_VERSION = 0 +EXPECTED_DB_VERSION = 1 +MIGRATIONS_DIR = 'migrations' +FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql' +PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}' + + +class UnmigratedDbException(HandledException): + """To identify case of unmigrated DB file.""" class DatabaseFile: # pylint: disable=too-few-public-methods @@ -17,43 +24,128 @@ class DatabaseFile: # pylint: disable=too-few-public-methods self.path = path self._check() - def remake(self) -> None: - """Create tables in self.path file as per PATH_DB_SCHEMA sql file.""" - with sql_connect(self.path) as conn: + @classmethod + def create_at(cls, path: str) -> DatabaseFile: + """Make new DB file at path.""" + with sql_connect(path) as conn: with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: conn.executescript(f.read()) - self._check() + conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}') + return cls(path) + + @classmethod + def migrate(cls, path: str) -> DatabaseFile: + """Apply migrations from_version to EXPECTED_DB_VERSION.""" + migrations = cls._available_migrations() + from_version = cls.get_version_of_db(path) + migrations_todo = migrations[from_version+1:] + for j, filename in enumerate(migrations_todo): + with sql_connect(path) as conn: + with open(f'{MIGRATIONS_DIR}/{filename}', 'r', + encoding='utf-8') as f: + conn.executescript(f.read()) + user_version = from_version + j + 1 + with sql_connect(path) as conn: + conn.execute(f'PRAGMA user_version = {user_version}') + return cls(path) def _check(self) -> None: """Check file exists, and is of proper DB version and schema.""" - self.exists = isfile(self.path) - if self.exists: - self._validate_user_version() - self._validate_schema() + if not isfile(self.path): + raise NotFoundException + if self.user_version != EXPECTED_DB_VERSION: + raise UnmigratedDbException() + self._validate_schema() + + @staticmethod + def _available_migrations() -> list[str]: + """Validate migrations directory and return sorted entries.""" + msg_too_big = 'Migration directory points beyond expected DB version.' + msg_bad_entry = 'Migration directory contains unexpected entry: ' + msg_missing = 'Migration directory misses migration of number: ' + migrations = {} + for entry in listdir(MIGRATIONS_DIR): + if entry == FILENAME_DB_SCHEMA: + continue + toks = entry.split('_', 1) + if len(toks) < 2: + raise HandledException(msg_bad_entry + entry) + try: + i = int(toks[0]) + except ValueError as e: + raise HandledException(msg_bad_entry + entry) from e + if i > EXPECTED_DB_VERSION: + raise HandledException(msg_too_big) + migrations[i] = toks[1] + migrations_list = [] + for i in range(EXPECTED_DB_VERSION + 1): + if i not in migrations: + raise HandledException(msg_missing + str(i)) + migrations_list += [f'{i}_{migrations[i]}'] + return migrations_list - def _validate_user_version(self) -> None: - """Compare DB user_version with EXPECTED_DB_VERSION.""" + @staticmethod + def get_version_of_db(path: str) -> int: + """Get DB user_version, fail if outside expected range.""" sql_for_db_version = 'PRAGMA user_version' - with sql_connect(self.path) as conn: + with sql_connect(path) as conn: db_version = list(conn.execute(sql_for_db_version))[0][0] - if db_version != EXPECTED_DB_VERSION: - msg = f'Wrong DB version, expected '\ - f'{EXPECTED_DB_VERSION}, got {db_version}.' - raise HandledException(msg) + if db_version > EXPECTED_DB_VERSION: + msg = f'Wrong DB version, expected '\ + f'{EXPECTED_DB_VERSION}, got unknown {db_version}.' + raise HandledException(msg) + assert isinstance(db_version, int) + return db_version + + @property + def user_version(self) -> int: + """Get DB user_version.""" + return self.__class__.get_version_of_db(self.path) def _validate_schema(self) -> None: """Compare found schema with what's stored at PATH_DB_SCHEMA.""" + + def reformat_rows(rows: list[str]) -> list[str]: + new_rows = [] + for row in rows: + new_row = [] + for subrow in row.split('\n'): + subrow = subrow.rstrip() + in_parentheses = 0 + split_at = [] + for i, c in enumerate(subrow): + if '(' == c: + in_parentheses += 1 + elif ')' == c: + in_parentheses -= 1 + elif ',' == c and 0 == in_parentheses: + split_at += [i + 1] + prev_split = 0 + for i in split_at: + segment = subrow[prev_split:i].strip() + if len(segment) > 0: + new_row += [f' {segment}'] + prev_split = i + segment = subrow[prev_split:].strip() + if len(segment) > 0: + new_row += [f' {segment}'] + new_row[0] = new_row[0].lstrip() + new_row[-1] = new_row[-1].lstrip() + new_rows += ['\n'.join(new_row)] + return new_rows + sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql' msg_err = 'Database has wrong tables schema. Diff:\n' with sql_connect(self.path) as conn: schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]] - retrieved_schema = ';\n'.join(schema_rows) + ';' - with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: - stored_schema = f.read().rstrip() - if stored_schema != retrieved_schema: - diff_msg = Differ().compare(retrieved_schema.splitlines(), - stored_schema.splitlines()) - raise HandledException(msg_err + '\n'.join(diff_msg)) + schema_rows = reformat_rows(schema_rows) + retrieved_schema = ';\n'.join(schema_rows) + ';' + with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: + stored_schema = f.read().rstrip() + if stored_schema != retrieved_schema: + diff_msg = Differ().compare(retrieved_schema.splitlines(), + stored_schema.splitlines()) + raise HandledException(msg_err + '\n'.join(diff_msg)) class DatabaseConnection: