From: Christian Heller Date: Tue, 7 Jan 2025 00:26:09 +0000 (+0100) Subject: Before committing migrations check schema validation. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/ledger?a=commitdiff_plain;h=ae54b8a9a187fcb0f3a89917d8bcaa856882cb71;p=plomtask Before committing migrations check schema validation. --- diff --git a/plomtask/db.py b/plomtask/db.py index 2695bda..07c31e9 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -3,7 +3,8 @@ from __future__ import annotations from os import listdir from os.path import isfile from difflib import Differ -from sqlite3 import connect as sql_connect, Cursor, Row +from sqlite3 import ( + connect as sql_connect, Connection as SqlConnection, Cursor, Row) from typing import Any, Self, Callable from plomtask.exceptions import (HandledException, NotFoundException, BadFormatException) @@ -51,6 +52,7 @@ class DatabaseFile: encoding='utf-8') as f: conn.executescript(f.read()) conn.execute(f'PRAGMA user_version = {from_version + j + 1}') + cls._validate_schema(conn) conn.commit() return cls(path) @@ -60,7 +62,8 @@ class DatabaseFile: raise NotFoundException if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION: raise UnmigratedDbException() - self._validate_schema() + with sql_connect(self.path) as conn: + self._validate_schema(conn) @staticmethod def _available_migrations() -> list[str]: @@ -98,17 +101,15 @@ class DatabaseFile: assert isinstance(db_version, int) return db_version - def _validate_schema(self) -> None: + @staticmethod + def _validate_schema(conn: SqlConnection) -> None: """Compare found schema with what's stored at PATH_DB_SCHEMA.""" - # pylint: disable=too-many-locals schema_rows_normed = [] indent = ' ' - with sql_connect(self.path) as conn: - schema_rows = [ - r[0] for r in conn.execute( - 'SELECT sql FROM sqlite_master ORDER BY sql') - if r[0]] - for row in schema_rows: + for row in [ + r[0] for r in conn.execute( + 'SELECT sql FROM sqlite_master ORDER BY sql') + if r[0]]: row_normed = [] for subrow in [sr.rstrip() for sr in row.split('\n')]: in_parentheses = 0