From 1c8d467107c916b7d96f3c149943dbc565d31ca1 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Tue, 3 Dec 2024 03:58:45 +0100
Subject: [PATCH] Store files hash digest as BLOB field .digest, overhaul DB
 and migrations code.

---
 src/migrations/4_add_files_sha512_blob.sql |  1 +
 src/migrations/5_files_redo.sql            | 17 +++++++
 src/migrations/new_init.sql                |  2 +-
 src/templates/file_data.tmpl               |  2 +-
 src/templates/files.tmpl                   |  4 +-
 src/templates/playlist.tmpl                |  4 +-
 src/ytplom/db.py                           | 52 ++++++++++++++++----
 src/ytplom/http.py                         | 22 ++++-----
 src/ytplom/migrations.py                   | 56 +++++++++++++++-------
 src/ytplom/misc.py                         | 28 +++++------
 10 files changed, 130 insertions(+), 58 deletions(-)
 create mode 100644 src/migrations/4_add_files_sha512_blob.sql
 create mode 100644 src/migrations/5_files_redo.sql

diff --git a/src/migrations/4_add_files_sha512_blob.sql b/src/migrations/4_add_files_sha512_blob.sql
new file mode 100644
index 0000000..c382740
--- /dev/null
+++ b/src/migrations/4_add_files_sha512_blob.sql
@@ -0,0 +1 @@
+ALTER TABLE files ADD COLUMN sha512_blob BLOB;
diff --git a/src/migrations/5_files_redo.sql b/src/migrations/5_files_redo.sql
new file mode 100644
index 0000000..a9c30fc
--- /dev/null
+++ b/src/migrations/5_files_redo.sql
@@ -0,0 +1,17 @@
+CREATE TABLE files_new (
+  digest BLOB PRIMARY KEY,
+  rel_path TEXT NOT NULL,
+  flags INTEGER NOT NULL DEFAULT 0,
+  yt_id TEXT,
+  last_update TEXT NOT NULL,
+  FOREIGN KEY (yt_id) REFERENCES yt_videos(id)
+);
+INSERT INTO files_new SELECT
+  sha512_blob,
+  rel_path,
+  flags,
+  yt_id,
+  last_update
+FROM files;
+DROP TABLE files;
+ALTER TABLE files_new RENAME TO files;
diff --git a/src/migrations/new_init.sql b/src/migrations/new_init.sql
index d223bef..8153a75 100644
--- a/src/migrations/new_init.sql
+++ b/src/migrations/new_init.sql
@@ -24,7 +24,7 @@ CREATE TABLE quota_costs (
   cost INT NOT NULL
 );
 CREATE TABLE files (
-  sha512_digest TEXT PRIMARY KEY,
+  digest BLOB PRIMARY KEY,
   rel_path TEXT NOT NULL,
   flags INTEGER NOT NULL DEFAULT 0,
   yt_id TEXT,
diff --git a/src/templates/file_data.tmpl b/src/templates/file_data.tmpl
index 589ef26..15637f4 100644
--- a/src/templates/file_data.tmpl
+++ b/src/templates/file_data.tmpl
@@ -8,7 +8,7 @@
 <tr><th>YouTube ID:</th><td><a href="/{{page_names.yt_result}}/{{file.yt_id}}">{{file.yt_id}}</a></tr>
 <tr><th>present:</th><td>{% if file.present %}<a href="/{{page_names.download}}/{{file.yt_id}}">yes</a>{% else %}no{% endif %}</td></tr>
 </table>
-<form action="/{{page_names.file}}/{{file.sha512_digest}}" method="POST" />
+<form action="/{{page_names.file}}/{{file.digest.b64}}" method="POST" />
 {% for flag_name in flag_names %}
 {{ flag_name }}: <input type="checkbox" name="{{flag_name}}" {% if file.is_flag_set(flag_name) %}checked {% endif %} /><br />
 {% endfor %}
diff --git a/src/templates/files.tmpl b/src/templates/files.tmpl
index f2a8024..7db6174 100644
--- a/src/templates/files.tmpl
+++ b/src/templates/files.tmpl
@@ -15,8 +15,8 @@ show absent: <input type="checkbox" name="show_absent" {% if show_absent %}check
 {% for file in files %}
 <tr>
 <td>{{ file.size | round(3) }}</td>
-<td><input type="submit" name="play_{{file.sha512_digest}}" value="play" {% if not file.present %}disabled {% endif %}/></td>
-<td><a href="/{{page_names.file}}/{{file.sha512_digest}}">{{file.rel_path}}</a></td>
+<td><input type="submit" name="play_{{file.digest.b64}}" value="play" {% if not file.present %}disabled {% endif %}/></td>
+<td><a href="/{{page_names.file}}/{{file.digest.b64}}">{{file.rel_path}}</a></td>
 </tr>
 {% endfor %}
 </table>
diff --git a/src/templates/playlist.tmpl b/src/templates/playlist.tmpl
index 494a7a1..507d1c3 100644
--- a/src/templates/playlist.tmpl
+++ b/src/templates/playlist.tmpl
@@ -48,7 +48,7 @@ td.entry_buttons { width: 5em; }
 <input type="submit" name="up_{{idx}}" value="{% if reverse %}v{% else %}^{% endif %}" />
 <input type="submit" name="down_{{idx}}" value="{% if reverse %}^{% else %}v{% endif %}" />
 </td>
-<td><a href="/{{page_names.file}}/{{file.sha512_digest}}">{{ file.rel_path }}</a></td>
+<td><a href="/{{page_names.file}}/{{file.digest.b64}}">{{ file.rel_path }}</a></td>
 </tr>
 {% endfor %}
 </table>
@@ -61,7 +61,7 @@ td.entry_buttons { width: 5em; }
 <table>
 <tr><td id="status" colspan=2>
 {% if running %}{% if pause %}PAUSED{% else %}PLAYING{% endif %}{% else %}STOPPED{% endif %}:<br />
-<a href="/{{page_names.file}}/{{current_video.sha512_digest}}">{{ current_video.rel_path }}</a><br />
+<a href="/{{page_names.file}}/{{current_video.digest.b64}}">{{ current_video.rel_path }}</a><br />
 <form action="/{{page_names.playlist}}" method="POST">
 <input type="submit" name="pause" autofocus value="{% if paused %}resume{% else %}pause{% endif %}">
 <input type="submit" name="prev" value="prev">
diff --git a/src/ytplom/db.py b/src/ytplom/db.py
index 9edf5c6..5387eca 100644
--- a/src/ytplom/db.py
+++ b/src/ytplom/db.py
@@ -1,4 +1,6 @@
 """Database access and management code."""
+from base64 import urlsafe_b64decode, urlsafe_b64encode
+from hashlib import file_digest
 from pathlib import Path
 from sqlite3 import (
         connect as sql_connect, Connection as SqlConnection, Cursor, Row)
@@ -8,10 +10,11 @@ from ytplom.primitives import (
 
 SqlText = NewType('SqlText', str)
 
-EXPECTED_DB_VERSION = 3
+EXPECTED_DB_VERSION = 5
 PATH_DB = PATH_APP_DATA.joinpath('db.sql')
 SQL_DB_VERSION = SqlText('PRAGMA user_version')
 PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
+_HASH_ALGO = 'sha512'
 _PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath('new_init.sql')
 _NAME_INSTALLER = Path('install.sh')
 
@@ -22,6 +25,29 @@ def get_db_version(db_path: Path) -> int:
         return list(conn.execute(SQL_DB_VERSION))[0][0]
 
 
+class Hash:
+    """Represents _HASH_ALGO hash of file."""
+
+    def __init__(self, as_bytes: bytes) -> None:
+        self.bytes = as_bytes
+
+    @classmethod
+    def from_file(cls, path: Path) -> Self:
+        """Hash-digest file at path, instantiate with hash's bytes."""
+        with path.open('rb') as f:
+            return cls(file_digest(f, _HASH_ALGO).digest())
+
+    @classmethod
+    def from_b64(cls, b64_str: str) -> Self:
+        """Instantiate from base64 string encoding of hash value."""
+        return cls(bytes(urlsafe_b64decode(b64_str)))
+
+    @property
+    def b64(self) -> str:
+        """Return hash bytes as base64-encoded string."""
+        return urlsafe_b64encode(self.bytes).decode('utf8')
+
+
 class BaseDbConn:
     """Wrapper for pre-established sqlite3.Connection."""
 
@@ -32,6 +58,11 @@ class BaseDbConn:
         """Wrapper around sqlite3.Connection.execute."""
         return self._conn.execute(sql, inputs)
 
+    def exec_on_values(self, sql: SqlText, inputs: tuple[Any, ...]) -> Cursor:
+        """Wraps .exec on inputs, affixes to sql proper ' VALUES (?, …)'."""
+        q_marks = '(' + ','.join(['?'] * len(inputs)) + ')'
+        return self._conn.execute(f'{sql} VALUES {q_marks}', inputs)
+
     def commit(self) -> None:
         """Commit changes (i.e. DbData.save() calls) to database."""
         self._conn.commit()
@@ -92,11 +123,12 @@ class DbData:
         return cls(**kwargs)
 
     @classmethod
-    def get_one(cls, conn: BaseDbConn, id_: str) -> Self:
+    def get_one(cls, conn: BaseDbConn, id_: str | Hash) -> Self:
         """Return single entry of id_ from DB."""
         sql = SqlText(f'SELECT * FROM {cls._table_name} '
                       f'WHERE {cls.id_name} = ?')
-        row = conn.exec(sql, (id_,)).fetchone()
+        id__ = id_.bytes if isinstance(id_, Hash) else id_
+        row = conn.exec(sql, (id__,)).fetchone()
         if not row:
             msg = f'no entry found for ID "{id_}" in table {cls._table_name}'
             raise NotFoundException(msg)
@@ -111,8 +143,12 @@ class DbData:
 
     def save(self, conn: BaseDbConn) -> Cursor:
         """Save entry to DB."""
-        vals = [getattr(self, col_name) for col_name in self._cols]
-        q_marks = '(' + ','.join(['?'] * len(vals)) + ')'
-        sql = SqlText(f'REPLACE INTO {self._table_name} VALUES {q_marks}')
-        return conn.exec(sql, tuple(str(v) if isinstance(v, Path) else v
-                                    for v in vals))
+        vals = []
+        for val in [getattr(self, col_name) for col_name in self._cols]:
+            if isinstance(val, Path):
+                val = str(val)
+            elif isinstance(val, Hash):
+                val = val.bytes
+            vals += [val]
+        return conn.exec_on_values(SqlText(f'REPLACE INTO {self._table_name}'),
+                                   tuple(vals))
diff --git a/src/ytplom/http.py b/src/ytplom/http.py
index 4fa2754..f4d394d 100644
--- a/src/ytplom/http.py
+++ b/src/ytplom/http.py
@@ -9,9 +9,9 @@ from urllib.request import urlretrieve
 from urllib.error import HTTPError
 from jinja2 import (  # type: ignore
         Environment as JinjaEnv, FileSystemLoader as JinjaFSLoader)
-from ytplom.db import DbConn
+from ytplom.db import Hash, DbConn
 from ytplom.misc import (
-        HashStr, FilesWithIndex, FlagName, PlayerUpdateId, QueryId, QueryText,
+        FilesWithIndex, FlagName, PlayerUpdateId, QueryId, QueryText,
         QuotaCost, UrlStr, YoutubeId,
         FILE_FLAGS, PATH_THUMBNAILS, YOUTUBE_URL_PREFIX,
         ensure_expected_dirs,
@@ -106,7 +106,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
         if PAGE_NAMES['files'] == page_name:
             self._receive_files_command(list(postvars.keys())[0])
         elif PAGE_NAMES['file'] == page_name:
-            self._receive_video_flag(HashStr(toks_url[2]),
+            self._receive_video_flag(Hash.from_b64(toks_url[2]),
                                      [FlagName(k) for k in postvars])
         elif PAGE_NAMES['yt_queries'] == page_name:
             self._receive_yt_query(QueryText(postvars['query'][0]))
@@ -134,24 +134,24 @@ class _TaskHandler(BaseHTTPRequestHandler):
     def _receive_files_command(self, command: str) -> None:
         if command.startswith('play_'):
             with DbConn() as conn:
-                file = VideoFile.get_one(conn,
-                                         HashStr(command.split('_', 1)[1]))
+                file = VideoFile.get_one(
+                        conn, Hash.from_b64(command.split('_', 1)[1]))
             self.server.player.inject_and_play(file)
         self._redirect(Path('/'))
 
     def _receive_video_flag(self,
-                            sha512_digest: HashStr,
+                            digest: Hash,
                             flag_names: list[FlagName]
                             ) -> None:
         with DbConn() as conn:
-            file = VideoFile.get_one(conn, sha512_digest)
+            file = VideoFile.get_one(conn, digest)
             file.set_flags([FILE_FLAGS[name] for name in flag_names])
             file.save(conn)
             conn.commit()
         file.ensure_absence_if_deleted()
         self._redirect(Path('/')
                        .joinpath(PAGE_NAMES['file'])
-                       .joinpath(sha512_digest))
+                       .joinpath(digest.b64))
 
     def _receive_yt_query(self, query_txt: QueryText) -> None:
         with DbConn() as conn:
@@ -178,7 +178,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                 show_absent = params.get('show_absent', [False])[0]
                 self._send_files_index(filter_, bool(show_absent))
             elif PAGE_NAMES['file'] == page_name:
-                self._send_file_data(HashStr(toks_url[2]))
+                self._send_file_data(Hash.from_b64(toks_url[2]))
             elif PAGE_NAMES['yt_result'] == page_name:
                 self._send_yt_result(YoutubeId(toks_url[2]))
             elif PAGE_NAMES['missing'] == page_name:
@@ -271,9 +271,9 @@ class _TaskHandler(BaseHTTPRequestHandler):
                  'youtube_prefix': YOUTUBE_URL_PREFIX,
                  'queries': linked_queries})
 
-    def _send_file_data(self, sha512_digest: HashStr) -> None:
+    def _send_file_data(self, digest: Hash) -> None:
         with DbConn() as conn:
-            file = VideoFile.get_one(conn, sha512_digest)
+            file = VideoFile.get_one(conn, digest)
         self._send_rendered_template(
                 _NAME_TEMPLATE_FILE_DATA,
                 {'file': file, 'flag_names': list(FILE_FLAGS)})
diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py
index 73f2406..ccda307 100644
--- a/src/ytplom/migrations.py
+++ b/src/ytplom/migrations.py
@@ -20,6 +20,7 @@ class _Migration:
                  after_sql_steps: Optional[Callable] = None
                  ) -> None:
         self.version = version
