From 83266154e9140151c975586d21f393a5eb3f4ef4 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Wed, 15 May 2024 08:21:07 +0200
Subject: [PATCH] Add Todo.comment, and for that purpose basic SQL migration
 infrastructure.

---
 scripts/init.sql => migrations/0_init.sql |   0
 migrations/1_add_Todo_comment.sql         |   1 +
 migrations/init_1.sql                     | 113 +++++++++++++++++
 plomtask/db.py                            | 140 ++++++++++++++++++----
 plomtask/todos.py                         |   6 +-
 run.py                                    |  30 +++--
 tests/utils.py                            |   3 +-
 7 files changed, 254 insertions(+), 39 deletions(-)
 rename scripts/init.sql => migrations/0_init.sql (100%)
 create mode 100644 migrations/1_add_Todo_comment.sql
 create mode 100644 migrations/init_1.sql

diff --git a/scripts/init.sql b/migrations/0_init.sql
similarity index 100%
rename from scripts/init.sql
rename to migrations/0_init.sql
diff --git a/migrations/1_add_Todo_comment.sql b/migrations/1_add_Todo_comment.sql
new file mode 100644
index 0000000..0c58335
--- /dev/null
+++ b/migrations/1_add_Todo_comment.sql
@@ -0,0 +1 @@
+ALTER TABLE todos ADD COLUMN comment TEXT NOT NULL DEFAULT "";
diff --git a/migrations/init_1.sql b/migrations/init_1.sql
new file mode 100644
index 0000000..c30121f
--- /dev/null
+++ b/migrations/init_1.sql
@@ -0,0 +1,113 @@
+CREATE TABLE condition_descriptions (
+    parent INTEGER NOT NULL,
+    timestamp TEXT NOT NULL,
+    description TEXT NOT NULL,
+    PRIMARY KEY (parent, timestamp),
+    FOREIGN KEY (parent) REFERENCES conditions(id)
+);
+CREATE TABLE condition_titles (
+    parent INTEGER NOT NULL,
+    timestamp TEXT NOT NULL,
+    title TEXT NOT NULL,
+    PRIMARY KEY (parent, timestamp),
+    FOREIGN KEY (parent) REFERENCES conditions(id)
+);
+CREATE TABLE conditions (
+    id INTEGER PRIMARY KEY,
+    is_active BOOLEAN NOT NULL
+);
+CREATE TABLE days (
+    id TEXT PRIMARY KEY,
+    comment TEXT NOT NULL
+);
+CREATE TABLE process_conditions (
+    process INTEGER NOT NULL,
+    condition INTEGER NOT NULL,
+    PRIMARY KEY (process, condition),
+    FOREIGN KEY (process) REFERENCES processes(id),
+    FOREIGN KEY (condition) REFERENCES conditions(id)
+);
+CREATE TABLE process_descriptions (
+    parent INTEGER NOT NULL,
+    timestamp TEXT NOT NULL,
+    description TEXT NOT NULL,
+    PRIMARY KEY (parent, timestamp),
+    FOREIGN KEY (parent) REFERENCES processes(id)
+);
+CREATE TABLE process_disables (
+    process INTEGER NOT NULL,
+    condition INTEGER NOT NULL,
+    PRIMARY KEY(process, condition),
+    FOREIGN KEY (process) REFERENCES processes(id),
+    FOREIGN KEY (condition) REFERENCES conditions(id)
+);
+CREATE TABLE process_efforts (
+    parent INTEGER NOT NULL,
+    timestamp TEXT NOT NULL,
+    effort REAL NOT NULL,
+    PRIMARY KEY (parent, timestamp),
+    FOREIGN KEY (parent) REFERENCES processes(id)
+);
+CREATE TABLE process_enables (
+    process INTEGER NOT NULL,
+    condition INTEGER NOT NULL,
+    PRIMARY KEY(process, condition),
+    FOREIGN KEY (process) REFERENCES processes(id),
+    FOREIGN KEY (condition) REFERENCES conditions(id)
+);
+CREATE TABLE process_steps (
+    id INTEGER PRIMARY KEY,
+    owner INTEGER NOT NULL,
+    step_process INTEGER NOT NULL,
+    parent_step INTEGER,
+    FOREIGN KEY (owner) REFERENCES processes(id),
+    FOREIGN KEY (step_process) REFERENCES processes(id),
+    FOREIGN KEY (parent_step) REFERENCES process_steps(step_id)
+);
+CREATE TABLE process_titles (
+    parent INTEGER NOT NULL,
+    timestamp TEXT NOT NULL,
+    title TEXT NOT NULL,
+    PRIMARY KEY (parent, timestamp),
+    FOREIGN KEY (parent) REFERENCES processes(id)
+);
+CREATE TABLE processes (
+    id INTEGER PRIMARY KEY
+);
+CREATE TABLE todo_children (
+    parent INTEGER NOT NULL,
+    child INTEGER NOT NULL,
+    PRIMARY KEY (parent, child),
+    FOREIGN KEY (parent) REFERENCES todos(id),
+    FOREIGN KEY (child) REFERENCES todos(id)
+);
+CREATE TABLE todo_conditions (
+    todo INTEGER NOT NULL,
+    condition INTEGER NOT NULL,
+    PRIMARY KEY(todo, condition),
+    FOREIGN KEY (todo) REFERENCES todos(id),
+    FOREIGN KEY (condition) REFERENCES conditions(id)
+);
+CREATE TABLE todo_disables (
+    todo INTEGER NOT NULL,
+    condition INTEGER NOT NULL,
+    PRIMARY KEY(todo, condition),
+    FOREIGN KEY (todo) REFERENCES todos(id),
+    FOREIGN KEY (condition) REFERENCES conditions(id)
+);
+CREATE TABLE todo_enables (
+    todo INTEGER NOT NULL,
+    condition INTEGER NOT NULL,
+    PRIMARY KEY(todo, condition),
+    FOREIGN KEY (todo) REFERENCES todos(id),
+    FOREIGN KEY (condition) REFERENCES conditions(id)
+);
+CREATE TABLE todos (
+    id INTEGER PRIMARY KEY,
+    process INTEGER NOT NULL,
+    is_done BOOLEAN NOT NULL,
+    day TEXT NOT NULL,
+    comment TEXT NOT NULL DEFAULT "",
+    FOREIGN KEY (process) REFERENCES processes(id),
+    FOREIGN KEY (day) REFERENCES days(id)
+);
diff --git a/plomtask/db.py b/plomtask/db.py
index e4d5f6e..7962eab 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -1,13 +1,20 @@
 """Database management."""
 from __future__ import annotations
