home · contact · privacy
More DB management code refactoring.
authorChristian Heller <c.heller@plomlompom.de>
Mon, 6 Jan 2025 21:18:55 +0000 (22:18 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Mon, 6 Jan 2025 21:18:55 +0000 (22:18 +0100)
plomtask/db.py

index 7a80f9fafae733906d321b9b932c712ed64311a9..e505b0e8c6adb9a0329f2675fab2c0394e147414 100644 (file)
@@ -39,24 +39,25 @@ class DatabaseFile:
     @classmethod
     def migrate(cls, path: str) -> DatabaseFile:
         """Apply migrations from_version to EXPECTED_DB_VERSION."""
-        migrations = cls._available_migrations()
         from_version = cls._get_version_of_db(path)
+        if from_version >= EXPECTED_DB_VERSION:
+            raise HandledException(
+                    f'Cannot migrate {from_version} to {EXPECTED_DB_VERSION}')
+        migrations = cls._available_migrations()
         migrations_todo = migrations[from_version+1:]
-        for j, filename in enumerate(migrations_todo):
-            with sql_connect(path) as conn:
+        with sql_connect(path) as conn:
+            for j, filename in enumerate(migrations_todo):
                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
                           encoding='utf-8') as f:
                     conn.executescript(f.read())
-            user_version = from_version + j + 1
-            with sql_connect(path) as conn:
-                conn.execute(f'PRAGMA user_version = {user_version}')
+                conn.execute(f'PRAGMA user_version = {from_version + j + 1}')
         return cls(path)
 
     def _check(self) -> None:
         """Check file exists, and is of proper DB version and schema."""
         if not isfile(self.path):
             raise NotFoundException
-        if self._user_version != EXPECTED_DB_VERSION:
+        if self._get_version_of_db(self.path) != EXPECTED_DB_VERSION:
             raise UnmigratedDbException()
         self._validate_schema()
 
@@ -93,18 +94,9 @@ class DatabaseFile:
         sql_for_db_version = 'PRAGMA user_version'
         with sql_connect(path) as conn:
             db_version = list(conn.execute(sql_for_db_version))[0][0]
-        if db_version > EXPECTED_DB_VERSION:
-            msg = f'Wrong DB version, expected '\
-                    f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
-            raise HandledException(msg)
         assert isinstance(db_version, int)
         return db_version
 
-    @property
-    def _user_version(self) -> int:
-        """Get DB user_version."""
-        return self._get_version_of_db(self.path)
-
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
         # pylint: disable=too-many-locals