From b50f524a86f37ecb48210ca37918a8fabb0a01c8 Mon Sep 17 00:00:00 2001 From: Christian Heller <c.heller@plomlompom.de> Date: Mon, 6 Jan 2025 22:18:55 +0100 Subject: [PATCH] More DB management code refactoring. --- plomtask/db.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/plomtask/db.py b/plomtask/db.py index 7a80f9f..e505b0e 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -39,24 +39,25 @@ class DatabaseFile: @classmethod def migrate(cls, path: str) -> DatabaseFile: """Apply migrations from_version to EXPECTED_DB_VERSION.""" - migrations = cls._available_migrations() from_version = cls._get_version_of_db(path) + if from_version >= EXPECTED_DB_VERSION: + raise HandledException( + f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}') + migrations = cls._available_migrations() migrations_todo = migrations[from_version+1:] - for j, filename in enumerate(migrations_todo): - with sql_connect(path) as conn: + with sql_connect(path) as conn: + for j, filename in enumerate(migrations_todo): with open(f'{MIGRATIONS_DIR}/{filename}', 'r', encoding='utf-8') as f: conn.executescript(f.read()) - user_version = from_version + j + 1 - with sql_connect(path) as conn: - conn.execute(f'PRAGMA user_version = {user_version}') + conn.execute(f'PRAGMA user_version = {from_version + j + 1}') return cls(path) def _check(self) -> None: """Check file exists, and is of proper DB version and schema.""" if not isfile(self.path): raise NotFoundException - if self._user_version != EXPECTED_DB_VERSION: + if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION: raise UnmigratedDbException() self._validate_schema() @@ -93,18 +94,9 @@ class DatabaseFile: sql_for_db_version = 'PRAGMA user_version' with sql_connect(path) as conn: db_version = list(conn.execute(sql_for_db_version))[0][0] - if db_version > EXPECTED_DB_VERSION: - msg = f'Wrong DB version, expected '\ - f'{EXPECTED_DB_VERSION}, got unknown {db_version}.' - raise HandledException(msg) assert isinstance(db_version, int) return db_version - @property - def _user_version(self) -> int: - """Get DB user_version.""" - return self._get_version_of_db(self.path) - def _validate_schema(self) -> None: """Compare found schema with what's stored at PATH_DB_SCHEMA.""" # pylint: disable=too-many-locals -- 2.30.2