From: Christian Heller Date: Thu, 2 Jan 2025 14:59:08 +0000 (+0100) Subject: Simplify DB management code. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/condition?a=commitdiff_plain;h=71ce6a01ed6f2aed314f86f8b96b0aeda68d9df4;p=ytplom Simplify DB management code. --- diff --git a/src/ytplom/db.py b/src/ytplom/db.py index 0fa9d7f..0703b05 100644 --- a/src/ytplom/db.py +++ b/src/ytplom/db.py @@ -2,8 +2,7 @@ from base64 import urlsafe_b64decode, urlsafe_b64encode from hashlib import file_digest from pathlib import Path -from sqlite3 import ( - connect as sql_connect, Connection as SqlConnection, Cursor, Row) +from sqlite3 import connect as sql_connect, Cursor as DbCursor, Row from typing import Any, Literal, NewType, Self from ytplom.primitives import ( HandledException, NotFoundException, PATH_APP_DATA) @@ -59,30 +58,13 @@ class Hash: return urlsafe_b64encode(self.bytes).decode('utf8') -class BaseDbConn: - """Wrapper for pre-established sqlite3.Connection.""" +class DbConn: + """Wrapper for sqlite3 connections.""" - def __init__(self, sql_conn: SqlConnection) -> None: - self._conn = sql_conn - - def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor: - """Wrapper around sqlite3.Connection.execute.""" - return self._conn.execute(sql, inputs) - - def exec_on_values(self, sql: SqlText, inputs: tuple[Any, ...]) -> Cursor: - """Wraps .exec on inputs, affixes to sql proper ' VALUES (?, …)'.""" - q_marks = '(' + ','.join(['?'] * len(inputs)) + ')' - return self._conn.execute(f'{sql} VALUES {q_marks}', inputs) - - def commit(self) -> None: - """Commit changes (i.e. DbData.save() calls) to database.""" - self._conn.commit() - - -class DbConn(BaseDbConn): - """Like parent, but opening (and as context mgr: closing) connection.""" - - def __init__(self, path: Path = PATH_DB) -> None: + def __init__(self, + path: Path = PATH_DB, + expected_version: int = EXPECTED_DB_VERSION + ) -> None: if not path.is_file(): if path.exists(): raise HandledException(f'no DB at {path}; would create, ' @@ -92,14 +74,16 @@ class DbConn(BaseDbConn): f'cannot find {path.parent} as directory to put ' f'DB into, did you run {_NAME_INSTALLER}?') with sql_connect(path) as conn: + print(f'No DB found at {path}, creating …') conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8')) conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') - cur_version = get_db_version(path) - if cur_version != EXPECTED_DB_VERSION: - raise HandledException( - f'wrong database version {cur_version}, expected: ' - f'{EXPECTED_DB_VERSION} – run "migrate"?') - super().__init__(sql_connect(path, autocommit=False)) + if expected_version >= 0: + cur_version = get_db_version(path) + if cur_version != expected_version: + raise HandledException( + f'wrong database version {cur_version}, expected: ' + f'{expected_version} – run "migrate"?') + self._conn = sql_connect(path, autocommit=False) def __enter__(self) -> Self: return self @@ -108,6 +92,25 @@ class DbConn(BaseDbConn): self._conn.close() return False + def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple() + ) -> DbCursor: + """Wrapper around sqlite3.Connection.execute.""" + return self._conn.execute(sql, inputs) + + def exec_on_values(self, sql: SqlText, inputs: tuple[Any, ...] + ) -> DbCursor: + """Wraps .exec on inputs, affixes to sql proper ' VALUES (?, …)'.""" + q_marks = '(' + ','.join(['?'] * len(inputs)) + ')' + return self._conn.execute(f'{sql} VALUES {q_marks}', inputs) + + def exec_script(self, sql: SqlText) -> DbCursor: + """Wrapper around sqlite3.Connection.executescript.""" + return self._conn.executescript(sql) + + def commit(self) -> None: + """Commit changes (i.e. DbData.save() calls) to database.""" + self._conn.commit() + class DbData: """Abstraction of common DB operation.""" @@ -138,7 +141,7 @@ class DbData: return cls(**kwargs) @classmethod - def get_one(cls, conn: BaseDbConn, id_: str | Hash) -> Self: + def get_one(cls, conn: DbConn, id_: str | Hash) -> Self: """Return single entry of id_ from DB.""" sql = SqlText(f'SELECT * FROM {cls._table_name} ' f'WHERE {cls.id_name} = ?') @@ -150,13 +153,13 @@ class DbData: return cls._from_table_row(row) @classmethod - def get_all(cls, conn: BaseDbConn) -> list[Self]: + def get_all(cls, conn: DbConn) -> list[Self]: """Return all entries from DB.""" sql = SqlText(f'SELECT * FROM {cls._table_name}') rows = conn.exec(sql).fetchall() return [cls._from_table_row(row) for row in rows] - def save(self, conn: BaseDbConn) -> Cursor: + def save(self, conn: DbConn) -> DbCursor: """Save entry to DB.""" vals = [] for val in [getattr(self, col_name) for col_name in self._cols]: diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py index 2c9c1e8..d54bdd8 100644 --- a/src/ytplom/migrations.py +++ b/src/ytplom/migrations.py @@ -1,44 +1,15 @@ """Anything pertaining specifically to DB migrations.""" from pathlib import Path -from sqlite3 import connect as sql_connect, Connection as SqlConnection -from typing import Callable, Optional +from typing import Callable from ytplom.db import ( - get_db_version, BaseDbConn, SqlText, EXPECTED_DB_VERSION, PATH_DB, - PATH_MIGRATIONS, SQL_DB_VERSION) + get_db_version, DbConn, SqlText, + EXPECTED_DB_VERSION, PATH_DB, PATH_MIGRATIONS, SQL_DB_VERSION) from ytplom.primitives import HandledException _LEGIT_YES = 'YES!!' -class _Migration: - """Wrapper for SQL and Python code to apply on migrating.""" - - def __init__(self, - version: int, - filename_sql: Optional[Path] = None, - after_sql_steps: Optional[Callable] = None - ) -> None: - self.version = version - self._filename_sql = filename_sql - self._sql_code = None - if filename_sql: - path_sql = PATH_MIGRATIONS.joinpath(filename_sql) - self._sql_code = path_sql.read_text(encoding='utf8') - self._after_sql_steps = after_sql_steps - - def apply_to(self, path_db: Path): - """Apply to DB at path_db migration code stored in self.""" - with sql_connect(path_db, autocommit=False) as conn: - if self._sql_code: - print(f'Executing {self._filename_sql}') - conn.executescript(self._sql_code) - if self._after_sql_steps: - print('Running additional steps') - self._after_sql_steps(conn) - conn.execute(SqlText(f'{SQL_DB_VERSION} = {self.version}')) - - def run_migrations() -> None: """Try to migrate DB towards EXPECTED_DB_VERSION.""" start_version = get_db_version(PATH_DB) @@ -52,18 +23,27 @@ def run_migrations() -> None: print(f'Trying to migrate from DB version {start_version} to ' f'{EXPECTED_DB_VERSION} …') migs_to_do = [] - migs_by_n = {mig.version: mig for mig in MIGRATIONS} + migs_by_n = dict(enumerate(MIGRATIONS)) for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]: if n not in migs_by_n: raise HandledException(f'Needed migration missing: {n}') - migs_to_do += [migs_by_n[n]] - for mig in migs_to_do: - print(f'Running migration towards: {mig.version}') - mig.apply_to(PATH_DB) + migs_to_do += [(n, *migs_by_n[n])] + for version, filename_sql, after_sql_steps in migs_to_do: + print(f'Running migration towards: {version}') + with DbConn(expected_version=version-1) as conn: + if filename_sql: + print(f'Executing {filename_sql}') + path_sql = PATH_MIGRATIONS.joinpath(filename_sql) + conn.exec_script(SqlText(path_sql.read_text(encoding='utf8'))) + if after_sql_steps: + print('Running additional steps') + after_sql_steps(conn) + conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}')) + conn.commit() print('Finished migrations.') -def _rewrite_files_last_field_processing_first_field(conn: BaseDbConn, +def _rewrite_files_last_field_processing_first_field(conn: DbConn, cb: Callable ) -> None: rows = conn.exec(SqlText('SELECT * FROM files')).fetchall() @@ -72,12 +52,11 @@ def _rewrite_files_last_field_processing_first_field(conn: BaseDbConn, conn.exec_on_values(SqlText('REPLACE INTO files'), tuple(row)) -def _mig_2_calc_digests(sql_conn: SqlConnection) -> None: +def _mig_2_calc_digests(conn: DbConn) -> None: """Calculate sha512 digests to all known video files.""" # pylint: disable=import-outside-toplevel from hashlib import file_digest from ytplom.misc import PATH_DOWNLOADS - conn = BaseDbConn(sql_conn) rel_paths = [ p[0] for p in conn.exec(SqlText('SELECT rel_path FROM files')).fetchall()] @@ -103,18 +82,17 @@ def _mig_2_calc_digests(sql_conn: SqlConnection) -> None: _rewrite_files_last_field_processing_first_field(conn, hexdigest_file) -def _mig_4_convert_digests(sql_conn: SqlConnection) -> None: +def _mig_4_convert_digests(conn: DbConn) -> None: """Fill new files.sha512_blob field with binary .sha512_digest.""" - _rewrite_files_last_field_processing_first_field( - BaseDbConn(sql_conn), bytes.fromhex) + _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex) MIGRATIONS = [ - _Migration(0, Path('0_init.sql')), - _Migration(1, Path('1_add_files_last_updated.sql')), - _Migration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests), - _Migration(3, Path('3_files_redo.sql')), - _Migration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests), - _Migration(5, Path('5_files_redo.sql')), - _Migration(6, Path('6_add_files_tags.sql')) + (Path('0_init.sql'), None), + (Path('1_add_files_last_updated.sql'), None), + (Path('2_add_files_sha512.sql'), _mig_2_calc_digests), + (Path('3_files_redo.sql'), None), + (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests), + (Path('5_files_redo.sql'), None), + (Path('6_add_files_tags.sql'), None) ] diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py index 7458cb7..c42b71a 100644 --- a/src/ytplom/misc.py +++ b/src/ytplom/misc.py @@ -12,14 +12,13 @@ from uuid import uuid4 from pathlib import Path from threading import Thread from queue import Queue -from sqlite3 import Cursor # non-included libs from ffmpeg import probe as ffprobe # type: ignore import googleapiclient.discovery # type: ignore from mpv import MPV # type: ignore from yt_dlp import YoutubeDL # type: ignore # ourselves -from ytplom.db import BaseDbConn, DbConn, DbData, Hash, SqlText +from ytplom.db import DbConn, DbCursor, DbData, Hash, SqlText from ytplom.primitives import HandledException, NotFoundException @@ -202,7 +201,7 @@ class YoutubeQuery(DbData): @classmethod def new_by_request_saved(cls, - conn: BaseDbConn, + conn: DbConn, config: Config, query_txt: QueryText ) -> Self: @@ -252,9 +251,7 @@ class YoutubeQuery(DbData): return query @classmethod - def get_all_for_video(cls, - conn: BaseDbConn, - video_id: YoutubeId + def get_all_for_video(cls, conn: DbConn, video_id: YoutubeId ) -> list[Self]: """Return YoutubeQueries containing YoutubeVideo's ID in results.""" sql = SqlText('SELECT query_id FROM ' @@ -306,10 +303,7 @@ class YoutubeVideo(DbData): self.duration = _readable_seconds(seconds) @classmethod - def get_all_for_query(cls, - conn: BaseDbConn, - query_id: QueryId - ) -> list[Self]: + def get_all_for_query(cls, conn: DbConn, query_id: QueryId) -> list[Self]: """Return all videos for query of query_id.""" sql = SqlText('SELECT video_id ' 'FROM yt_query_results WHERE query_id = ?') @@ -317,7 +311,7 @@ class YoutubeVideo(DbData): return [cls.get_one(conn, video_id_tup[0]) for video_id_tup in video_ids] - def save_to_query(self, conn: BaseDbConn, query_id: QueryId) -> None: + def save_to_query(self, conn: DbConn, query_id: QueryId) -> None: """Save inclusion of self in results to query of query_id.""" conn.exec_on_values(SqlText('REPLACE INTO yt_query_results'), (query_id, self.id_)) @@ -361,23 +355,23 @@ class VideoFile(DbData): def _renew_last_update(self): self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT)) - def save(self, conn: BaseDbConn) -> Cursor: + def save(self, conn: DbConn) -> DbCursor: """Extend super().save by new .last_update if sufficient changes.""" if hash(self) != self._hash_on_last_update: self._renew_last_update() return super().save(conn) @classmethod - def get_one_with_whitelist_tags_display(cls, conn: BaseDbConn, id_: Hash, - whitelist_tags_display: TagSet - ) -> Self: + def get_one_with_whitelist_tags_display( + cls, conn: DbConn, id_: Hash, whitelist_tags_display: TagSet + ) -> Self: """Same as .get_one except sets .whitelist_tags_display.""" vf = cls.get_one(conn, id_) vf.whitelist_tags_display = whitelist_tags_display return vf @classmethod - def get_by_yt_id(cls, conn: BaseDbConn, yt_id: YoutubeId) -> Self: + def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self: """Return VideoFile of .yt_id.""" sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id = ?') row = conn.exec(sql, (yt_id,)).fetchone() @@ -387,7 +381,7 @@ class VideoFile(DbData): @classmethod def get_filtered(cls, - conn: BaseDbConn, + conn: DbConn, filter_path: FilterStr, needed_tags_dark: TagSet, needed_tags_seen: TagSet, @@ -417,7 +411,7 @@ class VideoFile(DbData): 'canot show display-whitelisted tags on unset whitelist') return self.tags.whitelisted(self.whitelist_tags_display) - def unused_tags(self, conn: BaseDbConn) -> TagSet: + def unused_tags(self, conn: DbConn) -> TagSet: """Return tags used among other VideoFiles, not in self.""" if self.whitelist_tags_display is None: raise HandledException( @@ -488,7 +482,7 @@ class VideoFile(DbData): self.full_path.unlink() @classmethod - def purge_deleteds(cls, conn: BaseDbConn) -> None: + def purge_deleteds(cls, conn: DbConn) -> None: """For all of .is_flag_set("deleted"), remove file _and_ DB entry.""" for file in [f for f in cls.get_all(conn) if f.is_flag_set(FlagName('delete'))]: @@ -516,7 +510,7 @@ class QuotaLog(DbData): self.cost = cost @classmethod - def update(cls, conn: BaseDbConn, cost: QuotaCost) -> None: + def update(cls, conn: DbConn, cost: QuotaCost) -> None: """Adds cost mapped to current datetime.""" cls._remove_old(conn) new = cls(None, @@ -525,14 +519,14 @@ class QuotaLog(DbData): new.save(conn) @classmethod - def current(cls, conn: BaseDbConn) -> QuotaCost: + def current(cls, conn: DbConn) -> QuotaCost: """Returns quota cost total for last 24 hours, purges old data.""" cls._remove_old(conn) quota_costs = cls.get_all(conn) return QuotaCost(sum(c.cost for c in quota_costs)) @classmethod - def _remove_old(cls, conn: BaseDbConn) -> None: + def _remove_old(cls, conn: DbConn) -> None: cutoff = datetime.now() - timedelta(days=1) sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp < ?') conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),))