From 5880f767fb8d2ca25a70f9ebaee9b4268596a60b Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 25 Nov 2024 06:40:12 +0100
Subject: [PATCH] Add files.last_update field to renew on file.flags changes,
 and migration mechanism to allow addition of field to table.

---
 install.sh                                  |  3 +
 src/migrate.py                              | 48 ++++++++++++++++
 src/migrations/{init_0.sql => 0_init.sql}   |  0
 src/migrations/1_add_files_last_updated.sql |  1 +
 src/migrations/init_1.sql                   | 32 +++++++++++
 src/templates/video.tmpl                    |  2 +-
 src/ytplom/misc.py                          | 62 +++++++++++++++------
 ytplom                                      |  4 +-
 8 files changed, 133 insertions(+), 19 deletions(-)
 create mode 100755 src/migrate.py
 rename src/migrations/{init_0.sql => 0_init.sql} (100%)
 create mode 100644 src/migrations/1_add_files_last_updated.sql
 create mode 100644 src/migrations/init_1.sql

diff --git a/install.sh b/install.sh
index a2bd9c1..0ee48b1 100755
--- a/install.sh
+++ b/install.sh
@@ -6,6 +6,9 @@ PATH_LOCAL_BIN=~/.local/bin
 NAME_EXECUTABLE=ytplom
 
 mkdir -p "${PATH_APP_SHARE}" "${PATH_LOCAL_BIN}"
