home · contact · privacy
Add Todo/Process.blockers for Conditions that block rather than enable.
[plomtask] / plomtask / db.py
index 8e5529062ae18c60a2ed222deb27305f9c18e54c..d2791b1bdf43da11d389b805fffd6b2791ae168d 100644 (file)
@@ -1,13 +1,20 @@
 """Database management."""
 from __future__ import annotations
+from os import listdir
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from typing import Any, Self, TypeVar, Generic
 from plomtask.exceptions import HandledException, NotFoundException
 
-PATH_DB_SCHEMA = 'scripts/init.sql'
-EXPECTED_DB_VERSION = 0
+EXPECTED_DB_VERSION = 4
+MIGRATIONS_DIR = 'migrations'
+FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
+PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
+
+
+class UnmigratedDbException(HandledException):
+    """To identify case of unmigrated DB file."""
 
 
 class DatabaseFile:  # pylint: disable=too-few-public-methods
@@ -17,43 +24,131 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         self.path = path
         self._check()
 
-    def remake(self) -> None:
-        """Create tables in self.path file as per PATH_DB_SCHEMA sql file."""
-        with sql_connect(self.path) as conn:
+    @classmethod
+    def create_at(cls, path: str) -> DatabaseFile:
+        """Make new DB file at path."""
+        with sql_connect(path) as conn:
             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
                 conn.executescript(f.read())
-        self._check()
+            conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
+        return cls(path)
+
+    @classmethod
+    def migrate(cls, path: str) -> DatabaseFile:
+        """Apply migrations from_version to EXPECTED_DB_VERSION."""
+        migrations = cls._available_migrations()
+        from_version = cls.get_version_of_db(path)
+        migrations_todo = migrations[from_version+1:]
+        for j, filename in enumerate(migrations_todo):
+            with sql_connect(path) as conn:
+                with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
+                          encoding='utf-8') as f:
+                    conn.executescript(f.read())
+            user_version = from_version + j + 1
+            with sql_connect(path) as conn:
+                conn.execute(f'PRAGMA user_version = {user_version}')
+        return cls(path)
 
     def _check(self) -> None:
         """Check file exists, and is of proper DB version and schema."""
-        self.exists = isfile(self.path)
-        if self.exists:
-            self._validate_user_version()
-            self._validate_schema()
+        if not isfile(self.path):
+            raise NotFoundException
+        if self.user_version != EXPECTED_DB_VERSION:
+            raise UnmigratedDbException()
+        self._validate_schema()
 
