From: Christian Heller Date: Mon, 6 Jan 2025 21:09:27 +0000 (+0100) Subject: Refactor schema validation. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/%7B%7Bdb.prefix%7D%7D/condition?a=commitdiff_plain;h=9f497db6117671ff7c2d89e38b0c20ce590c8fef;p=plomtask Refactor schema validation. --- diff --git a/plomtask/db.py b/plomtask/db.py index a8e11ba..7a80f9f 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -107,51 +107,46 @@ class DatabaseFile: def _validate_schema(self) -> None: """Compare found schema with what's stored at PATH_DB_SCHEMA.""" - - def reformat_rows(rows: list[str]) -> list[str]: - new_rows = [] - for row in rows: - new_row = [] - for subrow in row.split('\n'): - subrow = subrow.rstrip() - in_parentheses = 0 - split_at = [] - for i, c in enumerate(subrow): - if '(' == c: - in_parentheses += 1 - elif ')' == c: - in_parentheses -= 1 - elif ',' == c and 0 == in_parentheses: - split_at += [i + 1] - prev_split = 0 - for i in split_at: - segment = subrow[prev_split:i].strip() - if len(segment) > 0: - new_row += [f' {segment}'] - prev_split = i - segment = subrow[prev_split:].strip() - if len(segment) > 0: - new_row += [f' {segment}'] - new_row[0] = new_row[0].lstrip() - new_row[-1] = new_row[-1].lstrip() - if new_row[-1] != ')' and new_row[-3][-1] != ',': - new_row[-3] = new_row[-3] + ',' - new_row[-2:] = [' ' + new_row[-1][:-1]] + [')'] - new_rows += ['\n'.join(new_row)] - return new_rows - - sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql' - msg_err = 'Database has wrong tables schema. Diff:\n' + # pylint: disable=too-many-locals + schema_rows_normed = [] + indent = ' ' with sql_connect(self.path) as conn: - schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]] - schema_rows = reformat_rows(schema_rows) - retrieved_schema = ';\n'.join(schema_rows) + ';' + schema_rows = [ + r[0] for r in conn.execute( + 'SELECT sql FROM sqlite_master ORDER BY sql') + if r[0]] + for row in schema_rows: + row_normed = [] + for subrow in [sr.rstrip() for sr in row.split('\n')]: + in_parentheses = 0 + split_at = [] + for i, c in enumerate(subrow): + if '(' == c: + in_parentheses += 1 + elif ')' == c: + in_parentheses -= 1 + elif ',' == c and 0 == in_parentheses: + split_at += [i + 1] + prev_split = 0 + for i in split_at: + if segment := subrow[prev_split:i].strip(): + row_normed += [f'{indent}{segment}'] + prev_split = i + if segment := subrow[prev_split:].strip(): + row_normed += [f'{indent}{segment}'] + row_normed[0] = row_normed[0].lstrip() # no indent for opening … + row_normed[-1] = row_normed[-1].lstrip() # … and closing line + if row_normed[-1] != ')' and row_normed[-3][-1] != ',': + row_normed[-3] = row_normed[-3] + ',' + row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')'] + row_normed[-1] = row_normed[-1] + ';' + schema_rows_normed += row_normed with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: - stored_schema = f.read().rstrip() - if stored_schema != retrieved_schema: - diff_msg = Differ().compare(retrieved_schema.splitlines(), - stored_schema.splitlines()) - raise HandledException(msg_err + '\n'.join(diff_msg)) + expected_rows = f.read().rstrip().splitlines() + if expected_rows != schema_rows_normed: + raise HandledException( + 'Unexpected tables schema. Diff to {path_expected_schema}:\n' + + '\n'.join(Differ().compare(schema_rows_normed, expected_rows))) class DatabaseConnection: