From: Christian Heller Date: Mon, 6 Jan 2025 18:36:45 +0000 (+0100) Subject: Simplify schema validation code. X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/%7B%7Bprefix%7D%7D/static/test?a=commitdiff_plain;ds=sidebyside;p=ytplom Simplify schema validation code. --- diff --git a/src/ytplom/db.py b/src/ytplom/db.py index b330d7e..816e401 100644 --- a/src/ytplom/db.py +++ b/src/ytplom/db.py @@ -78,13 +78,14 @@ class DbFile: @staticmethod def _validate_schema(conn: 'DbConn'): - schema_sql = SqlText('SELECT sql FROM sqlite_master ORDER BY sql') schema_rows_normed = [] indent = ' ' - for row in [r[0] for r in conn.exec(schema_sql) if r[0]]: + for row in [ + r[0] for r in conn.exec(SqlText( + 'SELECT sql FROM sqlite_master ORDER BY sql')) + if r[0]]: row_normed = [] - for subrow in row.split('\n'): - subrow = subrow.rstrip() + for subrow in [sr.rstrip() for sr in row.split('\n')]: in_parentheses = 0 split_at = [] for i, c in enumerate(subrow): @@ -96,26 +97,24 @@ class DbFile: split_at += [i + 1] prev_split = 0 for i in split_at: - segment = subrow[prev_split:i].strip() - if len(segment) > 0: + if segment := subrow[prev_split:i].strip(): row_normed += [f'{indent}{segment}'] prev_split = i - segment = subrow[prev_split:].strip() - if len(segment) > 0: + if segment := subrow[prev_split:].strip(): row_normed += [f'{indent}{segment}'] - row_normed[0] = row_normed[0].lstrip() - row_normed[-1] = row_normed[-1].lstrip() + row_normed[0] = row_normed[0].lstrip() # no indent for opening … + row_normed[-1] = row_normed[-1].lstrip() # … and closing line if row_normed[-1] != ')' and row_normed[-3][-1] != ',': row_normed[-3] = row_normed[-3] + ',' row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')'] - schema_rows_normed += ['\n'.join(row_normed)] - retrieved_schema = ';\n'.join(schema_rows_normed) + ';' - stored_schema = _PATH_DB_SCHEMA.read_text(encoding='utf-8').rstrip() - if stored_schema != retrieved_schema: - diff_msg = Differ().compare(retrieved_schema.splitlines(), - stored_schema.splitlines()) - raise HandledException('DB has wrong tables schema. Diff:\n' - + '\n'.join(diff_msg)) + row_normed[-1] = row_normed[-1] + ';' + schema_rows_normed += row_normed + if ((expected_rows := + _PATH_DB_SCHEMA.read_text(encoding='utf8').rstrip().splitlines() + ) != schema_rows_normed): + raise HandledException( + 'Unexpected tables schema. Diff to {_PATH_DB_SCHEMA}:\n' + + '\n'.join(Differ().compare(schema_rows_normed, expected_rows))) def _get_user_version(self) -> int: with sql_connect(self.path) as conn: