From ae54b8a9a187fcb0f3a89917d8bcaa856882cb71 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Tue, 7 Jan 2025 01:26:09 +0100
Subject: [PATCH] Before committing migrations check schema validation.

---
 plomtask/db.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/plomtask/db.py b/plomtask/db.py
index 2695bda..07c31e9 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -3,7 +3,8 @@ from __future__ import annotations
 from os import listdir
 from os.path import isfile
 from difflib import Differ
-from sqlite3 import connect as sql_connect, Cursor, Row
+from sqlite3 import (
+        connect as sql_connect, Connection as SqlConnection, Cursor, Row)
 from typing import Any, Self, Callable
 from plomtask.exceptions import (HandledException, NotFoundException,
                                  BadFormatException)
@@ -51,6 +52,7 @@ class DatabaseFile:
                           encoding='utf-8') as f:
                     conn.executescript(f.read())
                 conn.execute(f'PRAGMA user_version = {from_version + j + 1}')
+            cls._validate_schema(conn)
             conn.commit()
         return cls(path)
 
@@ -60,7 +62,8 @@ class DatabaseFile:
             raise NotFoundException
         if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
-        self._validate_schema()
+        with sql_connect(self.path) as conn:
+            self._validate_schema(conn)
 
     @staticmethod
     def _available_migrations() -> list[str]:
@@ -98,17 +101,15 @@ class DatabaseFile:
         assert isinstance(db_version, int)
         return db_version
 
-    def _validate_schema(self) -> None:
+    @staticmethod
+    def _validate_schema(conn: SqlConnection) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
-        # pylint: disable=too-many-locals
         schema_rows_normed = []
         indent = '    '
-        with sql_connect(self.path) as conn:
-            schema_rows = [
-                    r[0] for r in conn.execute(
-                        'SELECT sql FROM sqlite_master ORDER BY sql')
-                    if r[0]]
-        for row in schema_rows:
+        for row in [
+                r[0] for r in conn.execute(
+                    'SELECT sql FROM sqlite_master ORDER BY sql')
+                if r[0]]:
             row_normed = []
             for subrow in [sr.rstrip() for sr in row.split('\n')]:
                 in_parentheses = 0
-- 
2.30.2