+from os import listdir
 from os.path import isfile
 from difflib import Differ
 from sqlite3 import connect as sql_connect, Cursor, Row
 from typing import Any, Self, TypeVar, Generic
 from plomtask.exceptions import HandledException, NotFoundException
 
-PATH_DB_SCHEMA = 'scripts/init.sql'
-EXPECTED_DB_VERSION = 0
+EXPECTED_DB_VERSION = 1
+MIGRATIONS_DIR = 'migrations'
+FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
+PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
+
+
+class UnmigratedDbException(HandledException):
+    """To identify case of unmigrated DB file."""
 
 
 class DatabaseFile:  # pylint: disable=too-few-public-methods
@@ -17,43 +24,128 @@ class DatabaseFile:  # pylint: disable=too-few-public-methods
         self.path = path
         self._check()
 
-    def remake(self) -> None:
-        """Create tables in self.path file as per PATH_DB_SCHEMA sql file."""
-        with sql_connect(self.path) as conn:
+    @classmethod
+    def create_at(cls, path: str) -> DatabaseFile:
+        """Make new DB file at path."""
+        with sql_connect(path) as conn:
             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
                 conn.executescript(f.read())
-        self._check()
+            conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
+        return cls(path)
+
+    @classmethod
+    def migrate(cls, path: str) -> DatabaseFile:
+        """Apply migrations from_version to EXPECTED_DB_VERSION."""
+        migrations = cls._available_migrations()
+        from_version = cls.get_version_of_db(path)
+        migrations_todo = migrations[from_version+1:]
+        for j, filename in enumerate(migrations_todo):
+            with sql_connect(path) as conn:
+                with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
+                          encoding='utf-8') as f:
+                    conn.executescript(f.read())
+            user_version = from_version + j + 1
+            with sql_connect(path) as conn:
+                conn.execute(f'PRAGMA user_version = {user_version}')
+        return cls(path)
 
     def _check(self) -> None:
         """Check file exists, and is of proper DB version and schema."""
