From: Christian Heller Date: Mon, 2 Dec 2024 13:30:29 +0000 (+0100) Subject: Reorganize DB code and especially migrations handling. X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/static/%7B%7Bprefix%7D%7D/balance?a=commitdiff_plain;h=6ccda46d6c74bdb2bcdfd61942051d013c268a31;p=ytplom Reorganize DB code and especially migrations handling. --- diff --git a/src/migrate.py b/src/migrate.py index e1ba4de..cc5e6cf 100755 --- a/src/migrate.py +++ b/src/migrate.py @@ -1,69 +1,7 @@ #!/usr/bin/env python3 """Script to migrate DB to most recent schema.""" -from importlib.util import spec_from_file_location, module_from_spec -from pathlib import Path -from sys import exit as sys_exit -from ytplom.misc import ( - EXPECTED_DB_VERSION, PATH_DB, PATH_DB_SCHEMA, PATH_MIGRATIONS, - SQL_DB_VERSION, get_db_version, DbConn, HandledException, SqlText) - - -_SUFFIX_PY = '.py' -_SUFFIX_SQL = '.sql' - - -def main() -> None: - """Try to migrate DB towards EXPECTED_DB_VERSION.""" - start_version = get_db_version(PATH_DB) - if start_version == EXPECTED_DB_VERSION: - print('Database at expected version, no migrations to do.') - sys_exit(0) - elif start_version > EXPECTED_DB_VERSION: - raise HandledException( - f'Cannot migrate backward from version {start_version} to ' - f'{EXPECTED_DB_VERSION}.') - print(f'Trying to migrate from DB version {start_version} to ' - f'{EXPECTED_DB_VERSION} …') - migrations: dict[int, list[Path]] = { - n+1: [] for n in range(start_version, EXPECTED_DB_VERSION)} - for path in [p for p in PATH_MIGRATIONS.iterdir() - if p.is_file() and p != PATH_DB_SCHEMA]: - toks = path.name.split('_') - try: - version = int(toks[0]) - if path.suffix not in {_SUFFIX_PY, _SUFFIX_SQL}: - raise ValueError - except ValueError as e: - msg = f'Found illegal migration path {path}, aborting.' - raise HandledException(msg) from e - if version in migrations: - migrations[version] += [path] - missing = [n for n in migrations.keys() if not migrations[n]] - if missing: - raise HandledException(f'Needed migrations missing: {missing}') - with DbConn(check_version=False) as conn: - for version, migration_paths in migrations.items(): - sorted_paths = sorted(migration_paths) - msg_apply_prefix = f'Applying migration {version}: ' - for path in [p for p in sorted_paths if _SUFFIX_SQL == p.suffix]: - print(f'{msg_apply_prefix}{path}') - conn.exec_script(path) - for path in [p for p in sorted_paths if _SUFFIX_PY == p.suffix]: - spec = spec_from_file_location(str(path), path) - assert spec is not None - assert spec.loader is not None - module = module_from_spec(spec) - assert module is not None - spec.loader.exec_module(module) - if hasattr(module, 'migrate'): - print(f'{msg_apply_prefix}{path}') - module.migrate(conn) - else: - raise HandledException( - f'Suspected migration file {path} missing migrate().') - conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}')) - conn.commit() +from ytplom.migrations import run_migrations if __name__ == '__main__': - main() + run_migrations() diff --git a/src/migrations/2_add_files_sha512.py b/src/migrations/2_add_files_sha512.py deleted file mode 100644 index 0e10011..0000000 --- a/src/migrations/2_add_files_sha512.py +++ /dev/null @@ -1,28 +0,0 @@ -from hashlib import file_digest -from ytplom.misc import DbConn, HandledException, HashStr, SqlText, VideoFile - - -_LEGIT_YES = 'YES!' - - -def migrate(conn: DbConn) -> None: - file_entries = VideoFile.get_all(conn) - missing = [f for f in file_entries if not f.present] - if missing: - print('WARNING: Cannot find files to following paths') - for f in missing: - print(f.full_path) - reply = input( - 'WARNING: To continue migration, will have to delete above ' - f'rows from DB. To continue, type (exactly) "{_LEGIT_YES}": ') - if "YES!" != reply: - raise HandledException('Migration aborted!') - for f in missing: - conn.exec(SqlText('DELETE FROM files WHERE rel_path = ?'), - (str(f.rel_path),)) - for file in VideoFile.get_all(conn): - print(f'Calculating digest for: {file.rel_path}') - with open(file.full_path, 'rb') as f: - file.sha512_digest = HashStr( - file_digest(f, 'sha512').hexdigest()) - file.save(conn) diff --git a/src/migrations/init_3.sql b/src/migrations/init_3.sql deleted file mode 100644 index d223bef..0000000 --- a/src/migrations/init_3.sql +++ /dev/null @@ -1,33 +0,0 @@ -CREATE TABLE yt_queries ( - id TEXT PRIMARY KEY, - text TEXT NOT NULL, - retrieved_at TEXT NOT NULL -); -CREATE TABLE yt_videos ( - id TEXT PRIMARY KEY, - title TEXT NOT NULL, - description TEXT NOT NULL, - published_at TEXT NOT NULL, - duration TEXT NOT NULL, - definition TEXT NOT NULL -); -CREATE TABLE yt_query_results ( - query_id TEXT NOT NULL, - video_id TEXT NOT NULL, - PRIMARY KEY (query_id, video_id), - FOREIGN KEY (query_id) REFERENCES yt_queries(id), - FOREIGN KEY (video_id) REFERENCES yt_videos(id) -); -CREATE TABLE quota_costs ( - id TEXT PRIMARY KEY, - timestamp TEXT NOT NULL, - cost INT NOT NULL -); -CREATE TABLE files ( - sha512_digest TEXT PRIMARY KEY, - rel_path TEXT NOT NULL, - flags INTEGER NOT NULL DEFAULT 0, - yt_id TEXT, - last_update TEXT NOT NULL, - FOREIGN KEY (yt_id) REFERENCES yt_videos(id) -); diff --git a/src/migrations/new_init.sql b/src/migrations/new_init.sql new file mode 100644 index 0000000..d223bef --- /dev/null +++ b/src/migrations/new_init.sql @@ -0,0 +1,33 @@ +CREATE TABLE yt_queries ( + id TEXT PRIMARY KEY, + text TEXT NOT NULL, + retrieved_at TEXT NOT NULL +); +CREATE TABLE yt_videos ( + id TEXT PRIMARY KEY, + title TEXT NOT NULL, + description TEXT NOT NULL, + published_at TEXT NOT NULL, + duration TEXT NOT NULL, + definition TEXT NOT NULL +); +CREATE TABLE yt_query_results ( + query_id TEXT NOT NULL, + video_id TEXT NOT NULL, + PRIMARY KEY (query_id, video_id), + FOREIGN KEY (query_id) REFERENCES yt_queries(id), + FOREIGN KEY (video_id) REFERENCES yt_videos(id) +); +CREATE TABLE quota_costs ( + id TEXT PRIMARY KEY, + timestamp TEXT NOT NULL, + cost INT NOT NULL +); +CREATE TABLE files ( + sha512_digest TEXT PRIMARY KEY, + rel_path TEXT NOT NULL, + flags INTEGER NOT NULL DEFAULT 0, + yt_id TEXT, + last_update TEXT NOT NULL, + FOREIGN KEY (yt_id) REFERENCES yt_videos(id) +); diff --git a/src/sync.py b/src/sync.py index 4941188..63108b7 100755 --- a/src/sync.py +++ b/src/sync.py @@ -9,9 +9,9 @@ from urllib.request import urlopen # non-included libs from paramiko import SSHClient # type: ignore from scp import SCPClient # type: ignore -from ytplom.misc import ( - PATH_DB, PATH_DOWNLOADS, PATH_TEMP, - Config, DbConn, QuotaLog, VideoFile, YoutubeQuery, YoutubeVideo) +from ytplom.db import DbConn, PATH_DB +from ytplom.misc import (PATH_DOWNLOADS, PATH_TEMP, Config, QuotaLog, + VideoFile, YoutubeQuery, YoutubeVideo) from ytplom.http import PAGE_NAMES diff --git a/src/ytplom/db.py b/src/ytplom/db.py new file mode 100644 index 0000000..9edf5c6 --- /dev/null +++ b/src/ytplom/db.py @@ -0,0 +1,118 @@ +"""Database access and management code.""" +from pathlib import Path +from sqlite3 import ( + connect as sql_connect, Connection as SqlConnection, Cursor, Row) +from typing import Any, Literal, NewType, Self +from ytplom.primitives import ( + HandledException, NotFoundException, PATH_APP_DATA) + +SqlText = NewType('SqlText', str) + +EXPECTED_DB_VERSION = 3 +PATH_DB = PATH_APP_DATA.joinpath('db.sql') +SQL_DB_VERSION = SqlText('PRAGMA user_version') +PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations') +_PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath('new_init.sql') +_NAME_INSTALLER = Path('install.sh') + + +def get_db_version(db_path: Path) -> int: + """Return user_version value of DB at db_path.""" + with sql_connect(db_path) as conn: + return list(conn.execute(SQL_DB_VERSION))[0][0] + + +class BaseDbConn: + """Wrapper for pre-established sqlite3.Connection.""" + + def __init__(self, sql_conn: SqlConnection) -> None: + self._conn = sql_conn + + def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor: + """Wrapper around sqlite3.Connection.execute.""" + return self._conn.execute(sql, inputs) + + def commit(self) -> None: + """Commit changes (i.e. DbData.save() calls) to database.""" + self._conn.commit() + + +class DbConn(BaseDbConn): + """Like parent, but opening (and as context mgr: closing) connection.""" + + def __init__(self, path: Path = PATH_DB) -> None: + if not path.is_file(): + if path.exists(): + raise HandledException(f'no DB at {path}; would create, ' + 'but something\'s already there?') + if not path.parent.is_dir(): + raise HandledException( + f'cannot find {path.parent} as directory to put ' + f'DB into, did you run {_NAME_INSTALLER}?') + with sql_connect(path) as conn: + conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8')) + conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') + cur_version = get_db_version(path) + if cur_version != EXPECTED_DB_VERSION: + raise HandledException( + f'wrong database version {cur_version}, expected: ' + f'{EXPECTED_DB_VERSION} – run "migrate"?') + super().__init__(sql_connect(path, autocommit=False)) + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]: + self._conn.close() + return False + + +class DbData: + """Abstraction of common DB operation.""" + id_name: str = 'id' + _table_name: str + _cols: tuple[str, ...] + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, self.__class__): + return False + for attr_name in self._cols: + if getattr(self, attr_name) != getattr(other, attr_name): + return False + return True + + @classmethod + def _from_table_row(cls, row: Row) -> Self: + kwargs = {} + for i, col_name in enumerate(cls._cols): + kwargs[col_name] = row[i] + for attr_name, type_ in cls.__annotations__.items(): + if attr_name in kwargs: + kwargs[attr_name] = type_(kwargs[attr_name]) + return cls(**kwargs) + + @classmethod + def get_one(cls, conn: BaseDbConn, id_: str) -> Self: + """Return single entry of id_ from DB.""" + sql = SqlText(f'SELECT * FROM {cls._table_name} ' + f'WHERE {cls.id_name} = ?') + row = conn.exec(sql, (id_,)).fetchone() + if not row: + msg = f'no entry found for ID "{id_}" in table {cls._table_name}' + raise NotFoundException(msg) + return cls._from_table_row(row) + + @classmethod + def get_all(cls, conn: BaseDbConn) -> list[Self]: + """Return all entries from DB.""" + sql = SqlText(f'SELECT * FROM {cls._table_name}') + rows = conn.exec(sql).fetchall() + return [cls._from_table_row(row) for row in rows] + + def save(self, conn: BaseDbConn) -> Cursor: + """Save entry to DB.""" + vals = [getattr(self, col_name) for col_name in self._cols] + q_marks = '(' + ','.join(['?'] * len(vals)) + ')' + sql = SqlText(f'REPLACE INTO {self._table_name} VALUES {q_marks}') + return conn.exec(sql, tuple(str(v) if isinstance(v, Path) else v + for v in vals)) diff --git a/src/ytplom/http.py b/src/ytplom/http.py index 29f06ee..4fa2754 100644 --- a/src/ytplom/http.py +++ b/src/ytplom/http.py @@ -9,14 +9,16 @@ from urllib.request import urlretrieve from urllib.error import HTTPError from jinja2 import ( # type: ignore Environment as JinjaEnv, FileSystemLoader as JinjaFSLoader) +from ytplom.db import DbConn from ytplom.misc import ( - HashStr, FilesWithIndex, FlagName, NotFoundException, PlayerUpdateId, - QueryId, QueryText, QuotaCost, UrlStr, YoutubeId, - FILE_FLAGS, PATH_APP_DATA, PATH_THUMBNAILS, YOUTUBE_URL_PREFIX, + HashStr, FilesWithIndex, FlagName, PlayerUpdateId, QueryId, QueryText, + QuotaCost, UrlStr, YoutubeId, + FILE_FLAGS, PATH_THUMBNAILS, YOUTUBE_URL_PREFIX, ensure_expected_dirs, - Config, DbConn, DownloadsManager, Player, QuotaLog, VideoFile, - YoutubeQuery, YoutubeVideo + Config, DownloadsManager, Player, QuotaLog, VideoFile, YoutubeQuery, + YoutubeVideo ) +from ytplom.primitives import NotFoundException, PATH_APP_DATA # type definitions for mypy _PageNames: TypeAlias = dict[str, Path] diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py new file mode 100644 index 0000000..73f2406 --- /dev/null +++ b/src/ytplom/migrations.py @@ -0,0 +1,95 @@ +"""Anything pertaining specifically to DB migrations.""" +from pathlib import Path +from sqlite3 import connect as sql_connect, Connection as SqlConnection +from typing import Callable, Optional +from ytplom.db import ( + get_db_version, BaseDbConn, SqlText, EXPECTED_DB_VERSION, PATH_DB, + PATH_MIGRATIONS, SQL_DB_VERSION) +from ytplom.primitives import HandledException + + +_LEGIT_YES = 'YES!!' + + +class _Migration: + """Wrapper for SQL and Python code to apply on migrating.""" + + def __init__(self, + version: int, + filename_sql: Optional[Path] = None, + after_sql_steps: Optional[Callable] = None + ) -> None: + self.version = version + self._sql_code = None + if filename_sql: + path_sql = PATH_MIGRATIONS.joinpath(filename_sql) + self._sql_code = path_sql.read_text(encoding='utf8') + self._after_sql_steps = after_sql_steps + + def apply_to(self, path_db: Path): + """Apply to DB at path_db migration code stored in self.""" + with sql_connect(path_db, autocommit=False) as conn: + if self._sql_code: + conn.executescript(self._sql_code) + if self._after_sql_steps: + self._after_sql_steps(conn) + conn.execute(SqlText(f'{SQL_DB_VERSION} = {self.version}')) + + +def run_migrations() -> None: + """Try to migrate DB towards EXPECTED_DB_VERSION.""" + start_version = get_db_version(PATH_DB) + if start_version == EXPECTED_DB_VERSION: + print('Database at expected version, no migrations to do.') + return + if start_version > EXPECTED_DB_VERSION: + raise HandledException( + f'Cannot migrate backward from version {start_version} to ' + f'{EXPECTED_DB_VERSION}.') + print(f'Trying to migrate from DB version {start_version} to ' + f'{EXPECTED_DB_VERSION} …') + migs_to_do = [] + migs_by_n = {mig.version: mig for mig in MIGRATIONS} + for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]: + if n not in migs_by_n: + raise HandledException(f'Needed migration missing: {n}') + migs_to_do += [migs_by_n[n]] + for mig in migs_to_do: + print(f'Running migration towards: {mig.version}') + mig.apply_to(PATH_DB) + print('Finished migrations.') + + +def _mig_2_calc_digests(sql_conn: SqlConnection) -> None: + """Calculate sha512 digests to all known video files.""" + from hashlib import file_digest + from ytplom.misc import HashStr, VideoFile + conn = BaseDbConn(sql_conn) + file_entries = VideoFile.get_all(conn) + missing = [f for f in file_entries if not f.present] + if missing: + print('WARNING: Cannot find files to following paths') + for f in missing: + print(f.full_path) + reply = input( + 'WARNING: To continue migration, will have to delete above ' + f'rows from DB. To continue, type (exactly) "{_LEGIT_YES}": ') + if _LEGIT_YES != reply: + raise HandledException('Migration aborted!') + for f in missing: + conn.exec(SqlText('DELETE FROM files WHERE rel_path = ?'), + (str(f.rel_path),)) + for video_file in VideoFile.get_all(conn): + print(f'Calculating digest for: {video_file.rel_path}') + with open(video_file.full_path, 'rb') as vf: + video_file.sha512_digest = HashStr( + file_digest(vf, 'sha512').hexdigest()) + video_file.save(conn) + + +MIGRATIONS = [ + _Migration(0, Path('0_init.sql')), + _Migration(1, Path('1_add_files_last_updated.sql')), + _Migration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests), + _Migration(3, Path('3_files_redo.sql')) +] diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py index 4c26b4b..4aa1fed 100644 --- a/src/ytplom/misc.py +++ b/src/ytplom/misc.py @@ -1,7 +1,7 @@ """Main ytplom lib.""" # included libs -from typing import Any, Literal, NewType, Optional, Self, TypeAlias +from typing import NewType, Optional, Self, TypeAlias from os import chdir, environ from hashlib import file_digest from random import shuffle @@ -11,13 +11,16 @@ from json import loads as json_loads from urllib.request import urlretrieve from uuid import uuid4 from pathlib import Path -from sqlite3 import connect as sql_connect, Cursor, Row from threading import Thread from queue import Queue # non-included libs import googleapiclient.discovery # type: ignore from mpv import MPV # type: ignore from yt_dlp import YoutubeDL # type: ignore +# ourselves +from ytplom.db import BaseDbConn, DbConn, DbData, SqlText +from ytplom.primitives import HandledException, NotFoundException + # default configuration DEFAULTS = { @@ -33,7 +36,6 @@ YoutubeId = NewType('YoutubeId', str) QueryId = NewType('QueryId', str) QueryText = NewType('QueryText', str) ProseText = NewType('ProseText', str) -SqlText = NewType('SqlText', str) FlagName = NewType('FlagName', str) FlagsInt = NewType('FlagsInt', int) HashStr = NewType('HashStr', str) @@ -43,15 +45,11 @@ UrlStr = NewType('UrlStr', str) FilesWithIndex: TypeAlias = list[tuple[int, 'VideoFile']] # major expected directories -PATH_APP_DATA = Path.home().joinpath('.local/share/ytplom') -PATH_CACHE = Path.home().joinpath('.cache/ytplom') - -# paths for rather dynamic data PATH_DOWNLOADS = Path.home().joinpath('ytplom_downloads') -PATH_DB = PATH_APP_DATA.joinpath('db.sql') +PATH_CONFFILE = Path.home().joinpath('.config/ytplom/config.json') +PATH_CACHE = Path.home().joinpath('.cache/ytplom') PATH_TEMP = PATH_CACHE.joinpath('temp') PATH_THUMBNAILS = PATH_CACHE.joinpath('thumbnails') -PATH_CONFFILE = Path.home().joinpath('.config/ytplom/config.json') # yt_dlp config YT_DOWNLOAD_FORMAT = 'bestvideo[height<=1080][width<=1920]+bestaudio'\ @@ -65,30 +63,15 @@ YOUTUBE_URL_PREFIX = UrlStr('https://www.youtube.com/watch?v=') QUOTA_COST_YOUTUBE_SEARCH = QuotaCost(100) QUOTA_COST_YOUTUBE_DETAILS = QuotaCost(1) -# database stuff -EXPECTED_DB_VERSION = 3 -SQL_DB_VERSION = SqlText('PRAGMA user_version') -PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations') -PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath(f'init_{EXPECTED_DB_VERSION}.sql') - # other ENVIRON_PREFIX = 'YTPLOM_' TIMESTAMP_FMT = '%Y-%m-%d %H:%M:%S.%f' LEGAL_EXTENSIONS = {'webm', 'mp4', 'mkv'} -NAME_INSTALLER = Path('install.sh') FILE_FLAGS: dict[FlagName, FlagsInt] = { FlagName('delete'): FlagsInt(-(1 << 63)) } -class NotFoundException(Exception): - """Raise on expected data missing, e.g. DB fetches finding nothing.""" - - -class HandledException(Exception): - """Raise in any other case where we know what's happening.""" - - def ensure_expected_dirs(expected_dirs: list[Path]) -> None: """Ensure existance of expected_dirs _as_ directories.""" for dir_path in [p for p in expected_dirs if not p.is_dir()]: @@ -99,12 +82,6 @@ def ensure_expected_dirs(expected_dirs: list[Path]) -> None: dir_path.mkdir(parents=True, exist_ok=True) -def get_db_version(db_path: Path) -> int: - """Return user_version value of DB at db_path.""" - with sql_connect(db_path) as conn: - return list(conn.execute(SQL_DB_VERSION))[0][0] - - class Config: """Collects user-configurable settings.""" host: str @@ -127,104 +104,6 @@ class Config: if k.isupper() and k.startswith(ENVIRON_PREFIX)}) -class DbConn: - """Wrapped sqlite3.Connection.""" - - def __init__(self, - path: Path = PATH_DB, - check_version: bool = True - ) -> None: - self._path = path - if not self._path.is_file(): - if self._path.exists(): - raise HandledException(f'no DB at {self._path}; would create, ' - 'but something\'s already there?') - if not self._path.parent.is_dir(): - raise NotFoundException( - f'cannot find {self._path.parent} as directory to put ' - f'DB into, did you run {NAME_INSTALLER}?') - with sql_connect(self._path) as conn: - conn.executescript(PATH_DB_SCHEMA.read_text(encoding='utf8')) - conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') - if check_version: - cur_version = get_db_version(self._path) - if cur_version != EXPECTED_DB_VERSION: - raise HandledException( - f'wrong database version {cur_version}, expected: ' - f'{EXPECTED_DB_VERSION} – run "migrate"?') - self._conn = sql_connect(self._path, autocommit=False) - - def __enter__(self) -> Self: - return self - - def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]: - self._conn.close() - return False - - def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor: - """Wrapper around sqlite3.Connection.execute.""" - return self._conn.execute(sql, inputs) - - def exec_script(self, path: Path) -> None: - """Simplified sqlite3.Connection.executescript.""" - self._conn.executescript(path.read_text(encoding='utf8')) - - def commit(self) -> None: - """Commit changes (i.e. DbData.save() calls) to database.""" - self._conn.commit() - - -class DbData: - """Abstraction of common DB operation.""" - id_name: str = 'id' - _table_name: str - _cols: tuple[str, ...] - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, self.__class__): - return False - for attr_name in self._cols: - if getattr(self, attr_name) != getattr(other, attr_name): - return False - return True - - @classmethod - def _from_table_row(cls, row: Row) -> Self: - kwargs = {} - for i, col_name in enumerate(cls._cols): - kwargs[col_name] = row[i] - for attr_name, type_ in cls.__annotations__.items(): - if attr_name in kwargs: - kwargs[attr_name] = type_(kwargs[attr_name]) - return cls(**kwargs) - - @classmethod - def get_one(cls, conn: DbConn, id_: str) -> Self: - """Return single entry of id_ from DB.""" - sql = SqlText(f'SELECT * FROM {cls._table_name} ' - f'WHERE {cls.id_name} = ?') - row = conn.exec(sql, (id_,)).fetchone() - if not row: - msg = f'no entry found for ID "{id_}" in table {cls._table_name}' - raise NotFoundException(msg) - return cls._from_table_row(row) - - @classmethod - def get_all(cls, conn: DbConn) -> list[Self]: - """Return all entries from DB.""" - sql = SqlText(f'SELECT * FROM {cls._table_name}') - rows = conn.exec(sql).fetchall() - return [cls._from_table_row(row) for row in rows] - - def save(self, conn: DbConn) -> Cursor: - """Save entry to DB.""" - vals = [getattr(self, col_name) for col_name in self._cols] - q_marks = '(' + ','.join(['?'] * len(vals)) + ')' - sql = SqlText(f'REPLACE INTO {self._table_name} VALUES {q_marks}') - return conn.exec(sql, tuple(str(v) if isinstance(v, Path) else v - for v in vals)) - - class YoutubeQuery(DbData): """Representation of YouTube query (without results).""" _table_name = 'yt_queries' @@ -241,7 +120,7 @@ class YoutubeQuery(DbData): @classmethod def new_by_request_saved(cls, - conn: DbConn, + conn: BaseDbConn, config: Config, query_txt: QueryText ) -> Self: @@ -292,7 +171,7 @@ class YoutubeQuery(DbData): @classmethod def get_all_for_video(cls, - conn: DbConn, + conn: BaseDbConn, video_id: YoutubeId ) -> list[Self]: """Return YoutubeQueries containing YoutubeVideo's ID in results.""" @@ -349,7 +228,7 @@ class YoutubeVideo(DbData): @classmethod def get_all_for_query(cls, - conn: DbConn, + conn: BaseDbConn, query_id: QueryId ) -> list[Self]: """Return all videos for query of query_id.""" @@ -359,7 +238,7 @@ class YoutubeVideo(DbData): return [cls.get_one(conn, video_id_tup[0]) for video_id_tup in video_ids] - def save_to_query(self, conn: DbConn, query_id: QueryId) -> None: + def save_to_query(self, conn: BaseDbConn, query_id: QueryId) -> None: """Save inclusion of self in results to query of query_id.""" conn.exec(SqlText('REPLACE INTO yt_query_results VALUES (?, ?)'), (query_id, self.id_)) @@ -398,7 +277,7 @@ class VideoFile(DbData): self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT)) @classmethod - def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self: + def get_by_yt_id(cls, conn: BaseDbConn, yt_id: YoutubeId) -> Self: """Return VideoFile of .yt_id.""" sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id = ?') row = conn.exec(sql, (yt_id,)).fetchone() @@ -462,7 +341,7 @@ class QuotaLog(DbData): self.cost = cost @classmethod - def update(cls, conn: DbConn, cost: QuotaCost) -> None: + def update(cls, conn: BaseDbConn, cost: QuotaCost) -> None: """Adds cost mapped to current datetime.""" cls._remove_old(conn) new = cls(None, @@ -471,14 +350,14 @@ class QuotaLog(DbData): new.save(conn) @classmethod - def current(cls, conn: DbConn) -> QuotaCost: + def current(cls, conn: BaseDbConn) -> QuotaCost: """Returns quota cost total for last 24 hours, purges old data.""" cls._remove_old(conn) quota_costs = cls.get_all(conn) return QuotaCost(sum(c.cost for c in quota_costs)) @classmethod - def _remove_old(cls, conn: DbConn) -> None: + def _remove_old(cls, conn: BaseDbConn) -> None: cutoff = datetime.now() - timedelta(days=1) sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp < ?') conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),)) diff --git a/src/ytplom/primitives.py b/src/ytplom/primitives.py new file mode 100644 index 0000000..ddc64d5 --- /dev/null +++ b/src/ytplom/primitives.py @@ -0,0 +1,12 @@ +from pathlib import Path + + +PATH_APP_DATA = Path.home().joinpath('.local/share/ytplom') + + +class NotFoundException(Exception): + """Raise on expected data missing, e.g. DB fetches finding nothing.""" + + +class HandledException(Exception): + """Raise in any other case where we know what's happening."""