home · contact · privacy
Simplify schema validation code.
authorChristian Heller <c.heller@plomlompom.de>
Mon, 6 Jan 2025 18:36:45 +0000 (19:36 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Mon, 6 Jan 2025 18:36:45 +0000 (19:36 +0100)
src/ytplom/db.py

index b330d7e172baea9b70c8a5c44303e4b9c191b98e..816e4012be1fa3a0456d26303c59853c257f65f6 100644 (file)
@@ -78,13 +78,14 @@ class DbFile:
 
     @staticmethod
     def _validate_schema(conn: 'DbConn'):
-        schema_sql = SqlText('SELECT sql FROM sqlite_master ORDER BY sql')
         schema_rows_normed = []
         indent = '  '
-        for row in [r[0] for r in conn.exec(schema_sql) if r[0]]:
+        for row in [
+                r[0] for r in conn.exec(SqlText(
+                    'SELECT sql FROM sqlite_master ORDER BY sql'))
+                if r[0]]:
             row_normed = []
-            for subrow in row.split('\n'):
-                subrow = subrow.rstrip()
+            for subrow in [sr.rstrip() for sr in row.split('\n')]:
                 in_parentheses = 0
                 split_at = []
                 for i, c in enumerate(subrow):
@@ -96,26 +97,24 @@ class DbFile:
                         split_at += [i + 1]
                 prev_split = 0
                 for i in split_at:
-                    segment = subrow[prev_split:i].strip()
-                    if len(segment) > 0:
+                    if segment := subrow[prev_split:i].strip():
                         row_normed += [f'{indent}{segment}']
                     prev_split = i
-                segment = subrow[prev_split:].strip()
-                if len(segment) > 0:
+                if segment := subrow[prev_split:].strip():
                     row_normed += [f'{indent}{segment}']
-            row_normed[0] = row_normed[0].lstrip()
-            row_normed[-1] = row_normed[-1].lstrip()
+            row_normed[0] = row_normed[0].lstrip()  # no indent for opening …
+            row_normed[-1] = row_normed[-1].lstrip()  # … and closing line
             if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
                 row_normed[-3] = row_normed[-3] + ','
                 row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
-            schema_rows_normed += ['\n'.join(row_normed)]
-        retrieved_schema = ';\n'.join(schema_rows_normed) + ';'
-        stored_schema = _PATH_DB_SCHEMA.read_text(encoding='utf-8').rstrip()
-        if stored_schema != retrieved_schema:
-            diff_msg = Differ().compare(retrieved_schema.splitlines(),
-                                        stored_schema.splitlines())
-            raise HandledException('DB has wrong tables schema. Diff:\n'
-                                   + '\n'.join(diff_msg))
+            row_normed[-1] = row_normed[-1] + ';'
+            schema_rows_normed += row_normed
+        if ((expected_rows :=
+             _PATH_DB_SCHEMA.read_text(encoding='utf8').rstrip().splitlines()
+             ) != schema_rows_normed):
+            raise HandledException(
+                'Unexpected tables schema. Diff to {_PATH_DB_SCHEMA}:\n' +
+                '\n'.join(Differ().compare(schema_rows_normed, expected_rows)))
 
     def _get_user_version(self) -> int:
         with sql_connect(self.path) as conn: