From b50f524a86f37ecb48210ca37918a8fabb0a01c8 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 6 Jan 2025 22:18:55 +0100
Subject: [PATCH] More DB management code refactoring.

---
 plomtask/db.py | 24 ++++++++----------------
 1 file changed, 8 insertions(+), 16 deletions(-)

diff --git a/plomtask/db.py b/plomtask/db.py
index 7a80f9f..e505b0e 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -39,24 +39,25 @@ class DatabaseFile:
     @classmethod
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
-        migrations = cls._available_migrations()
         from_version = cls._get_version_of_db(path)
+        if from_version >= EXPECTED_DB_VERSION:
+            raise HandledException(
+                    f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}')
+        migrations = cls._available_migrations()
         migrations_todo = migrations[from_version+1:]
-        for j, filename in enumerate(migrations_todo):
-            with sql_connect(path) as conn:
+        with sql_connect(path) as conn:
+            for j, filename in enumerate(migrations_todo):
                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
                           encoding='utf-8') as f:
                     conn.executescript(f.read())
-            user_version = from_version + j + 1
-            with sql_connect(path) as conn:
-                conn.execute(f'PRAGMA user_version = {user_version}')
+                conn.execute(f'PRAGMA user_version = {from_version + j + 1}')
         return cls(path)
 
     def _check(self) -> None:
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
-        if self._user_version != EXPECTED_DB_VERSION:
+        if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
         self._validate_schema()
 
@@ -93,18 +94,9 @@ class DatabaseFile:
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
             db_version = list(conn.execute(sql_for_db_version))[0][0]
-        if db_version > EXPECTED_DB_VERSION:
-            msg = f'Wrong DB version, expected '\
-                    f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
-            raise HandledException(msg)
         assert isinstance(db_version, int)
         return db_version
 
-    @property
-    def _user_version(self) -> int:
-        """Get DB user_version."""
-        return self._get_version_of_db(self.path)
-
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
         # pylint: disable=too-many-locals
-- 
2.30.2