From 71ce6a01ed6f2aed314f86f8b96b0aeda68d9df4 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 2 Jan 2025 15:59:08 +0100
Subject: [PATCH] Simplify DB management code.

---
 src/ytplom/db.py         | 71 ++++++++++++++++++------------------
 src/ytplom/migrations.py | 78 +++++++++++++++-------------------------
 src/ytplom/misc.py       | 38 +++++++++-----------
 3 files changed, 81 insertions(+), 106 deletions(-)

diff --git a/src/ytplom/db.py b/src/ytplom/db.py
index 0fa9d7f..0703b05 100644
--- a/src/ytplom/db.py
+++ b/src/ytplom/db.py
@@ -2,8 +2,7 @@
 from base64 import urlsafe_b64decode, urlsafe_b64encode
 from hashlib import file_digest
 from pathlib import Path
-from sqlite3 import (
-        connect as sql_connect, Connection as SqlConnection, Cursor, Row)
+from sqlite3 import connect as sql_connect, Cursor as DbCursor, Row
 from typing import Any, Literal, NewType, Self
 from ytplom.primitives import (
         HandledException, NotFoundException, PATH_APP_DATA)
@@ -59,30 +58,13 @@ class Hash:
         return urlsafe_b64encode(self.bytes).decode('utf8')
 
 
-class BaseDbConn:
-    """Wrapper for pre-established sqlite3.Connection."""
+class DbConn:
+    """Wrapper for sqlite3 connections."""
 
-    def __init__(self, sql_conn: SqlConnection) -> None:
-        self._conn = sql_conn
-
-    def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor:
-        """Wrapper around sqlite3.Connection.execute."""
-        return self._conn.execute(sql, inputs)
-
-    def exec_on_values(self, sql: SqlText, inputs: tuple[Any, ...]) -> Cursor:
-        """Wraps .exec on inputs, affixes to sql proper ' VALUES (?, …)'."""
-        q_marks = '(' + ','.join(['?'] * len(inputs)) + ')'
-        return self._conn.execute(f'{sql} VALUES {q_marks}', inputs)
-
-    def commit(self) -> None:
-        """Commit changes (i.e. DbData.save() calls) to database."""
-        self._conn.commit()
-
-
-class DbConn(BaseDbConn):
-    """Like parent, but opening (and as context mgr: closing) connection."""
-
-    def __init__(self, path: Path = PATH_DB) -> None:
+    def __init__(self,
+                 path: Path = PATH_DB,
+                 expected_version: int = EXPECTED_DB_VERSION
+                 ) -> None:
         if not path.is_file():
             if path.exists():
                 raise HandledException(f'no DB at {path}; would create, '
@@ -92,14 +74,16 @@ class DbConn(BaseDbConn):
                         f'cannot find {path.parent} as directory to put '
                         f'DB into, did you run {_NAME_INSTALLER}?')
             with sql_connect(path) as conn:
+                print(f'No DB found at {path}, creating …')
                 conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
                 conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
-        cur_version = get_db_version(path)
-        if cur_version != EXPECTED_DB_VERSION:
-            raise HandledException(
-                    f'wrong database version {cur_version}, expected: '
-                    f'{EXPECTED_DB_VERSION} – run "migrate"?')
-        super().__init__(sql_connect(path, autocommit=False))
+        if expected_version >= 0:
+            cur_version = get_db_version(path)
+            if cur_version != expected_version:
+                raise HandledException(
+                        f'wrong database version {cur_version}, expected: '
+                        f'{expected_version} – run "migrate"?')
+        self._conn = sql_connect(path, autocommit=False)
 
     def __enter__(self) -> Self:
         return self
@@ -108,6 +92,25 @@ class DbConn(BaseDbConn):
         self._conn.close()
         return False
 
+    def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()
+             ) -> DbCursor:
+        """Wrapper around sqlite3.Connection.execute."""
+        return self._conn.execute(sql, inputs)
+
+    def exec_on_values(self, sql: SqlText, inputs: tuple[Any, ...]
+                       ) -> DbCursor:
+        """Wraps .exec on inputs, affixes to sql proper ' VALUES (?, …)'."""
+        q_marks = '(' + ','.join(['?'] * len(inputs)) + ')'
+        return self._conn.execute(f'{sql} VALUES {q_marks}', inputs)
+
+    def exec_script(self, sql: SqlText) -> DbCursor:
+        """Wrapper around sqlite3.Connection.executescript."""
+        return self._conn.executescript(sql)
+
+    def commit(self) -> None:
+        """Commit changes (i.e. DbData.save() calls) to database."""
+        self._conn.commit()
+
 
 class DbData:
     """Abstraction of common DB operation."""