-        self.exists = isfile(self.path)
-        if self.exists:
-            self._validate_user_version()
-            self._validate_schema()
+        if not isfile(self.path):
+            raise NotFoundException
+        if self.user_version != EXPECTED_DB_VERSION:
+            raise UnmigratedDbException()
+        self._validate_schema()
+
+    @staticmethod
+    def _available_migrations() -> list[str]:
+        """Validate migrations directory and return sorted entries."""
+        msg_too_big = 'Migration directory points beyond expected DB version.'
+        msg_bad_entry = 'Migration directory contains unexpected entry: '
+        msg_missing = 'Migration directory misses migration of number: '
+        migrations = {}
+        for entry in listdir(MIGRATIONS_DIR):
+            if entry == FILENAME_DB_SCHEMA:
+                continue
+            toks = entry.split('_', 1)
+            if len(toks) < 2:
+                raise HandledException(msg_bad_entry + entry)
+            try:
+                i = int(toks[0])
+            except ValueError as e:
+                raise HandledException(msg_bad_entry + entry) from e
+            if i > EXPECTED_DB_VERSION:
+                raise HandledException(msg_too_big)
+            migrations[i] = toks[1]
+        migrations_list = []
+        for i in range(EXPECTED_DB_VERSION + 1):
+            if i not in migrations:
+                raise HandledException(msg_missing + str(i))
+            migrations_list += [f'{i}_{migrations[i]}']
+        return migrations_list
 
-    def _validate_user_version(self) -> None:
-        """Compare DB user_version with EXPECTED_DB_VERSION."""
+    @staticmethod
+    def get_version_of_db(path: str) -> int:
+        """Get DB user_version, fail if outside expected range."""
         sql_for_db_version = 'PRAGMA user_version'
-        with sql_connect(self.path) as conn:
+        with sql_connect(path) as conn:
             db_version = list(conn.execute(sql_for_db_version))[0][0]
-            if db_version != EXPECTED_DB_VERSION:
-                msg = f'Wrong DB version, expected '\
-                        f'{EXPECTED_DB_VERSION}, got {db_version}.'
-                raise HandledException(msg)
+        if db_version > EXPECTED_DB_VERSION:
+            msg = f'Wrong DB version, expected '\
+                    f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
+            raise HandledException(msg)
+        assert isinstance(db_version, int)
+        return db_version
+
+    @property
+    def user_version(self) -> int:
+        """Get DB user_version."""
+        return self.__class__.get_version_of_db(self.path)
 
     def _validate_schema(self) -> None:
         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
+
+        def reformat_rows(rows: list[str]) -> list[str]:
+            new_rows = []
+            for row in rows:
+                new_row = []
+                for subrow in row.split('\n'):
+                    subrow = subrow.rstrip()
+                    in_parentheses = 0
+                    split_at = []
+                    for i, c in enumerate(subrow):
+                        if '(' == c:
+                            in_parentheses += 1
+                        elif ')' == c:
+                            in_parentheses -= 1
+                        elif ',' == c and 0 == in_parentheses:
+                            split_at += [i + 1]
+                    prev_split = 0
+                    for i in split_at:
+                        segment = subrow[prev_split:i].strip()
+                        if len(segment) > 0:
+                            new_row += [f'    {segment}']
+                        prev_split = i
+                    segment = subrow[prev_split:].strip()
+                    if len(segment) > 0:
+                        new_row += [f'    {segment}']
+                new_row[0] = new_row[0].lstrip()
+                new_row[-1] = new_row[-1].lstrip()
+                new_rows += ['\n'.join(new_row)]
+            return new_rows
+
         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
         msg_err = 'Database has wrong tables schema. Diff:\n'
         with sql_connect(self.path) as conn:
             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
-            retrieved_schema = ';\n'.join(schema_rows) + ';'
-            with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
-                stored_schema = f.read().rstrip()
-                if stored_schema != retrieved_schema:
-                    diff_msg = Differ().compare(retrieved_schema.splitlines(),
-                                                stored_schema.splitlines())
-                    raise HandledException(msg_err + '\n'.join(diff_msg))
+        schema_rows = reformat_rows(schema_rows)
+        retrieved_schema = ';\n'.join(schema_rows) + ';'
+        with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
+            stored_schema = f.read().rstrip()
+        if stored_schema != retrieved_schema:
+            diff_msg = Differ().compare(retrieved_schema.splitlines(),
+                                        stored_schema.splitlines())
+            raise HandledException(msg_err + '\n'.join(diff_msg))
 
 
 class DatabaseConnection:
