From: Christian Heller Date: Sat, 4 Jan 2025 16:57:01 +0000 (+0100) Subject: More DB management code reorganization; add explicit "create" script. X-Git-Url: https://plomlompom.com/repos/%7B%7B%20web_path%20%7D%7D/%7B%7Bprefix%7D%7D/%7B%7Btodo.comment%7D%7D?a=commitdiff_plain;h=799fe5e97556d1ca5820a13fd0a3daa7f1dd7e7e;p=ytplom More DB management code reorganization; add explicit "create" script. --- diff --git a/src/migrate.py b/src/migrate.py index cc5e6cf..9d0dcc6 100755 --- a/src/migrate.py +++ b/src/migrate.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 """Script to migrate DB to most recent schema.""" -from ytplom.migrations import run_migrations +from ytplom.db import DbFile +from ytplom.migrations import MIGRATIONS if __name__ == '__main__': - run_migrations() + DbFile(expected_version=-1).migrate(MIGRATIONS) diff --git a/src/sync.py b/src/sync.py index f9fcb74..a6be397 100755 --- a/src/sync.py +++ b/src/sync.py @@ -8,7 +8,7 @@ from urllib.request import Request, urlopen # non-included libs from paramiko import SSHClient # type: ignore from scp import SCPClient # type: ignore -from ytplom.db import DbConn, Hash, PATH_DB +from ytplom.db import DbConn, DbFile, Hash, PATH_DB from ytplom.misc import (PATH_TEMP, Config, FlagName, QuotaLog, VideoFile, YoutubeQuery, YoutubeVideo) from ytplom.http import PAGE_NAMES @@ -75,7 +75,8 @@ def sync_relations(host_names: tuple[str, str], def sync_dbs(scp: SCPClient) -> None: """Download remote DB, run sync_(objects|relations), put remote DB back.""" scp.get(PATH_DB, PATH_DB_REMOTE) - with DbConn(PATH_DB) as db_local, DbConn(PATH_DB_REMOTE) as db_remote: + with DbConn(DbFile(PATH_DB).connect()) as db_local, \ + DbConn(DbFile(PATH_DB_REMOTE).connect()) as db_remote: for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile): back_and_forth(sync_objects, (db_local, db_remote), cls) for yt_video_local in YoutubeVideo.get_all(db_local): @@ -106,18 +107,19 @@ def fill_missing(scp: SCPClient, config: Config) -> None: for url_missing in _urls_here_and_there(config, 'missing'): with urlopen(url_missing) as response: missings += [list(json_loads(response.read()))] - conn = DbConn() - for i, direction_mover in enumerate([('local->remote', scp.put), - ('remote->local', scp.get)]): - direction, mover = direction_mover - for digest in (d for d in missings[i] - if d not in missings[int(not bool(i))]): - vf = VideoFile.get_one(conn, Hash.from_b64(digest)) - if vf.is_flag_set(FlagName('do not sync')): - print(f'SYNC: not sending ("do not sync" set): {vf.full_path}') - return - print(f'SYNC: sending {direction}: {vf.full_path}') - mover(vf.full_path, vf.full_path) + with DbConn() as conn: + for i, direction_mover in enumerate([('local->remote', scp.put), + ('remote->local', scp.get)]): + direction, mover = direction_mover + for digest in (d for d in missings[i] + if d not in missings[int(not bool(i))]): + vf = VideoFile.get_one(conn, Hash.from_b64(digest)) + if vf.is_flag_set(FlagName('do not sync')): + print(f'SYNC: not sending ("do not sync" set)' + f': {vf.full_path}') + return + print(f'SYNC: sending {direction}: {vf.full_path}') + mover(vf.full_path, vf.full_path) def main(): diff --git a/src/ytplom/db.py b/src/ytplom/db.py index 599b9f8..f503e9b 100644 --- a/src/ytplom/db.py +++ b/src/ytplom/db.py @@ -2,26 +2,22 @@ from base64 import urlsafe_b64decode, urlsafe_b64encode from hashlib import file_digest from pathlib import Path -from sqlite3 import connect as sql_connect, Cursor as SqlCursor, Row as SqlRow -from typing import Any, Literal, NewType, Self +from sqlite3 import (connect as sql_connect, Connection as SqlConnection, + Cursor as SqlCursor, Row as SqlRow) +from typing import Callable, Literal, NewType, Optional, Self from ytplom.primitives import ( HandledException, NotFoundException, PATH_APP_DATA) -SqlText = NewType('SqlText', str) - EXPECTED_DB_VERSION = 6 -PATH_DB = PATH_APP_DATA.joinpath('db.sql') -SQL_DB_VERSION = SqlText('PRAGMA user_version') -PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations') -_HASH_ALGO = 'sha512' -_PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath('new_init.sql') -_NAME_INSTALLER = Path('install.sh') +PATH_DB = PATH_APP_DATA.joinpath('TESTdb.sql') +SqlText = NewType('SqlText', str) +MigrationsList = list[tuple[Path, Optional[Callable]]] -def get_db_version(db_path: Path) -> int: - """Return user_version value of DB at db_path.""" - with sql_connect(db_path) as conn: - return list(conn.execute(SQL_DB_VERSION))[0][0] +_SQL_DB_VERSION = SqlText('PRAGMA user_version') +_PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations') +_HASH_ALGO = 'sha512' +_PATH_DB_SCHEMA = _PATH_MIGRATIONS.joinpath('new_init.sql') class Hash: @@ -58,32 +54,84 @@ class Hash: return urlsafe_b64encode(self.bytes).decode('utf8') -class DbConn: - """Wrapper for sqlite3 connections.""" +class DbFile: + """Wrapper around the file of a sqlite3 database.""" def __init__(self, path: Path = PATH_DB, expected_version: int = EXPECTED_DB_VERSION ) -> None: + self._path = path if not path.is_file(): - if path.exists(): - raise HandledException(f'no DB at {path}; would create, ' - 'but something\'s already there?') - if not path.parent.is_dir(): - raise HandledException( - f'cannot find {path.parent} as directory to put ' - f'DB into, did you run {_NAME_INSTALLER}?') - with sql_connect(path) as conn: - print(f'No DB found at {path}, creating …') - conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8')) - conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') + raise HandledException( + f'no DB file at {path} – run "create"?') if expected_version >= 0: - cur_version = get_db_version(path) - if cur_version != expected_version: + user_version = self._get_user_version() + if user_version != expected_version: raise HandledException( - f'wrong database version {cur_version}, expected: ' + f'wrong database version {user_version}, expected: ' f'{expected_version} – run "migrate"?') - self._conn = sql_connect(path, autocommit=False) + + def _get_user_version(self) -> int: + with sql_connect(self._path) as conn: + return list(conn.execute(_SQL_DB_VERSION))[0][0] + + @staticmethod + def create(path: Path = PATH_DB) -> None: + """Create DB file at path according to _PATH_DB_SCHEMA.""" + if path.exists(): + raise HandledException( + f'There already exists a node at {path}.') + if not path.parent.is_dir(): + raise HandledException( + f'No directory {path.parent} found to write into.') + with sql_connect(path) as conn: + conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8')) + conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') + + def connect(self) -> SqlConnection: + """Open database file into SQL connection, with autocommit off.""" + return sql_connect(self._path, autocommit=False) + + def migrate(self, migrations: MigrationsList) -> None: + """Migrate self towards EXPECTED_DB_VERSION""" + start_version = self._get_user_version() + if start_version == EXPECTED_DB_VERSION: + print('Database at expected version, no migrations to do.') + return + if start_version > EXPECTED_DB_VERSION: + raise HandledException( + f'Cannot migrate backward from version {start_version} to ' + f'{EXPECTED_DB_VERSION}.') + print(f'Trying to migrate from DB version {start_version} to ' + f'{EXPECTED_DB_VERSION} …') + migs_to_do = [] + migs_by_n = dict(enumerate(migrations)) + for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]: + if n not in migs_by_n: + raise HandledException(f'Needed migration missing: {n}') + migs_to_do += [(n, *migs_by_n[n])] + for version, filename_sql, after_sql_steps in migs_to_do: + print(f'Running migration towards: {version}') + with DbConn(self.connect()) as conn: + if filename_sql: + print(f'Executing {filename_sql}') + path_sql = _PATH_MIGRATIONS.joinpath(filename_sql) + conn.exec_script( + SqlText(path_sql.read_text(encoding='utf8'))) + if after_sql_steps: + print('Running additional steps') + after_sql_steps(conn) + conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}')) + conn.commit() + print('Finished migrations.') + + +class DbConn: + """Wrapper for sqlite3 connections.""" + + def __init__(self, sql_conn: Optional[SqlConnection] = None) -> None: + self._conn = sql_conn or DbFile().connect() def __enter__(self) -> Self: return self @@ -92,7 +140,7 @@ class DbConn: self._conn.close() return False - def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple() + def exec(self, sql: SqlText, inputs: tuple = tuple() ) -> SqlCursor: """Wrapper around sqlite3.Connection.execute, building '?' if inputs""" if len(inputs) > 0: @@ -117,7 +165,7 @@ class DbData: _str_field: str _cols: tuple[str, ...] - def __eq__(self, other: Any) -> bool: + def __eq__(self, other) -> bool: if not isinstance(other, self.__class__): return False for attr_name in self._cols: diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py index d5e49a3..59c21f3 100644 --- a/src/ytplom/migrations.py +++ b/src/ytplom/migrations.py @@ -1,48 +1,13 @@ """Anything pertaining specifically to DB migrations.""" from pathlib import Path from typing import Callable -from ytplom.db import ( - get_db_version, DbConn, SqlText, - EXPECTED_DB_VERSION, PATH_DB, PATH_MIGRATIONS, SQL_DB_VERSION) +from ytplom.db import DbConn, MigrationsList, SqlText from ytplom.primitives import HandledException _LEGIT_YES = 'YES!!' -def run_migrations() -> None: - """Try to migrate DB towards EXPECTED_DB_VERSION.""" - start_version = get_db_version(PATH_DB) - if start_version == EXPECTED_DB_VERSION: - print('Database at expected version, no migrations to do.') - return - if start_version > EXPECTED_DB_VERSION: - raise HandledException( - f'Cannot migrate backward from version {start_version} to ' - f'{EXPECTED_DB_VERSION}.') - print(f'Trying to migrate from DB version {start_version} to ' - f'{EXPECTED_DB_VERSION} …') - migs_to_do = [] - migs_by_n = dict(enumerate(MIGRATIONS)) - for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]: - if n not in migs_by_n: - raise HandledException(f'Needed migration missing: {n}') - migs_to_do += [(n, *migs_by_n[n])] - for version, filename_sql, after_sql_steps in migs_to_do: - print(f'Running migration towards: {version}') - with DbConn(expected_version=version-1) as conn: - if filename_sql: - print(f'Executing {filename_sql}') - path_sql = PATH_MIGRATIONS.joinpath(filename_sql) - conn.exec_script(SqlText(path_sql.read_text(encoding='utf8'))) - if after_sql_steps: - print('Running additional steps') - after_sql_steps(conn) - conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}')) - conn.commit() - print('Finished migrations.') - - def _rewrite_files_last_field_processing_first_field(conn: DbConn, cb: Callable ) -> None: @@ -87,7 +52,7 @@ def _mig_4_convert_digests(conn: DbConn) -> None: _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex) -MIGRATIONS = [ +MIGRATIONS: MigrationsList = [ (Path('0_init.sql'), None), (Path('1_add_files_last_updated.sql'), None), (Path('2_add_files_sha512.sql'), _mig_2_calc_digests), diff --git a/ytplom b/ytplom index 335023a..7e5c4e2 100755 --- a/ytplom +++ b/ytplom @@ -4,8 +4,8 @@ set -e PATH_APP_SHARE=~/.local/share/ytplom PATH_VENV="${PATH_APP_SHARE}/venv" -if [ ! "$1" = 'serve' ] && [ ! "$1" = 'sync' ] && [ ! "$1" = 'migrate' ]; then - echo "Need argument (serve' or 'sync' or 'migrate')." +if [ ! "$1" = 'serve' ] && [ ! "$1" = 'sync' ] && [ ! "$1" = 'migrate' ] && [ ! "$1" = 'create' ]; then + echo "Need argument ('serve' or 'sync' or 'migrate' or 'create')." false fi