home · contact · privacy
Before committing migrations check schema validation.
authorChristian Heller <c.heller@plomlompom.de>
Tue, 7 Jan 2025 00:26:09 +0000 (01:26 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Tue, 7 Jan 2025 00:26:09 +0000 (01:26 +0100)
plomtask/db.py

index 2695bda3964c1933182e338d7e0756615d59518a..07c31e9016ba2fda748aba11388a87e9c3a799e2 100644 (file)
@@ -3,7 +3,8 @@ from __future__ import annotations
 from os import listdir
 from os.path import isfile
 from difflib import Differ
-from sqlite3 import connect as sql_connect, Cursor, Row
+from sqlite3 import (
+        connect as sql_connect, Connection as SqlConnection, Cursor, Row)
 from typing import Any, Self, Callable
 from plomtask.exceptions import (HandledException, NotFoundException,
                                  BadFormatException)
@@ -51,6 +52,7 @@ class DatabaseFile:
                           encoding='utf-8') as f:
                     conn.executescript(f.read())
                 conn.execute(f'PRAGMA user_version = {from_version + j + 1}')
+            cls._validate_schema(conn)
             conn.commit()
         return cls(path)
 
@@ -60,7 +62,8 @@ class DatabaseFile:
             raise NotFoundException
         if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
-        self._validate_schema()
+        with sql_connect(self.path) as conn:
+            self._validate_schema(conn)
 
     @staticmethod
     def _available_migrations() -> list[str]:
@@ -98,17 +101,15 @@ class DatabaseFile:
         assert isinstance(db_version, int)
         return db_version
 
-    def _validate_schema(self) -> None:
+    @staticmethod
+    def _validate_schema(conn: SqlConnection) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
-        # pylint: disable=too-many-locals
         schema_rows_normed = []
         indent = '    '
-        with sql_connect(self.path) as conn:
-            schema_rows = [
-                    r[0] for r in conn.execute(
-                        'SELECT sql FROM sqlite_master ORDER BY sql')
-                    if r[0]]
-        for row in schema_rows:
+        for row in [
+                r[0] for r in conn.execute(
+                    'SELECT sql FROM sqlite_master ORDER BY sql')
+                if r[0]]:
             row_normed = []
             for subrow in [sr.rstrip() for sr in row.split('\n')]:
                 in_parentheses = 0