-    def _validate_user_version(self) -> None:
-        """Compare DB user_version with EXPECTED_DB_VERSION."""
+    @staticmethod
+    def _available_migrations() -> list[str]:
+        """Validate migrations directory and return sorted entries."""
+        msg_too_big = 'Migration directory points beyond expected DB version.'
+        msg_bad_entry = 'Migration directory contains unexpected entry: '
+        msg_missing = 'Migration directory misses migration of number: '
+        migrations = {}
+        for entry in listdir(MIGRATIONS_DIR):
+            if entry == FILENAME_DB_SCHEMA:
+                continue
+            toks = entry.split('_', 1)
+            if len(toks) < 2:
+                raise HandledException(msg_bad_entry + entry)
+            try:
+                i = int(toks[0])
+            except ValueError as e:
+                raise HandledException(msg_bad_entry + entry) from e
+            if i > EXPECTED_DB_VERSION:
+                raise HandledException(msg_too_big)
+            migrations[i] = toks[1]
+        migrations_list = []
+        for i in range(EXPECTED_DB_VERSION + 1):
+            if i not in migrations:
+                raise HandledException(msg_missing + str(i))
+            migrations_list += [f'{i}_{migrations[i]}']
+        return migrations_list
+
+    @staticmethod
+    def get_version_of_db(path: str) -> int:
+        """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
-        with sql_connect(self.path) as conn:
+        with sql_connect(path) as conn:
             db_version = list(conn.execute(sql_for_db_version))[0][0]
-            if db_version != EXPECTED_DB_VERSION:
-                msg = f'Wrong DB version, expected '\
-                        f'{EXPECTED_DB_VERSION}, got {db_version}.'
-                raise HandledException(msg)
+        if db_version > EXPECTED_DB_VERSION:
+            msg = f'Wrong DB version, expected '\
+                    f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
+            raise HandledException(msg)
+        assert isinstance(db_version, int)
+        return db_version
+
+    @property
+    def user_version(self) -> int:
+        """Get DB user_version."""
+        return self.__class__.get_version_of_db(self.path)
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
+
+        def reformat_rows(rows: list[str]) -> list[str]:
+            new_rows = []
+            for row in rows:
+                new_row = []
+                for subrow in row.split('\n'):
+                    subrow = subrow.rstrip()
+                    in_parentheses = 0
+                    split_at = []
+                    for i, c in enumerate(subrow):
+                        if '(' == c:
+                            in_parentheses += 1
+                        elif ')' == c:
+                            in_parentheses -= 1
+                        elif ',' == c and 0 == in_parentheses:
+                            split_at += [i + 1]
+                    prev_split = 0
+                    for i in split_at:
+                        segment = subrow[prev_split:i].strip()
+                        if len(segment) > 0:
+                            new_row += [f'    {segment}']
+                        prev_split = i
+                    segment = subrow[prev_split:].strip()
+                    if len(segment) > 0:
+                        new_row += [f'    {segment}']
+                new_row[0] = new_row[0].lstrip()
+                new_row[-1] = new_row[-1].lstrip()
+                if new_row[-1] != ')' and new_row[-3][-1] != ',':
+                    new_row[-3] = new_row[-3] + ','
+                    new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
+                new_rows += ['\n'.join(new_row)]
+            return new_rows
+
         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
         msg_err = 'Database has wrong tables schema. Diff:\n'
         with sql_connect(self.path) as conn:
             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
-            retrieved_schema = ';\n'.join(schema_rows) + ';'
-            with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
-                stored_schema = f.read().rstrip()
-                if stored_schema != retrieved_schema:
-                    diff_msg = Differ().compare(retrieved_schema.splitlines(),
-                                                stored_schema.splitlines())
-                    raise HandledException(msg_err + '\n'.join(diff_msg))
+        schema_rows = reformat_rows(schema_rows)
+        retrieved_schema = ';\n'.join(schema_rows) + ';'
+        with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
+            stored_schema = f.read().rstrip()
+        if stored_schema != retrieved_schema:
+            diff_msg = Differ().compare(retrieved_schema.splitlines(),
+                                        stored_schema.splitlines())
+            raise HandledException(msg_err + '\n'.join(diff_msg))
 
 
 class DatabaseConnection:
@@ -75,7 +170,7 @@ class DatabaseConnection:
         """Close DB connection."""
         self.conn.close()
 
-    def rewrite_relations(self, table_name: str, key: str, target: int,
+    def rewrite_relations(self, table_name: str, key: str, target: int | str,
                           rows: list[list[Any]]) -> None:
         """Rewrite relations in table_name to target, with rows values."""
         self.delete_where(table_name, key, target)
@@ -121,6 +216,8 @@ class BaseModel(Generic[BaseModelId]):
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
     to_save: list[str] = []
+    to_save_versioned: list[str] = []
+    to_save_relations: list[tuple[str, str, str]] = []
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
 
@@ -130,6 +227,23 @@ class BaseModel(Generic[BaseModelId]):
             raise HandledException(msg)
         self.id_ = id_
 
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, self.__class__):
+            return False
+        to_hash_me = tuple([self.id_] +
+                           [getattr(self, name) for name in self.to_save])
+        to_hash_other = tuple([other.id_] +
+                              [getattr(other, name) for name in other.to_save])
+        return hash(to_hash_me) == hash(to_hash_other)
+
+    def __lt__(self, other: Any) -> bool:
+        if not isinstance(other, self.__class__):
+            msg = 'cannot compare to object of different class'
+            raise HandledException(msg)
+        assert isinstance(self.id_, int)
+        assert isinstance(other.id_, int)
+        return self.id_ < other.id_
+
     @classmethod
     def get_cached(cls: type[BaseModelInstance],
                    id_: BaseModelId) -> BaseModelInstance | None:
@@ -209,7 +323,13 @@ class BaseModel(Generic[BaseModelId]):
     @classmethod
     def all(cls: type[BaseModelInstance],
             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
-        """Collect all objects of class."""
+        """Collect all objects of class into list.
+
+        Note that this primarily returns the contents of the cache, and only
+        _expands_ that by additional findings in the DB. This assumes the
+        cache is always instantly cleaned of any items that would be removed
+        from the DB.
+        """
         items: dict[BaseModelId, BaseModelInstance] = {}
         for k, v in cls.get_cache().items():
             assert isinstance(v, cls)
@@ -222,21 +342,11 @@ class BaseModel(Generic[BaseModelId]):
                 items[item.id_] = item
         return list(items.values())
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, self.__class__):
-            msg = 'cannot compare to object of different class'
-            raise HandledException(msg)
-        to_hash_me = tuple([self.id_] +
-                           [getattr(self, name) for name in self.to_save])
-        to_hash_other = tuple([other.id_] +
-                              [getattr(other, name) for name in other.to_save])
-        return hash(to_hash_me) == hash(to_hash_other)
-
-    def save_core(self, db_conn: DatabaseConnection) -> None:
-        """Write bare-bones self (sans connected items), ensuring self.id_.
+    def save(self, db_conn: DatabaseConnection) -> None:
+        """Write self to DB and cache and ensure .id_.
 
         Write both to DB, and to cache. To DB, write .id_ and attributes
-        listed in cls.to_save.
+        listed in cls.to_save[_versioned|_relations].
 
         Ensure self.id_ by setting it to what the DB command returns as the
         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
@@ -252,9 +362,21 @@ class BaseModel(Generic[BaseModelId]):
         if not isinstance(self.id_, str):
             self.id_ = cursor.lastrowid  # type: ignore[assignment]
         self.cache()
+        for attr_name in self.to_save_versioned:
+            getattr(self, attr_name).save(db_conn)
+        for table, column, attr_name in self.to_save_relations:
+            assert isinstance(self.id_, (int, str))
+            db_conn.rewrite_relations(table, column, self.id_,
+                                      [[i.id_] for i
+                                       in getattr(self, attr_name)])
 
     def remove(self, db_conn: DatabaseConnection) -> None:
-        """Remove from DB."""
-        assert isinstance(self.id_, int | str)
+        """Remove from DB and cache, including dependencies."""
+        if self.id_ is None or self.__class__.get_cached(self.id_) is None:
+            raise HandledException('cannot remove unsaved item')
+        for attr_name in self.to_save_versioned:
+            getattr(self, attr_name).remove(db_conn)
+        for table, column, attr_name in self.to_save_relations:
+            db_conn.delete_where(table, column, self.id_)
         self.uncache()
         db_conn.delete_where(self.table_name, 'id', self.id_)