+        self._filename_sql = filename_sql
         self._sql_code = None
         if filename_sql:
             path_sql = PATH_MIGRATIONS.joinpath(filename_sql)
@@ -30,8 +31,10 @@ class _Migration:
         """Apply to DB at path_db migration code stored in self."""
         with sql_connect(path_db, autocommit=False) as conn:
             if self._sql_code:
+                print(f'Executing {self._filename_sql}')
                 conn.executescript(self._sql_code)
             if self._after_sql_steps:
+                print('Running additional steps')
                 self._after_sql_steps(conn)
             conn.execute(SqlText(f'{SQL_DB_VERSION} = {self.version}'))
 
@@ -60,36 +63,57 @@ def run_migrations() -> None:
     print('Finished migrations.')
 
 
+def _rewrite_files_last_field_processing_first_field(conn: BaseDbConn,
+                                                     cb: Callable
+                                                     ) -> None:
+    rows = conn.exec(SqlText('SELECT * FROM files')).fetchall()
+    for row in [list(r) for r in rows]:
+        row[-1] = cb(row[0])
+        conn.exec_on_values(SqlText('REPLACE INTO files'), tuple(row))
+
+
 def _mig_2_calc_digests(sql_conn: SqlConnection) -> None:
     """Calculate sha512 digests to all known video files."""