+
+rm ${PATH_APP_SHARE}/migrations/*
+
 cp -r ./src/* "${PATH_APP_SHARE}/"
 cp "${NAME_EXECUTABLE}" "${PATH_LOCAL_BIN}/"
 
diff --git a/src/migrate.py b/src/migrate.py
new file mode 100755
index 0000000..7d712d9
--- /dev/null
+++ b/src/migrate.py
@@ -0,0 +1,48 @@
+#!/usr/bin/env python3
+"""Script to migrate DB to most recent schema."""
+from sys import exit as sys_exit
+from os import scandir
+from os.path import basename, isfile
+from sqlite3 import connect as sql_connect
+from ytplom.misc import (
+        EXPECTED_DB_VERSION, PATH_DB, PATH_DB_SCHEMA, PATH_MIGRATIONS,
+        SQL_DB_VERSION, HandledException, get_db_version)
+
+
+def main() -> None:
+    """Try to migrate DB towards EXPECTED_DB_VERSION."""
+    start_version = get_db_version(PATH_DB)
+    if start_version == EXPECTED_DB_VERSION:
+        print('Database at expected version, no migrations to do.')
+        sys_exit(0)
+    elif start_version > EXPECTED_DB_VERSION:
+        raise HandledException(
+                f'Cannot migrate backward from version {start_version} to '
+                f'{EXPECTED_DB_VERSION}.')
+    print(f'Trying to migrate from DB version {start_version} to '
+          f'{EXPECTED_DB_VERSION} …')
+    needed = [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]
+    migrations = {}
+    for entry in [entry for entry in scandir(PATH_MIGRATIONS)
+                  if isfile(entry) and entry.path != PATH_DB_SCHEMA]:
+        toks = basename(entry.path).split('_')
+        try:
+            version = int(toks[0])
+        except ValueError as e:
+            msg = f'Found illegal migration path {entry.path}, aborting.'
+            raise HandledException(msg) from e
+        if version in needed:
+            migrations[version] = entry.path
+    missing = [n for n in needed if n not in migrations]
+    if missing:
+        raise HandledException(f'Needed migrations missing: {missing}')
+    with sql_connect(PATH_DB) as conn:
+        for version_number, migration_path in migrations.items():
+            print(f'Applying migration {version_number}: {migration_path}')
+            with open(migration_path, 'r', encoding='utf8') as f:
+                conn.executescript(f.read())
+            conn.execute(f'{SQL_DB_VERSION} = {version_number}')
+
+
+if __name__ == '__main__':
+    main()
diff --git a/src/migrations/init_0.sql b/src/migrations/0_init.sql
similarity index 100%
rename from src/migrations/init_0.sql
rename to src/migrations/0_init.sql
diff --git a/src/migrations/1_add_files_last_updated.sql b/src/migrations/1_add_files_last_updated.sql
new file mode 100644
index 0000000..678d843
--- /dev/null
+++ b/src/migrations/1_add_files_last_updated.sql
@@ -0,0 +1 @@
+ALTER TABLE files ADD COLUMN last_update TEXT NOT NULL DEFAULT "2000-01-01 12:00:00.123456";
diff --git a/src/migrations/init_1.sql b/src/migrations/init_1.sql
new file mode 100644
index 0000000..6d90d23
--- /dev/null
+++ b/src/migrations/init_1.sql
@@ -0,0 +1,32 @@
+CREATE TABLE yt_queries (
+  id TEXT PRIMARY KEY,
+  text TEXT NOT NULL,
+  retrieved_at TEXT NOT NULL
+);
+CREATE TABLE yt_videos (
+  id TEXT PRIMARY KEY,
+  title TEXT NOT NULL,
+  description TEXT NOT NULL,
+  published_at TEXT NOT NULL,
+  duration TEXT NOT NULL,
+  definition TEXT NOT NULL
+);
+CREATE TABLE yt_query_results (
+  query_id TEXT NOT NULL,
+  video_id TEXT NOT NULL,
+  PRIMARY KEY (query_id, video_id),
+  FOREIGN KEY (query_id) REFERENCES yt_queries(id),
+  FOREIGN KEY (video_id) REFERENCES yt_videos(id)
+);
+CREATE TABLE quota_costs (
+  id TEXT PRIMARY KEY,
+  timestamp TEXT NOT NULL,
+  cost INT NOT NULL
+);
+CREATE TABLE files (
+  rel_path TEXT PRIMARY KEY,
+  yt_id TEXT NOT NULL DEFAULT "",
+  flags INTEGER NOT NULL DEFAULT 0,
+  last_update TEXT NOT NULL DEFAULT "2000-01-01 12:00:00.123456",
+  FOREIGN KEY (yt_id) REFERENCES yt_videos(id)
+);
diff --git a/src/templates/video.tmpl b/src/templates/video.tmpl
index 54d005f..032a9b9 100644
--- a/src/templates/video.tmpl
+++ b/src/templates/video.tmpl
@@ -10,7 +10,7 @@
 </table>
 <form action="/video/{{file.yt_id}}" method="POST" />
 {% for flag_name in flag_names %}
-{{ flag_name }}: <input type="checkbox" name="{{flag_name}}" {% if file.flag_set(flag_name) %}checked {% endif %} /><br />
+{{ flag_name }}: <input type="checkbox" name="{{flag_name}}" {% if file.is_flag_set(flag_name) %}checked {% endif %} /><br />
 {% endfor %}
 <input type="submit" />
 </form>
diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py
index f57eee2..f65351c 100644
--- a/src/ytplom/misc.py
+++ b/src/ytplom/misc.py
@@ -85,7 +85,7 @@ TIMESTAMP_FMT = '%Y-%m-%d %H:%M:%S.%f'
 LEGAL_EXTENSIONS = {'webm', 'mp4', 'mkv'}
 
 # database stuff
-EXPECTED_DB_VERSION = 0
+EXPECTED_DB_VERSION = 1
 SQL_DB_VERSION = SqlText('PRAGMA user_version')
 PATH_MIGRATIONS = PathStr(path_join(PATH_APP_DATA, 'migrations'))
 PATH_DB_SCHEMA = PathStr(path_join(PATH_MIGRATIONS,
@@ -117,6 +117,12 @@ def _ensure_expected_dirs(expected_dirs: list[PathStr]) -> None:
             makedirs(dir_name)
 
 
+def get_db_version(db_path: PathStr) -> int:
+    """Return user_version value of DB at db_path."""
+    with sql_connect(db_path) as conn:
+        return list(conn.execute(SQL_DB_VERSION))[0][0]
+
+
 class DatabaseConnection:
     """Wrapped sqlite3.Connection."""
 
@@ -135,11 +141,11 @@ class DatabaseConnection:
                 with open(PATH_DB_SCHEMA, 'r', encoding='utf8') as f:
                     conn.executescript(f.read())
                 conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
-        with sql_connect(self._path) as conn:
-            db_version = list(conn.execute(SQL_DB_VERSION))[0][0]
-        if db_version != EXPECTED_DB_VERSION:
-            raise HandledException(f'wrong database version {db_version}, '
-                                   f'expected: {EXPECTED_DB_VERSION}')
+        cur_version = get_db_version(self._path)
+        if cur_version != EXPECTED_DB_VERSION:
+            raise HandledException(
+                    f'wrong database version {cur_version}, expected: '
+                    f'{EXPECTED_DB_VERSION} – run "migrate"?')
         self._conn = sql_connect(self._path)
 
     def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor:
@@ -292,13 +298,25 @@ class YoutubeVideo(DbData):
 class VideoFile(DbData):
     """Collects data about downloaded files."""
     _table_name = 'files'
-    _cols = ('rel_path', 'yt_id', 'flags')
+    _cols = ('rel_path', 'yt_id', 'flags', 'last_update')
+    last_update: DatetimeStr
 
-    def __init__(self, rel_path: PathStr, yt_id: YoutubeId, flags=FlagsInt(0)
+    def __init__(self,
+                 rel_path: PathStr,
+                 yt_id: YoutubeId,
+                 flags: FlagsInt = FlagsInt(0),
+                 last_update: Optional[DatetimeStr] = None
                  ) -> None:
         self.rel_path = rel_path
         self.yt_id = yt_id
-        self.flags = flags
+        self._flags = flags
+        if last_update is None:
+            self._renew_last_update()
+        else:
+            self.last_update = last_update
+
+    def _renew_last_update(self):
+        self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT))
 
     @classmethod
     def get_by_yt_id(cls, conn: DatabaseConnection, yt_id: YoutubeId) -> Self:
@@ -322,15 +340,26 @@ class VideoFile(DbData):
     @property
     def missing(self) -> bool:
         """Return if file absent despite absence of 'delete' flag."""
-        return not (self.flag_set(FlagName('delete')) or self.present)
+        return not (self.is_flag_set(FlagName('delete')) or self.present)
+
+    @property
+    def flags(self) -> FlagsInt:
+        """Return value of flags field."""
+        return self._flags
+
+    @flags.setter
+    def flags(self, flags: FlagsInt) -> None:
+        self._renew_last_update()
+        self._flags = flags
 
-    def flag_set(self, flag_name: FlagName) -> bool:
-        """Return if flag of flag_name is set in self.flags."""
-        return self.flags & VIDEO_FLAGS[flag_name]
+    def is_flag_set(self, flag_name: FlagName) -> bool:
+        """Return if flag of flag_name is set in flags field."""
+        return bool(self._flags & VIDEO_FLAGS[flag_name])
 
     def ensure_absence_if_deleted(self) -> None:
         """If 'delete' flag set, ensure no actual file in filesystem."""
-        if self.flag_set(FlagName('delete')) and path_exists(self.full_path):
+        if (self.is_flag_set(FlagName('delete'))
+                and path_exists(self.full_path)):
             print(f'SYNC: {self.rel_path} set "delete", '
                   'removing from filesystem.')
             os_remove(self.full_path)
@@ -647,9 +676,10 @@ class TaskHandler(BaseHTTPRequestHandler):
                          ) -> None:
         conn = DatabaseConnection()
         file = VideoFile.get_by_yt_id(conn, yt_id)
-        file.flags = 0
+        flags = FlagsInt(0)
         for flag_name in flag_names:
-            file.flags |= VIDEO_FLAGS[flag_name]
+            flags = FlagsInt(file.flags | VIDEO_FLAGS[flag_name])
+        file.flags = flags
         file.save(conn)
         conn.commit_close()
         file.ensure_absence_if_deleted()
diff --git a/ytplom b/ytplom
index 64d1981..335023a 100755
--- a/ytplom
+++ b/ytplom
@@ -4,8 +4,8 @@ set -e
 PATH_APP_SHARE=~/.local/share/ytplom
 PATH_VENV="${PATH_APP_SHARE}/venv"
 
-if [ ! "$1" = 'serve' ] && [ ! "$1" = 'sync' ]; then
-    echo "Need argument (either 'serve' or 'sync')."
+if [ ! "$1" = 'serve' ] && [ ! "$1" = 'sync' ] && [ ! "$1" = 'migrate' ]; then
+    echo "Need argument (serve' or 'sync' or 'migrate')."
     false
 fi
 
-- 
2.30.2