@@ -138,7 +141,7 @@ class DbData:
         return cls(**kwargs)
 
     @classmethod
-    def get_one(cls, conn: BaseDbConn, id_: str | Hash) -> Self:
+    def get_one(cls, conn: DbConn, id_: str | Hash) -> Self:
         """Return single entry of id_ from DB."""
         sql = SqlText(f'SELECT * FROM {cls._table_name} '
                       f'WHERE {cls.id_name} = ?')
@@ -150,13 +153,13 @@ class DbData:
         return cls._from_table_row(row)
 
     @classmethod
-    def get_all(cls, conn: BaseDbConn) -> list[Self]:
+    def get_all(cls, conn: DbConn) -> list[Self]:
         """Return all entries from DB."""
         sql = SqlText(f'SELECT * FROM {cls._table_name}')
         rows = conn.exec(sql).fetchall()
         return [cls._from_table_row(row) for row in rows]
 
-    def save(self, conn: BaseDbConn) -> Cursor:
+    def save(self, conn: DbConn) -> DbCursor:
         """Save entry to DB."""
         vals = []
         for val in [getattr(self, col_name) for col_name in self._cols]:
diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py
index 2c9c1e8..d54bdd8 100644
--- a/src/ytplom/migrations.py
+++ b/src/ytplom/migrations.py
@@ -1,44 +1,15 @@
 """Anything pertaining specifically to DB migrations."""
 from pathlib import Path
-from sqlite3 import connect as sql_connect, Connection as SqlConnection
-from typing import Callable, Optional
+from typing import Callable
 from ytplom.db import (
-        get_db_version, BaseDbConn, SqlText, EXPECTED_DB_VERSION, PATH_DB,
-        PATH_MIGRATIONS, SQL_DB_VERSION)
+        get_db_version, DbConn, SqlText,
+        EXPECTED_DB_VERSION, PATH_DB, PATH_MIGRATIONS, SQL_DB_VERSION)
 from ytplom.primitives import HandledException
 
 
 _LEGIT_YES = 'YES!!'
 
 
-class _Migration:
-    """Wrapper for SQL and Python code to apply on migrating."""
-
-    def __init__(self,
-                 version: int,
-                 filename_sql: Optional[Path] = None,
-                 after_sql_steps: Optional[Callable] = None
-                 ) -> None:
-        self.version = version
-        self._filename_sql = filename_sql
-        self._sql_code = None
-        if filename_sql:
-            path_sql = PATH_MIGRATIONS.joinpath(filename_sql)
-            self._sql_code = path_sql.read_text(encoding='utf8')
-        self._after_sql_steps = after_sql_steps
-
-    def apply_to(self, path_db: Path):
-        """Apply to DB at path_db migration code stored in self."""
-        with sql_connect(path_db, autocommit=False) as conn:
-            if self._sql_code:
-                print(f'Executing {self._filename_sql}')
-                conn.executescript(self._sql_code)
-            if self._after_sql_steps:
-                print('Running additional steps')
-                self._after_sql_steps(conn)
-            conn.execute(SqlText(f'{SQL_DB_VERSION} = {self.version}'))
-
-
 def run_migrations() -> None:
     """Try to migrate DB towards EXPECTED_DB_VERSION."""
     start_version = get_db_version(PATH_DB)
@@ -52,18 +23,27 @@ def run_migrations() -> None:
     print(f'Trying to migrate from DB version {start_version} to '
           f'{EXPECTED_DB_VERSION} …')
     migs_to_do = []
-    migs_by_n = {mig.version: mig for mig in MIGRATIONS}
+    migs_by_n = dict(enumerate(MIGRATIONS))
     for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
         if n not in migs_by_n:
             raise HandledException(f'Needed migration missing: {n}')
-        migs_to_do += [migs_by_n[n]]
-    for mig in migs_to_do:
-        print(f'Running migration towards: {mig.version}')
-        mig.apply_to(PATH_DB)
+        migs_to_do += [(n, *migs_by_n[n])]
+    for version, filename_sql, after_sql_steps in migs_to_do:
+        print(f'Running migration towards: {version}')
+        with DbConn(expected_version=version-1) as conn:
+            if filename_sql:
+                print(f'Executing {filename_sql}')
+                path_sql = PATH_MIGRATIONS.joinpath(filename_sql)
+                conn.exec_script(SqlText(path_sql.read_text(encoding='utf8')))
+            if after_sql_steps:
+                print('Running additional steps')
+                after_sql_steps(conn)
+            conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}'))
+            conn.commit()
     print('Finished migrations.')
 
 