diff --git a/plomtask/todos.py b/plomtask/todos.py
index 4e3a4db..a9bd94c 100644
--- a/plomtask/todos.py
+++ b/plomtask/todos.py
@@ -22,21 +22,23 @@ class Todo(BaseModel[int], ConditionsRelations):
     """Individual actionable."""
     # pylint: disable=too-many-instance-attributes
     table_name = 'todos'
-    to_save = ['process_id', 'is_done', 'date']
+    to_save = ['process_id', 'is_done', 'date', 'comment']
     to_save_relations = [('todo_conditions', 'todo', 'conditions'),
                          ('todo_enables', 'todo', 'enables'),
                          ('todo_disables', 'todo', 'disables'),
                          ('todo_children', 'parent', 'children'),
                          ('todo_children', 'child', 'parents')]
 
+    # pylint: disable=too-many-arguments
     def __init__(self, id_: int | None, process: Process,
-                 is_done: bool, date: str) -> None:
+                 is_done: bool, date: str, comment: str = '') -> None:
         super().__init__(id_)
         if process.id_ is None:
             raise NotFoundException('Process of Todo without ID (not saved?)')
         self.process = process
         self._is_done = is_done
         self.date = date
+        self.comment = comment
         self.children: list[Todo] = []
         self.parents: list[Todo] = []
         self.conditions: list[Condition] = []
diff --git a/run.py b/run.py
index e1bbe5d..c69dc6a 100755
--- a/run.py
+++ b/run.py
@@ -2,28 +2,36 @@
 """Call this to start the application."""
 from sys import exit as sys_exit
 from os import environ
-from plomtask.exceptions import HandledException
+from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.http import TaskHandler, TaskServer
-from plomtask.db import DatabaseFile
+from plomtask.db import DatabaseFile, UnmigratedDbException
 
 PLOMTASK_DB_PATH = environ.get('PLOMTASK_DB_PATH')
 HTTP_PORT = 8082
 DB_CREATION_ASK = 'Database file not found. Create? Y/n\n'
+DB_MIGRATE_ASK = 'Database file needs migration. Migrate? Y/n\n'
+
+
+def yes_or_fail(question: str, fail_msg: str) -> None:
+    """Ask question, raise HandledException(fail_msg) if reply not yes."""
+    reply = input(question)
+    if not reply.lower() in {'y', 'yes', 'yes.', 'yes!'}:
+        print('Not recognizing reply as "yes".')
+        raise HandledException(fail_msg)
 
 
 if __name__ == '__main__':
     try:
         if not PLOMTASK_DB_PATH:
             raise HandledException('PLOMTASK_DB_PATH not set.')
-        db_file = DatabaseFile(PLOMTASK_DB_PATH)
-        if not db_file.exists:
-            legal_yesses = {'y', 'yes', 'yes.', 'yes!'}
-            reply = input(DB_CREATION_ASK)
-            if reply.lower() in legal_yesses:
-                db_file.remake()
-            else:
-                print('Not recognizing reply as "yes".')
-                raise HandledException('Cannot run without database.')
+        try:
+            db_file = DatabaseFile(PLOMTASK_DB_PATH)
+        except NotFoundException:
+            yes_or_fail(DB_CREATION_ASK, 'Cannot run without DB.')
+            db_file = DatabaseFile.create_at(PLOMTASK_DB_PATH)
+        except UnmigratedDbException:
+            yes_or_fail(DB_MIGRATE_ASK, 'Cannot run with unmigrated DB.')
+            db_file = DatabaseFile.migrate(PLOMTASK_DB_PATH)
         server = TaskServer(db_file, ('localhost', HTTP_PORT), TaskHandler)
         print(f'running at port {HTTP_PORT}')
         try:
diff --git a/tests/utils.py b/tests/utils.py
index fb7e227..c1a22b6 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -45,8 +45,7 @@ class TestCaseWithDB(TestCase):
         ProcessStep.empty_cache()
         Todo.empty_cache()
         timestamp = datetime.now().timestamp()
-        self.db_file = DatabaseFile(f'test_db:{timestamp}')
-        self.db_file.remake()
+        self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
         self.db_conn = DatabaseConnection(self.db_file)
 
     def tearDown(self) -> None:
-- 
2.30.2