From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 6 Jan 2025 21:09:27 +0000 (+0100)
Subject: Refactor schema validation.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/%7B%7B%20web_path%20%7D%7D/static/%7B%7Bdb.prefix%7D%7D/%7Broute%7D?a=commitdiff_plain;h=9f497db6117671ff7c2d89e38b0c20ce590c8fef;p=plomtask

Refactor schema validation.
---

diff --git a/plomtask/db.py b/plomtask/db.py
index a8e11ba..7a80f9f 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -107,51 +107,46 @@ class DatabaseFile:
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
-
-        def reformat_rows(rows: list[str]) -> list[str]:
-            new_rows = []
-            for row in rows:
-                new_row = []
-                for subrow in row.split('\n'):
-                    subrow = subrow.rstrip()
-                    in_parentheses = 0
-                    split_at = []
-                    for i, c in enumerate(subrow):
-                        if '(' == c:
-                            in_parentheses += 1
-                        elif ')' == c:
-                            in_parentheses -= 1
-                        elif ',' == c and 0 == in_parentheses:
-                            split_at += [i + 1]
-                    prev_split = 0
-                    for i in split_at:
-                        segment = subrow[prev_split:i].strip()
-                        if len(segment) > 0:
-                            new_row += [f'    {segment}']
-                        prev_split = i
-                    segment = subrow[prev_split:].strip()
-                    if len(segment) > 0:
-                        new_row += [f'    {segment}']
-                new_row[0] = new_row[0].lstrip()
-                new_row[-1] = new_row[-1].lstrip()
-                if new_row[-1] != ')' and new_row[-3][-1] != ',':
-                    new_row[-3] = new_row[-3] + ','
-                    new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
-                new_rows += ['\n'.join(new_row)]
-            return new_rows
-
-        sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
-        msg_err = 'Database has wrong tables schema. Diff:\n'
+        # pylint: disable=too-many-locals
+        schema_rows_normed = []
+        indent = '    '
         with sql_connect(self.path) as conn:
-            schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
-        schema_rows = reformat_rows(schema_rows)
-        retrieved_schema = ';\n'.join(schema_rows) + ';'
+            schema_rows = [
+                    r[0] for r in conn.execute(
+                        'SELECT sql FROM sqlite_master ORDER BY sql')
+                    if r[0]]
+        for row in schema_rows:
+            row_normed = []
+            for subrow in [sr.rstrip() for sr in row.split('\n')]:
+                in_parentheses = 0
+                split_at = []
+                for i, c in enumerate(subrow):
+                    if '(' == c:
+                        in_parentheses += 1
+                    elif ')' == c:
+                        in_parentheses -= 1
+                    elif ',' == c and 0 == in_parentheses:
+                        split_at += [i + 1]
+                prev_split = 0
+                for i in split_at:
+                    if segment := subrow[prev_split:i].strip():
+                        row_normed += [f'{indent}{segment}']
+                    prev_split = i
+                if segment := subrow[prev_split:].strip():
+                    row_normed += [f'{indent}{segment}']
+            row_normed[0] = row_normed[0].lstrip()  # no indent for opening …
+            row_normed[-1] = row_normed[-1].lstrip()  # … and closing line
+            if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
+                row_normed[-3] = row_normed[-3] + ','
+                row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
+            row_normed[-1] = row_normed[-1] + ';'
+            schema_rows_normed += row_normed
         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
-            stored_schema = f.read().rstrip()
-        if stored_schema != retrieved_schema:
-            diff_msg = Differ().compare(retrieved_schema.splitlines(),
-                                        stored_schema.splitlines())
-            raise HandledException(msg_err + '\n'.join(diff_msg))
+            expected_rows = f.read().rstrip().splitlines()
+        if expected_rows != schema_rows_normed:
+            raise HandledException(
+                'Unexpected tables schema. Diff to {path_expected_schema}:\n' +
+                '\n'.join(Differ().compare(schema_rows_normed, expected_rows)))
 
 
 class DatabaseConnection: