From 799fe5e97556d1ca5820a13fd0a3daa7f1dd7e7e Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sat, 4 Jan 2025 17:57:01 +0100
Subject: [PATCH] More DB management code reorganization; add explicit "create"
 script.

---
 src/migrate.py           |   5 +-
 src/sync.py              |  30 ++++++-----
 src/ytplom/db.py         | 114 +++++++++++++++++++++++++++------------
 src/ytplom/migrations.py |  39 +-------------
 ytplom                   |   4 +-
 5 files changed, 104 insertions(+), 88 deletions(-)

diff --git a/src/migrate.py b/src/migrate.py
index cc5e6cf..9d0dcc6 100755
--- a/src/migrate.py
+++ b/src/migrate.py
@@ -1,7 +1,8 @@
 #!/usr/bin/env python3
 """Script to migrate DB to most recent schema."""
-from ytplom.migrations import run_migrations
+from ytplom.db import DbFile
+from ytplom.migrations import MIGRATIONS
 
 
 if __name__ == '__main__':
-    run_migrations()
+    DbFile(expected_version=-1).migrate(MIGRATIONS)
diff --git a/src/sync.py b/src/sync.py
index f9fcb74..a6be397 100755
--- a/src/sync.py
+++ b/src/sync.py
@@ -8,7 +8,7 @@ from urllib.request import Request, urlopen
 # non-included libs
 from paramiko import SSHClient  # type: ignore
 from scp import SCPClient  # type: ignore
-from ytplom.db import DbConn, Hash, PATH_DB
+from ytplom.db import DbConn, DbFile, Hash, PATH_DB
 from ytplom.misc import (PATH_TEMP, Config, FlagName, QuotaLog, VideoFile,
                          YoutubeQuery, YoutubeVideo)
 from ytplom.http import PAGE_NAMES
@@ -75,7 +75,8 @@ def sync_relations(host_names: tuple[str, str],
 def sync_dbs(scp: SCPClient) -> None:
     """Download remote DB, run sync_(objects|relations), put remote DB back."""
     scp.get(PATH_DB, PATH_DB_REMOTE)
-    with DbConn(PATH_DB) as db_local, DbConn(PATH_DB_REMOTE) as db_remote:
+    with DbConn(DbFile(PATH_DB).connect()) as db_local, \
+            DbConn(DbFile(PATH_DB_REMOTE).connect()) as db_remote:
         for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile):
             back_and_forth(sync_objects, (db_local, db_remote), cls)
         for yt_video_local in YoutubeVideo.get_all(db_local):
@@ -106,18 +107,19 @@ def fill_missing(scp: SCPClient, config: Config) -> None:
     for url_missing in _urls_here_and_there(config, 'missing'):
         with urlopen(url_missing) as response:
             missings += [list(json_loads(response.read()))]
-    conn = DbConn()
-    for i, direction_mover in enumerate([('local->remote', scp.put),
-                                         ('remote->local', scp.get)]):
-        direction, mover = direction_mover
-        for digest in (d for d in missings[i]
-                       if d not in missings[int(not bool(i))]):
-            vf = VideoFile.get_one(conn, Hash.from_b64(digest))
-            if vf.is_flag_set(FlagName('do not sync')):
-                print(f'SYNC: not sending ("do not sync" set): {vf.full_path}')
-                return
-            print(f'SYNC: sending {direction}: {vf.full_path}')
-            mover(vf.full_path, vf.full_path)
+    with DbConn() as conn:
+        for i, direction_mover in enumerate([('local->remote', scp.put),
+                                             ('remote->local', scp.get)]):
+            direction, mover = direction_mover
+            for digest in (d for d in missings[i]
+                           if d not in missings[int(not bool(i))]):
+                vf = VideoFile.get_one(conn, Hash.from_b64(digest))
+                if vf.is_flag_set(FlagName('do not sync')):
+                    print(f'SYNC: not sending ("do not sync" set)'
+                          f': {vf.full_path}')
+                    return
+                print(f'SYNC: sending {direction}: {vf.full_path}')
+                mover(vf.full_path, vf.full_path)
 
 
 def main():
diff --git a/src/ytplom/db.py b/src/ytplom/db.py
index 599b9f8..f503e9b 100644
--- a/src/ytplom/db.py
+++ b/src/ytplom/db.py
@@ -2,26 +2,22 @@
 from base64 import urlsafe_b64decode, urlsafe_b64encode
 from hashlib import file_digest
 from pathlib import Path
-from sqlite3 import connect as sql_connect, Cursor as SqlCursor, Row as SqlRow
-from typing import Any, Literal, NewType, Self
+from sqlite3 import (connect as sql_connect, Connection as SqlConnection,
+                     Cursor as SqlCursor, Row as SqlRow)
+from typing import Callable, Literal, NewType, Optional, Self
 from ytplom.primitives import (
         HandledException, NotFoundException, PATH_APP_DATA)
 
-SqlText = NewType('SqlText', str)
-
 EXPECTED_DB_VERSION = 6
-PATH_DB = PATH_APP_DATA.joinpath('db.sql')
-SQL_DB_VERSION = SqlText('PRAGMA user_version')
-PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
-_HASH_ALGO = 'sha512'
-_PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath('new_init.sql')
-_NAME_INSTALLER = Path('install.sh')
+PATH_DB = PATH_APP_DATA.joinpath('TESTdb.sql')
 
+SqlText = NewType('SqlText', str)
+MigrationsList = list[tuple[Path, Optional[Callable]]]
 
-def get_db_version(db_path: Path) -> int:
-    """Return user_version value of DB at db_path."""
-    with sql_connect(db_path) as conn:
-        return list(conn.execute(SQL_DB_VERSION))[0][0]
+_SQL_DB_VERSION = SqlText('PRAGMA user_version')
+_PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
+_HASH_ALGO = 'sha512'
+_PATH_DB_SCHEMA = _PATH_MIGRATIONS.joinpath('new_init.sql')
 
 
 class Hash:
@@ -58,32 +54,84 @@ class Hash:
         return urlsafe_b64encode(self.bytes).decode('utf8')
 
 
-class DbConn:
-    """Wrapper for sqlite3 connections."""
+class DbFile:
+    """Wrapper around the file of a sqlite3 database."""
 
     def __init__(self,
                  path: Path = PATH_DB,
                  expected_version: int = EXPECTED_DB_VERSION
                  ) -> None:
+        self._path = path
         if not path.is_file():
-            if path.exists():
-                raise HandledException(f'no DB at {path}; would create, '
-                                       'but something\'s already there?')
-            if not path.parent.is_dir():
-                raise HandledException(
-                        f'cannot find {path.parent} as directory to put '
-                        f'DB into, did you run {_NAME_INSTALLER}?')
-            with sql_connect(path) as conn:
-                print(f'No DB found at {path}, creating …')
-                conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
-                conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
+            raise HandledException(
+                    f'no DB file at {path} – run "create"?')
         if expected_version >= 0:
-            cur_version = get_db_version(path)
-            if cur_version != expected_version:
+            user_version = self._get_user_version()
+            if user_version != expected_version:
                 raise HandledException(
-                        f'wrong database version {cur_version}, expected: '
+                        f'wrong database version {user_version}, expected: '
                         f'{expected_version} – run "migrate"?')
-        self._conn = sql_connect(path, autocommit=False)
+
+    def _get_user_version(self) -> int:
+        with sql_connect(self._path) as conn:
+            return list(conn.execute(_SQL_DB_VERSION))[0][0]
+
+    @staticmethod
+    def create(path: Path = PATH_DB) -> None:
+        """Create DB file at path according to _PATH_DB_SCHEMA."""
+        if path.exists():
+            raise HandledException(
+                    f'There already exists a node at {path}.')
+        if not path.parent.is_dir():
+            raise HandledException(
+                    f'No directory {path.parent} found to write into.')
+        with sql_connect(path) as conn:
+            conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
+            conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
+
+    def connect(self) -> SqlConnection:
+        """Open database file into SQL connection, with autocommit off."""
+        return sql_connect(self._path, autocommit=False)
+
+    def migrate(self, migrations: MigrationsList) -> None:
+        """Migrate self towards EXPECTED_DB_VERSION"""
+        start_version = self._get_user_version()
+        if start_version == EXPECTED_DB_VERSION:
+            print('Database at expected version, no migrations to do.')
+            return
+        if start_version > EXPECTED_DB_VERSION:
+            raise HandledException(
+                    f'Cannot migrate backward from version {start_version} to '
+                    f'{EXPECTED_DB_VERSION}.')
+        print(f'Trying to migrate from DB version {start_version} to '
+              f'{EXPECTED_DB_VERSION} …')
+        migs_to_do = []
+        migs_by_n = dict(enumerate(migrations))
+        for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
+            if n not in migs_by_n:
+                raise HandledException(f'Needed migration missing: {n}')
+            migs_to_do += [(n, *migs_by_n[n])]
+        for version, filename_sql, after_sql_steps in migs_to_do:
+            print(f'Running migration towards: {version}')
+            with DbConn(self.connect()) as conn:
+                if filename_sql:
+                    print(f'Executing {filename_sql}')
+                    path_sql = _PATH_MIGRATIONS.joinpath(filename_sql)
+                    conn.exec_script(
+                            SqlText(path_sql.read_text(encoding='utf8')))
+                if after_sql_steps:
+                    print('Running additional steps')
+                    after_sql_steps(conn)
+                conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}'))
+                conn.commit()
+        print('Finished migrations.')
+
+
+class DbConn:
+    """Wrapper for sqlite3 connections."""
+
+    def __init__(self, sql_conn: Optional[SqlConnection] = None) -> None:
+        self._conn = sql_conn or DbFile().connect()
 
     def __enter__(self) -> Self:
         return self