-def _rewrite_files_last_field_processing_first_field(conn: BaseDbConn,
+def _rewrite_files_last_field_processing_first_field(conn: DbConn,
                                                      cb: Callable
                                                      ) -> None:
     rows = conn.exec(SqlText('SELECT * FROM files')).fetchall()
@@ -72,12 +52,11 @@ def _rewrite_files_last_field_processing_first_field(conn: BaseDbConn,
         conn.exec_on_values(SqlText('REPLACE INTO files'), tuple(row))
 
 
-def _mig_2_calc_digests(sql_conn: SqlConnection) -> None:
+def _mig_2_calc_digests(conn: DbConn) -> None:
     """Calculate sha512 digests to all known video files."""
     # pylint: disable=import-outside-toplevel
     from hashlib import file_digest
     from ytplom.misc import PATH_DOWNLOADS
-    conn = BaseDbConn(sql_conn)
     rel_paths = [
             p[0] for p
             in conn.exec(SqlText('SELECT rel_path FROM files')).fetchall()]
@@ -103,18 +82,17 @@ def _mig_2_calc_digests(sql_conn: SqlConnection) -> None:
     _rewrite_files_last_field_processing_first_field(conn, hexdigest_file)
 
 
-def _mig_4_convert_digests(sql_conn: SqlConnection) -> None:
+def _mig_4_convert_digests(conn: DbConn) -> None:
     """Fill new files.sha512_blob field with binary .sha512_digest."""
-    _rewrite_files_last_field_processing_first_field(
-            BaseDbConn(sql_conn), bytes.fromhex)
+    _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
 
 
 MIGRATIONS = [
-    _Migration(0, Path('0_init.sql')),
-    _Migration(1, Path('1_add_files_last_updated.sql')),
-    _Migration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
-    _Migration(3, Path('3_files_redo.sql')),
-    _Migration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
-    _Migration(5, Path('5_files_redo.sql')),
-    _Migration(6, Path('6_add_files_tags.sql'))
+    (Path('0_init.sql'), None),
+    (Path('1_add_files_last_updated.sql'), None),
+    (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
+    (Path('3_files_redo.sql'), None),
+    (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+    (Path('5_files_redo.sql'), None),
+    (Path('6_add_files_tags.sql'), None)
 ]
diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py
index 7458cb7..c42b71a 100644
--- a/src/ytplom/misc.py
+++ b/src/ytplom/misc.py
@@ -12,14 +12,13 @@ from uuid import uuid4
 from pathlib import Path
 from threading import Thread
 from queue import Queue
-from sqlite3 import Cursor
 # non-included libs
 from ffmpeg import probe as ffprobe  # type: ignore
 import googleapiclient.discovery  # type: ignore
 from mpv import MPV  # type: ignore
 from yt_dlp import YoutubeDL  # type: ignore
 # ourselves
-from ytplom.db import BaseDbConn, DbConn, DbData, Hash, SqlText
+from ytplom.db import DbConn, DbCursor, DbData, Hash, SqlText
 from ytplom.primitives import HandledException, NotFoundException
 
 
@@ -202,7 +201,7 @@ class YoutubeQuery(DbData):
 
     @classmethod
     def new_by_request_saved(cls,
-                             conn: BaseDbConn,
+                             conn: DbConn,
                              config: Config,
                              query_txt: QueryText
                              ) -> Self:
@@ -252,9 +251,7 @@ class YoutubeQuery(DbData):
         return query
 
     @classmethod
-    def get_all_for_video(cls,
-                          conn: BaseDbConn,
-                          video_id: YoutubeId
+    def get_all_for_video(cls, conn: DbConn, video_id: YoutubeId
                           ) -> list[Self]:
         """Return YoutubeQueries containing YoutubeVideo's ID in results."""
         sql = SqlText('SELECT query_id FROM '
@@ -306,10 +303,7 @@ class YoutubeVideo(DbData):
         self.duration = _readable_seconds(seconds)
 
     @classmethod
-    def get_all_for_query(cls,
-                          conn: BaseDbConn,
-                          query_id: QueryId
-                          ) -> list[Self]:
+    def get_all_for_query(cls, conn: DbConn, query_id: QueryId) -> list[Self]:
         """Return all videos for query of query_id."""
         sql = SqlText('SELECT video_id '
                       'FROM yt_query_results WHERE query_id = ?')
@@ -317,7 +311,7 @@ class YoutubeVideo(DbData):
         return [cls.get_one(conn, video_id_tup[0])
                 for video_id_tup in video_ids]
 
-    def save_to_query(self, conn: BaseDbConn, query_id: QueryId) -> None:
+    def save_to_query(self, conn: DbConn, query_id: QueryId) -> None:
         """Save inclusion of self in results to query of query_id."""
         conn.exec_on_values(SqlText('REPLACE INTO yt_query_results'),
                             (query_id, self.id_))
@@ -361,23 +355,23 @@ class VideoFile(DbData):
     def _renew_last_update(self):
         self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT))
 
-    def save(self, conn: BaseDbConn) -> Cursor:
+    def save(self, conn: DbConn) -> DbCursor:
         """Extend super().save by new .last_update if sufficient changes."""
         if hash(self) != self._hash_on_last_update:
             self._renew_last_update()
         return super().save(conn)
 
     @classmethod
-    def get_one_with_whitelist_tags_display(cls, conn: BaseDbConn, id_: Hash,
-                                            whitelist_tags_display: TagSet
-                                            ) -> Self:
+    def get_one_with_whitelist_tags_display(
+            cls, conn: DbConn, id_: Hash, whitelist_tags_display: TagSet
+            ) -> Self:
         """Same as .get_one except sets .whitelist_tags_display."""
         vf = cls.get_one(conn, id_)
         vf.whitelist_tags_display = whitelist_tags_display
         return vf
 
     @classmethod
-    def get_by_yt_id(cls, conn: BaseDbConn, yt_id: YoutubeId) -> Self:
+    def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self:
         """Return VideoFile of .yt_id."""
         sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id = ?')
         row = conn.exec(sql, (yt_id,)).fetchone()
@@ -387,7 +381,7 @@ class VideoFile(DbData):
 
     @classmethod
     def get_filtered(cls,
-                     conn: BaseDbConn,
+                     conn: DbConn,
                      filter_path: FilterStr,
                      needed_tags_dark: TagSet,
                      needed_tags_seen: TagSet,
@@ -417,7 +411,7 @@ class VideoFile(DbData):
                     'canot show display-whitelisted tags on unset whitelist')
         return self.tags.whitelisted(self.whitelist_tags_display)
 
-    def unused_tags(self, conn: BaseDbConn) -> TagSet:
+    def unused_tags(self, conn: DbConn) -> TagSet:
         """Return tags used among other VideoFiles, not in self."""
         if self.whitelist_tags_display is None:
             raise HandledException(
@@ -488,7 +482,7 @@ class VideoFile(DbData):
         self.full_path.unlink()
 
     @classmethod
-    def purge_deleteds(cls, conn: BaseDbConn) -> None:
+    def purge_deleteds(cls, conn: DbConn) -> None:
         """For all of .is_flag_set("deleted"), remove file _and_ DB entry."""
         for file in [f for f in cls.get_all(conn)
                      if f.is_flag_set(FlagName('delete'))]:
@@ -516,7 +510,7 @@ class QuotaLog(DbData):
         self.cost = cost
 
     @classmethod
-    def update(cls, conn: BaseDbConn, cost: QuotaCost) -> None:
+    def update(cls, conn: DbConn, cost: QuotaCost) -> None:
         """Adds cost mapped to current datetime."""
         cls._remove_old(conn)
         new = cls(None,
@@ -525,14 +519,14 @@ class QuotaLog(DbData):
         new.save(conn)
 
     @classmethod
-    def current(cls, conn: BaseDbConn) -> QuotaCost:
+    def current(cls, conn: DbConn) -> QuotaCost:
         """Returns quota cost total for last 24 hours, purges old data."""
         cls._remove_old(conn)
         quota_costs = cls.get_all(conn)
         return QuotaCost(sum(c.cost for c in quota_costs))
 
     @classmethod
-    def _remove_old(cls, conn: BaseDbConn) -> None:
+    def _remove_old(cls, conn: DbConn) -> None:
         cutoff = datetime.now() - timedelta(days=1)
         sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp < ?')
         conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),))
-- 
2.30.2