+    # pylint: disable=import-outside-toplevel
     from hashlib import file_digest
-    from ytplom.misc import HashStr, VideoFile
+    from ytplom.misc import PATH_DOWNLOADS
     conn = BaseDbConn(sql_conn)
-    file_entries = VideoFile.get_all(conn)
-    missing = [f for f in file_entries if not f.present]
+    rel_paths = [
+            p[0] for p
+            in conn.exec(SqlText('SELECT rel_path FROM files')).fetchall()]
+    missing = [p for p in rel_paths
+               if not Path(PATH_DOWNLOADS).joinpath(p).exists()]
     if missing:
-        print('WARNING: Cannot find files to following paths')
-        for f in missing:
-            print(f.full_path)
+        print('WARNING: Cannot find files to following (relative) paths:')
+        for path in missing:
+            print(path)
         reply = input(
                 'WARNING: To continue migration, will have to delete above '
                 f'rows from DB. To continue, type (exactly) "{_LEGIT_YES}": ')
         if _LEGIT_YES != reply:
             raise HandledException('Migration aborted!')
-        for f in missing:
-            conn.exec(SqlText('DELETE FROM files WHERE rel_path = ?'),
-                      (str(f.rel_path),))
-    for video_file in VideoFile.get_all(conn):
-        print(f'Calculating digest for: {video_file.rel_path}')
-        with open(video_file.full_path, 'rb') as vf:
-            video_file.sha512_digest = HashStr(
-                    file_digest(vf, 'sha512').hexdigest())
-        video_file.save(conn)
+        for path in missing:
+            conn.exec(SqlText('DELETE FROM files WHERE rel_path = ?'), (path,))
+
+    def hexdigest_file(path):
+        print(f'Calculating digest for: {path}')
+        with Path(PATH_DOWNLOADS).joinpath(path).open('rb') as vf:
+            return file_digest(vf, 'sha512').hexdigest()
+
+    _rewrite_files_last_field_processing_first_field(conn, hexdigest_file)
+
+
+def _mig_4_convert_digests(sql_conn: SqlConnection) -> None:
+    """Fill new files.sha512_blob field with binary .sha512_digest."""
+    _rewrite_files_last_field_processing_first_field(
+            BaseDbConn(sql_conn), bytes.fromhex)
 
 
 MIGRATIONS = [
     _Migration(0, Path('0_init.sql')),
     _Migration(1, Path('1_add_files_last_updated.sql')),
     _Migration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
-    _Migration(3, Path('3_files_redo.sql'))
+    _Migration(3, Path('3_files_redo.sql')),
+    _Migration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+    _Migration(5, Path('5_files_redo.sql'))
 ]
diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py
index 4aa1fed..0dc4a9d 100644
--- a/src/ytplom/misc.py
+++ b/src/ytplom/misc.py
@@ -3,7 +3,6 @@
 # included libs
 from typing import NewType, Optional, Self, TypeAlias
 from os import chdir, environ
-from hashlib import file_digest
 from random import shuffle
 from time import time, sleep
 from datetime import datetime, timedelta
@@ -18,7 +17,7 @@ import googleapiclient.discovery  # type: ignore
 from mpv import MPV  # type: ignore
 from yt_dlp import YoutubeDL  # type: ignore
 # ourselves
-from ytplom.db import BaseDbConn, DbConn, DbData, SqlText
+from ytplom.db import BaseDbConn, DbConn, DbData, Hash, SqlText
 from ytplom.primitives import HandledException, NotFoundException
 
 
@@ -38,7 +37,6 @@ QueryText = NewType('QueryText', str)
 ProseText = NewType('ProseText', str)
 FlagName = NewType('FlagName', str)
 FlagsInt = NewType('FlagsInt', int)
-HashStr = NewType('HashStr', str)
 AmountDownloads = NewType('AmountDownloads', int)
 PlayerUpdateId = NewType('PlayerUpdateId', str)
 UrlStr = NewType('UrlStr', str)
