From 3b30254cdec658814a7e59b18f790103a59e136f Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sun, 1 Dec 2024 08:07:59 +0100
Subject: [PATCH] To files table, add sha512 checksum field.

---
 install.sh                                |  2 +-
 src/migrate.py                            | 49 +++++++++++++++++------
 src/migrations/2_add_files_sha512.py      | 28 +++++++++++++
 src/migrations/2_add_files_sha512.sql     |  1 +
 src/migrations/{init_1.sql => init_2.sql} |  1 +
 src/ytplom/misc.py                        | 33 ++++++++++-----
 6 files changed, 91 insertions(+), 23 deletions(-)
 create mode 100644 src/migrations/2_add_files_sha512.py
 create mode 100644 src/migrations/2_add_files_sha512.sql
 rename src/migrations/{init_1.sql => init_2.sql} (95%)

diff --git a/install.sh b/install.sh
index 18b764b..8c0c6d7 100755
--- a/install.sh
+++ b/install.sh
@@ -7,7 +7,7 @@ NAME_EXECUTABLE=ytplom
 
 mkdir -p "${PATH_APP_SHARE}" "${PATH_LOCAL_BIN}"
 
-rm -f ${PATH_APP_SHARE}/migrations/*
+rm -rf ${PATH_APP_SHARE}/migrations/*
 
 cp -r ./src/* "${PATH_APP_SHARE}/"
 cp "${NAME_EXECUTABLE}" "${PATH_LOCAL_BIN}/"
diff --git a/src/migrate.py b/src/migrate.py
index 85f4af4..fc63965 100755
--- a/src/migrate.py
+++ b/src/migrate.py
@@ -1,10 +1,15 @@
 #!/usr/bin/env python3
 """Script to migrate DB to most recent schema."""
+from importlib.util import spec_from_file_location, module_from_spec
+from pathlib import Path
 from sys import exit as sys_exit
-from sqlite3 import connect as sql_connect
 from ytplom.misc import (
         EXPECTED_DB_VERSION, PATH_DB, PATH_DB_SCHEMA, PATH_MIGRATIONS,
-        SQL_DB_VERSION, HandledException, get_db_version)
+        SQL_DB_VERSION, get_db_version, DbConn, HandledException, SqlText)
+
+
+_SUFFIX_PY = '.py'
+_SUFFIX_SQL = '.sql'
 
 
 def main() -> None:
@@ -19,26 +24,46 @@ def main() -> None:
                 f'{EXPECTED_DB_VERSION}.')
     print(f'Trying to migrate from DB version {start_version} to '
           f'{EXPECTED_DB_VERSION} …')
-    needed = [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]
-    migrations = {}
+    migrations: dict[int, list[Path]] = {
+            n+1: [] for n in range(start_version, EXPECTED_DB_VERSION)}
     for path in [p for p in PATH_MIGRATIONS.iterdir()
                  if p.is_file() and p != PATH_DB_SCHEMA]:
         toks = path.name.split('_')
         try:
             version = int(toks[0])
+            if path.suffix not in {_SUFFIX_PY, _SUFFIX_SQL}:
+                raise ValueError
         except ValueError as e:
             msg = f'Found illegal migration path {path}, aborting.'
             raise HandledException(msg) from e
-        if version in needed:
-            migrations[version] = path
-    missing = [n for n in needed if n not in migrations]
+        if version in migrations:
+            migrations[version] += [path]
+    missing = [n for n in migrations.keys() if not migrations[n]]
     if missing:
         raise HandledException(f'Needed migrations missing: {missing}')
-    with sql_connect(PATH_DB) as conn:
-        for version_number, migration_path in migrations.items():
-            print(f'Applying migration {version_number}: {migration_path}')
-            conn.executescript(migration_path.read_text(encoding='utf8'))
-            conn.execute(f'{SQL_DB_VERSION} = {version_number}')
+    with DbConn(check_version=False) as conn:
+        for version, migration_paths in migrations.items():
+            sorted_paths = sorted(migration_paths)
+            msg_apply_prefix = f'Applying migration {version}: '
+            for path in [p for p in sorted_paths if _SUFFIX_SQL == p.suffix]:
+                print(f'{msg_apply_prefix}{path}')
+                sql = SqlText(path.read_text(encoding='utf8'))
+                conn.exec(sql)
+            for path in [p for p in sorted_paths if _SUFFIX_PY == p.suffix]:
+                spec = spec_from_file_location(str(path), path)
+                assert spec is not None
+                assert spec.loader is not None
+                module = module_from_spec(spec)
+                assert module is not None
+                spec.loader.exec_module(module)
+                if hasattr(module, 'migrate'):
+                    print(f'{msg_apply_prefix}{path}')
+                    module.migrate(conn)
+                else:
+                    raise HandledException(
+                        f'Suspected migration file {path} missing migrate().')
+        conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}'))
+        conn.commit()
 
 
 if __name__ == '__main__':
diff --git a/src/migrations/2_add_files_sha512.py b/src/migrations/2_add_files_sha512.py
new file mode 100644
index 0000000..329286f
--- /dev/null
+++ b/src/migrations/2_add_files_sha512.py
@@ -0,0 +1,28 @@
+from hashlib import file_digest
+from ytplom.misc import DbConn, HandledException, HashStr, SqlText, VideoFile
+
+
+_LEGIT_YES = 'YES!'
+
+
+def migrate(conn: DbConn) -> None:
+    file_entries = VideoFile.get_all(conn)
+    missing = [f for f in file_entries if not f.present]
+    if missing:
+        print('WARNING: Cannot find files to following paths')
+        for f in missing:
+            print(f.full_path)
+        reply = input(
+                'WARNING: To continue migration, will have to delete above '
+                f'rows from DB. To continue, type (exactly) "{_LEGIT_YES}": ')
+        if "YES!" != reply:
+            raise HandledException('Migration aborted!')
+        for f in missing:
+            conn.exec(SqlText('DELETE FROM files WHERE rel_path = ?'),
+                      (str(f.rel_path),))
+    for file in VideoFile.get_all(conn):
+        print(f'Calculating digest for: {file.rel_path}')
+        with open(file.full_path, 'rb') as x:
+            file.sha512_digest = HashStr(
+                    file_digest(x, 'sha512').hexdigest())
+        file.save(conn)
diff --git a/src/migrations/2_add_files_sha512.sql b/src/migrations/2_add_files_sha512.sql
new file mode 100644
index 0000000..36d99e1
--- /dev/null
+++ b/src/migrations/2_add_files_sha512.sql
@@ -0,0 +1 @@
+ALTER TABLE files ADD COLUMN sha512_digest TEXT NOT NULL DEFAULT "";
diff --git a/src/migrations/init_1.sql b/src/migrations/init_2.sql
similarity index 95%
rename from src/migrations/init_1.sql
rename to src/migrations/init_2.sql
index 6d90d23..aaa866b 100644
--- a/src/migrations/init_1.sql
+++ b/src/migrations/init_2.sql
@@ -28,5 +28,6 @@ CREATE TABLE files (
   yt_id TEXT NOT NULL DEFAULT "",
   flags INTEGER NOT NULL DEFAULT 0,
   last_update TEXT NOT NULL DEFAULT "2000-01-01 12:00:00.123456",
+  sha512_digest TEXT NOT NULL DEFAULT "",
   FOREIGN KEY (yt_id) REFERENCES yt_videos(id)
 );
diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py
index 3fd8218..512c755 100644
--- a/src/ytplom/misc.py
+++ b/src/ytplom/misc.py
@@ -4,6 +4,7 @@
 from typing import Any, Literal, NewType, Optional, Self, TypeAlias
 from os import chdir, environ
 from base64 import urlsafe_b64encode, urlsafe_b64decode
+from hashlib import file_digest
 from random import shuffle
 from time import time, sleep
 from datetime import datetime, timedelta
@@ -36,6 +37,7 @@ ProseText = NewType('ProseText', str)
 SqlText = NewType('SqlText', str)
 FlagName = NewType('FlagName', str)
 FlagsInt = NewType('FlagsInt', int)
+HashStr = NewType('HashStr', str)
 AmountDownloads = NewType('AmountDownloads', int)
 PlayerUpdateId = NewType('PlayerUpdateId', str)
 B64Str = NewType('B64Str', str)
@@ -66,7 +68,7 @@ QUOTA_COST_YOUTUBE_SEARCH = QuotaCost(100)
 QUOTA_COST_YOUTUBE_DETAILS = QuotaCost(1)
 
 # database stuff
-EXPECTED_DB_VERSION = 1
+EXPECTED_DB_VERSION = 2
 SQL_DB_VERSION = SqlText('PRAGMA user_version')
 PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
 PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath(f'init_{EXPECTED_DB_VERSION}.sql')
@@ -130,7 +132,10 @@ class Config:
 class DbConn:
     """Wrapped sqlite3.Connection."""
 
-    def __init__(self, path: Path = PATH_DB) -> None:
+    def __init__(self,
+                 path: Path = PATH_DB,
+                 check_version: bool = True
+                 ) -> None:
         self._path = path
         if not self._path.is_file():
             if self._path.exists():
@@ -143,11 +148,12 @@ class DbConn:
             with sql_connect(self._path) as conn:
                 conn.executescript(PATH_DB_SCHEMA.read_text(encoding='utf8'))
                 conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
-        cur_version = get_db_version(self._path)
-        if cur_version != EXPECTED_DB_VERSION:
-            raise HandledException(
-                    f'wrong database version {cur_version}, expected: '
-                    f'{EXPECTED_DB_VERSION} – run "migrate"?')
+        if check_version:
+            cur_version = get_db_version(self._path)
+            if cur_version != EXPECTED_DB_VERSION:
+                raise HandledException(
+                        f'wrong database version {cur_version}, expected: '
+                        f'{EXPECTED_DB_VERSION} – run "migrate"?')
         self._conn = sql_connect(self._path)
 
     def __enter__(self) -> Self:
@@ -361,7 +367,7 @@ class VideoFile(DbData):
     """Collects data about downloaded files."""
     id_name = 'rel_path'
     _table_name = 'files'
-    _cols = ('rel_path', 'yt_id', 'flags', 'last_update')
+    _cols = ('rel_path', 'yt_id', 'flags', 'last_update', 'sha512_digest')
     last_update: DatetimeStr
     rel_path: Path
 
@@ -369,7 +375,8 @@ class VideoFile(DbData):
                  rel_path: Path,
                  yt_id: YoutubeId,
                  flags: FlagsInt = FlagsInt(0),
-                 last_update: Optional[DatetimeStr] = None
+                 last_update: Optional[DatetimeStr] = None,
+                 sha512_digest: Optional[HashStr] = None
                  ) -> None:
         self.rel_path = rel_path
         self.yt_id = yt_id
@@ -378,6 +385,12 @@ class VideoFile(DbData):
             self._renew_last_update()
         else:
             self.last_update = last_update
+        if sha512_digest is None:
+            with self.full_path.open('rb') as f:
+                self.sha512_digest = HashStr(
+                        file_digest(f, 'sha512').hexdigest())
+        else:
+            self.sha512_digest = sha512_digest
 
     def _renew_last_update(self):
         self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT))
@@ -679,8 +692,8 @@ class DownloadsManager:
             for path in [p for p in Path('.').iterdir()
                          if p.is_file() and p not in known_paths]:
                 yt_id = self._id_from_filename(path)
+                print(f'SYNC: new file {path}, saving to YT ID "{yt_id}".')
                 file = VideoFile(path, yt_id)
-                print(f'SYNC: new file {path}, saving with YT ID "{yt_id}".')
                 file.save(conn)
             self._files = VideoFile.get_all(conn)
             for file in self._files:
-- 
2.30.2