@@ -92,7 +140,7 @@ class DbConn:
         self._conn.close()
         return False
 
-    def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()
+    def exec(self, sql: SqlText, inputs: tuple = tuple()
              ) -> SqlCursor:
         """Wrapper around sqlite3.Connection.execute, building '?' if inputs"""
         if len(inputs) > 0:
@@ -117,7 +165,7 @@ class DbData:
     _str_field: str
     _cols: tuple[str, ...]
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other) -> bool:
         if not isinstance(other, self.__class__):
             return False
         for attr_name in self._cols:
diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py
index d5e49a3..59c21f3 100644
--- a/src/ytplom/migrations.py
+++ b/src/ytplom/migrations.py
@@ -1,48 +1,13 @@
 """Anything pertaining specifically to DB migrations."""
 from pathlib import Path
 from typing import Callable
-from ytplom.db import (
-        get_db_version, DbConn, SqlText,
-        EXPECTED_DB_VERSION, PATH_DB, PATH_MIGRATIONS, SQL_DB_VERSION)
+from ytplom.db import DbConn, MigrationsList, SqlText
 from ytplom.primitives import HandledException
 
 
 _LEGIT_YES = 'YES!!'
 
 
-def run_migrations() -> None:
-    """Try to migrate DB towards EXPECTED_DB_VERSION."""
-    start_version = get_db_version(PATH_DB)
-    if start_version == EXPECTED_DB_VERSION:
-        print('Database at expected version, no migrations to do.')
-        return
-    if start_version > EXPECTED_DB_VERSION:
-        raise HandledException(
-                f'Cannot migrate backward from version {start_version} to '
-                f'{EXPECTED_DB_VERSION}.')
-    print(f'Trying to migrate from DB version {start_version} to '
-          f'{EXPECTED_DB_VERSION} …')
-    migs_to_do = []
-    migs_by_n = dict(enumerate(MIGRATIONS))
-    for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
-        if n not in migs_by_n:
-            raise HandledException(f'Needed migration missing: {n}')
-        migs_to_do += [(n, *migs_by_n[n])]
-    for version, filename_sql, after_sql_steps in migs_to_do:
-        print(f'Running migration towards: {version}')
-        with DbConn(expected_version=version-1) as conn:
-            if filename_sql:
-                print(f'Executing {filename_sql}')
-                path_sql = PATH_MIGRATIONS.joinpath(filename_sql)
-                conn.exec_script(SqlText(path_sql.read_text(encoding='utf8')))
-            if after_sql_steps:
-                print('Running additional steps')
-                after_sql_steps(conn)
-            conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}'))
-            conn.commit()
-    print('Finished migrations.')
-
-
 def _rewrite_files_last_field_processing_first_field(conn: DbConn,
                                                      cb: Callable
                                                      ) -> None:
@@ -87,7 +52,7 @@ def _mig_4_convert_digests(conn: DbConn) -> None:
     _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
 
 
-MIGRATIONS = [
+MIGRATIONS: MigrationsList = [
     (Path('0_init.sql'), None),
     (Path('1_add_files_last_updated.sql'), None),
     (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
diff --git a/ytplom b/ytplom
index 335023a..7e5c4e2 100755
--- a/ytplom
+++ b/ytplom
@@ -4,8 +4,8 @@ set -e
 PATH_APP_SHARE=~/.local/share/ytplom
 PATH_VENV="${PATH_APP_SHARE}/venv"
 
-if [ ! "$1" = 'serve' ] && [ ! "$1" = 'sync' ] && [ ! "$1" = 'migrate' ]; then
-    echo "Need argument (serve' or 'sync' or 'migrate')."
+if [ ! "$1" = 'serve' ] && [ ! "$1" = 'sync' ] && [ ! "$1" = 'migrate' ] && [ ! "$1" = 'create' ]; then
+    echo "Need argument ('serve' or 'sync' or 'migrate' or 'create')."
     false
 fi
 
-- 
2.30.2