@@ -240,38 +238,34 @@ class YoutubeVideo(DbData):
 
     def save_to_query(self, conn: BaseDbConn, query_id: QueryId) -> None:
         """Save inclusion of self in results to query of query_id."""
-        conn.exec(SqlText('REPLACE INTO yt_query_results VALUES (?, ?)'),
-                  (query_id, self.id_))
+        conn.exec_on_values(SqlText('REPLACE INTO yt_query_results'),
+                            (query_id, self.id_))
 
 
 class VideoFile(DbData):
     """Collects data about downloaded files."""
-    id_name = 'sha512_digest'
+    id_name = 'digest'
     _table_name = 'files'
-    _cols = ('sha512_digest', 'rel_path', 'flags', 'yt_id', 'last_update')
+    _cols = ('digest', 'rel_path', 'flags', 'yt_id', 'last_update')
     last_update: DatetimeStr
     rel_path: Path
+    digest: Hash
 
     def __init__(self,
+                 digest: Optional[Hash],
                  rel_path: Path,
-                 yt_id: Optional[YoutubeId] = None,
                  flags: FlagsInt = FlagsInt(0),
-                 last_update: Optional[DatetimeStr] = None,
-                 sha512_digest: Optional[HashStr] = None
+                 yt_id: Optional[YoutubeId] = None,
+                 last_update: Optional[DatetimeStr] = None
                  ) -> None:
         self.rel_path = rel_path
+        self.digest = digest if digest else Hash.from_file(self.full_path)
         self.yt_id = yt_id
         self.flags = flags
         if last_update is None:
             self._renew_last_update()
         else:
             self.last_update = last_update
-        if sha512_digest is None:
-            with self.full_path.open('rb') as f:
-                self.sha512_digest = HashStr(
-                        file_digest(f, 'sha512').hexdigest())
-        else:
-            self.sha512_digest = sha512_digest
 
     def _renew_last_update(self):
         self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT))
@@ -564,7 +558,7 @@ class DownloadsManager:
                          if p.is_file() and p not in known_paths]:
                 yt_id = self._id_from_filename(path)
                 print(f'SYNC: new file {path}, saving to YT ID "{yt_id}".')
-                file = VideoFile(path, yt_id)
+                file = VideoFile(digest=None, rel_path=path, yt_id=yt_id)
                 file.save(conn)
             self._files = VideoFile.get_all(conn)
             for file in self._files:
-- 
2.30.2