From: Christian Heller Date: Wed, 15 Jan 2025 14:27:59 +0000 (+0100) Subject: Include plomlib for its db.py, adapt DB code to it. X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/%7B%7Bprefix%7D%7D/static/index.html?a=commitdiff_plain;h=HEAD;p=ytplom Include plomlib for its db.py, adapt DB code to it. --- diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..95be54b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/plomlib"] + path = src/plomlib + url = https://plomlompom.com/repos/clone/plomlib diff --git a/src/plomlib b/src/plomlib new file mode 160000 index 0000000..743dbe0 --- /dev/null +++ b/src/plomlib @@ -0,0 +1 @@ +Subproject commit 743dbe0d493ddeb47eca981fa5be6d78e4d754c9 diff --git a/src/run.py b/src/run.py index 4b1ddc3..49214ce 100755 --- a/src/run.py +++ b/src/run.py @@ -6,7 +6,7 @@ from sys import argv, exit as sys_exit # ourselves from ytplom.db import DbFile from ytplom.primitives import HandledException -from ytplom.migrations import migrate +from ytplom.migrations import MIGRATIONS from ytplom.http import serve from ytplom.sync import sync @@ -19,7 +19,7 @@ if __name__ == '__main__': case 'create_db': DbFile.create() case 'migrate_db': - migrate() + DbFile(skip_validations=True).migrate(MIGRATIONS) case 'serve': serve() case 'sync': diff --git a/src/ytplom/db.py b/src/ytplom/db.py index 816e401..3bf7198 100644 --- a/src/ytplom/db.py +++ b/src/ytplom/db.py @@ -2,23 +2,20 @@ # included libs from base64 import urlsafe_b64decode, urlsafe_b64encode -from difflib import Differ from hashlib import file_digest from pathlib import Path -from sqlite3 import ( - connect as sql_connect, Cursor as SqlCursor, Row as SqlRow) -from typing import Callable, Literal, NewType, Optional, Self +from sqlite3 import Row as SqlRow +from typing import Self # ourselves +from plomlib.db import ( + PlomDbConn, PlomDbFile, PlomDbMigration, TypePlomDbMigration) from ytplom.primitives import ( HandledException, NotFoundException, PATH_APP_DATA) -EXPECTED_DB_VERSION = 6 PATH_DB = PATH_APP_DATA.joinpath('db.sql') -SqlText = NewType('SqlText', str) - -_SQL_DB_VERSION = SqlText('PRAGMA user_version') +_EXPECTED_DB_VERSION = 6 _PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations') _HASH_ALGO = 'sha512' _PATH_DB_SCHEMA = _PATH_MIGRATIONS.joinpath('new_init.sql') @@ -58,124 +55,20 @@ class Hash: return urlsafe_b64encode(self.bytes).decode('utf8') -class DbFile: - """Wrapper around the file of a sqlite3 database.""" - - def __init__(self, - path: Path = PATH_DB, - version_to_validate: int = EXPECTED_DB_VERSION - ) -> None: - self.path = path - if not self.path.is_file(): - raise HandledException(f'no DB file at {self.path}') - if version_to_validate < 0: - return - if (user_version := self._get_user_version()) != version_to_validate: - raise HandledException( - f'wrong DB version {user_version} (!= {version_to_validate})') - with DbConn(self) as conn: - self._validate_schema(conn) - - @staticmethod - def _validate_schema(conn: 'DbConn'): - schema_rows_normed = [] - indent = ' ' - for row in [ - r[0] for r in conn.exec(SqlText( - 'SELECT sql FROM sqlite_master ORDER BY sql')) - if r[0]]: - row_normed = [] - for subrow in [sr.rstrip() for sr in row.split('\n')]: - in_parentheses = 0 - split_at = [] - for i, c in enumerate(subrow): - if '(' == c: - in_parentheses += 1 - elif ')' == c: - in_parentheses -= 1 - elif ',' == c and 0 == in_parentheses: - split_at += [i + 1] - prev_split = 0 - for i in split_at: - if segment := subrow[prev_split:i].strip(): - row_normed += [f'{indent}{segment}'] - prev_split = i - if segment := subrow[prev_split:].strip(): - row_normed += [f'{indent}{segment}'] - row_normed[0] = row_normed[0].lstrip() # no indent for opening … - row_normed[-1] = row_normed[-1].lstrip() # … and closing line - if row_normed[-1] != ')' and row_normed[-3][-1] != ',': - row_normed[-3] = row_normed[-3] + ',' - row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')'] - row_normed[-1] = row_normed[-1] + ';' - schema_rows_normed += row_normed - if ((expected_rows := - _PATH_DB_SCHEMA.read_text(encoding='utf8').rstrip().splitlines() - ) != schema_rows_normed): - raise HandledException( - 'Unexpected tables schema. Diff to {_PATH_DB_SCHEMA}:\n' + - '\n'.join(Differ().compare(schema_rows_normed, expected_rows))) - - def _get_user_version(self) -> int: - with sql_connect(self.path) as conn: - return list(conn.execute(_SQL_DB_VERSION))[0][0] - - @staticmethod - def create(path: Path = PATH_DB) -> None: - """Create DB file at path according to _PATH_DB_SCHEMA.""" - if path.exists(): - raise HandledException( - f'There already exists a node at {path}.') - if not path.parent.is_dir(): - raise HandledException( - f'No directory {path.parent} found to write into.') - with sql_connect(path) as conn: - conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8')) - conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}') - - def migrate(self, migrations: set['DbMigration']) -> None: - """Migrate self towards EXPECTED_DB_VERSION""" - start_version = self._get_user_version() - if start_version == EXPECTED_DB_VERSION: - raise HandledException( - f'Already at {EXPECTED_DB_VERSION}, nothing to migrate.') - if start_version > EXPECTED_DB_VERSION: - raise HandledException( - f'Cannot migrate backwards from {start_version}' - f'to {EXPECTED_DB_VERSION}.') - with DbConn(self) as conn: - for migration in DbMigration.from_to_in_set( - start_version, EXPECTED_DB_VERSION, migrations): - migration.perform(conn) - self._validate_schema(conn) - conn.commit() - - -class DbMigration: - """Representation of DbFile migration data.""" - - def __init__(self, - version: int, - sql_path: Optional[Path] = None, - after_sql_steps: Optional[Callable[['DbConn'], None]] = None - ) -> None: - if sql_path: - start_tok = str(sql_path).split('_', maxsplit=1)[0] - if (not start_tok.isdigit()) or int(start_tok) != version: - raise HandledException( - f'migration {version} mapped to bad path {sql_path}') - self._version = version - self._sql_path = sql_path - self._after_sql_steps = after_sql_steps +class DbMigration(PlomDbMigration): + """Collects and enacts DbFile migration commands.""" + migs_dir_path = _PATH_MIGRATIONS @classmethod - def from_to_in_set( - cls, from_version: int, to_version: int, migs_set: set[Self] - ) -> list[Self]: - """From migs_set make sorted unbroken list from_version to_version.""" + def gather(cls, + from_version: int, + base_set: set[TypePlomDbMigration] + ) -> list[TypePlomDbMigration]: selected_migs = [] - for version in [n+1 for n in range(from_version, to_version)]: - matching_migs = [m for m in migs_set if version == m._version] + for version in [n+1 for n in range(from_version, + _EXPECTED_DB_VERSION)]: + matching_migs = [m for m in base_set # cls.collection + if version == m.target_version] if not matching_migs: raise HandledException(f'Missing migration of v{version}') if len(matching_migs) > 1: @@ -183,43 +76,19 @@ class DbMigration: selected_migs += [matching_migs[0]] return selected_migs - def perform(self, conn: 'DbConn') -> None: - """Do 1) script at sql_path, 2) after_sql_steps, 3) versino setting.""" - if self._sql_path: - conn.exec_script( - SqlText(_PATH_MIGRATIONS.joinpath(self._sql_path) - .read_text(encoding='utf8'))) - if self._after_sql_steps: - self._after_sql_steps(conn) - conn.exec(SqlText(f'{_SQL_DB_VERSION} = {self._version}')) - - -class DbConn: - """Wrapper for sqlite3 connections.""" - - def __init__(self, db_file: Optional[DbFile] = None) -> None: - self._conn = sql_connect((db_file or DbFile()).path, autocommit=False) - self.commit = self._conn.commit - - def __enter__(self) -> Self: - return self - def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]: - self._conn.close() - return False +class DbFile(PlomDbFile): + """File readable as DB of expected schema, user version.""" + indent_n = 2 + target_version = _EXPECTED_DB_VERSION + path_schema = _PATH_DB_SCHEMA + default_path = PATH_DB + mig_class = DbMigration - def exec(self, sql: SqlText, inputs: tuple = tuple() - ) -> SqlCursor: - """Wrapper around sqlite3.Connection.execute, building '?' if inputs""" - if len(inputs) > 0: - q_marks = ('?' if len(inputs) == 1 - else '(' + ','.join(['?'] * len(inputs)) + ')') - return self._conn.execute(SqlText(f'{sql} {q_marks}'), inputs) - return self._conn.execute(sql) - def exec_script(self, sql: SqlText) -> None: - """Wrapper around sqlite3.Connection.executescript.""" - self._conn.executescript(sql) +class DbConn(PlomDbConn): + """SQL connection to DbFile.""" + default_path = PATH_DB class DbData: @@ -253,7 +122,7 @@ class DbData: @classmethod def get_one(cls, conn: DbConn, id_: str | Hash) -> Self: """Return single entry of id_ from DB.""" - sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE {cls.id_name} =') + sql = f'SELECT * FROM {cls._table_name} WHERE {cls.id_name} =' id__ = id_.bytes if isinstance(id_, Hash) else id_ row = conn.exec(sql, (id__,)).fetchone() if not row: @@ -264,7 +133,7 @@ class DbData: @classmethod def get_all(cls, conn: DbConn) -> list[Self]: """Return all entries from DB.""" - sql = SqlText(f'SELECT * FROM {cls._table_name}') + sql = f'SELECT * FROM {cls._table_name}' rows = conn.exec(sql).fetchall() return [cls._from_table_row(row) for row in rows] @@ -277,5 +146,5 @@ class DbData: elif isinstance(val, Hash): val = val.bytes vals += [val] - conn.exec(SqlText(f'REPLACE INTO {self._table_name} VALUES'), + conn.exec(f'REPLACE INTO {self._table_name} VALUES', tuple(vals)) diff --git a/src/ytplom/http.py b/src/ytplom/http.py index 5fef79d..c57dfae 100644 --- a/src/ytplom/http.py +++ b/src/ytplom/http.py @@ -367,7 +367,6 @@ class _TaskHandler(BaseHTTPRequestHandler): self._send_http(f.read(), [(_HEADER_CONTENT_TYPE, 'image/jpg')]) def _send_yt_result(self, video_id: YoutubeId) -> None: - conn = DbConn() with DbConn() as conn: linked_queries = YoutubeQuery.get_all_for_video(conn, video_id) try: diff --git a/src/ytplom/migrations.py b/src/ytplom/migrations.py index f075cdf..aadec78 100644 --- a/src/ytplom/migrations.py +++ b/src/ytplom/migrations.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Callable # ourselves -from ytplom.db import DbConn, DbFile, DbMigration, SqlText +from ytplom.db import DbConn, DbMigration from ytplom.primitives import HandledException @@ -14,10 +14,10 @@ _LEGIT_YES = 'YES!!' def _rewrite_files_last_field_processing_first_field(conn: DbConn, cb: Callable ) -> None: - rows = conn.exec(SqlText('SELECT * FROM files')).fetchall() + rows = conn.exec('SELECT * FROM files').fetchall() for row in [list(r) for r in rows]: row[-1] = cb(row[0]) - conn.exec(SqlText('REPLACE INTO files VALUES'), tuple(row)) + conn.exec('REPLACE INTO files VALUES', tuple(row)) def _mig_2_calc_digests(conn: DbConn) -> None: @@ -27,7 +27,7 @@ def _mig_2_calc_digests(conn: DbConn) -> None: from ytplom.misc import PATH_DOWNLOADS rel_paths = [ p[0] for p - in conn.exec(SqlText('SELECT rel_path FROM files')).fetchall()] + in conn.exec('SELECT rel_path FROM files').fetchall()] missing = [p for p in rel_paths if not Path(PATH_DOWNLOADS).joinpath(p).exists()] if missing: @@ -40,7 +40,7 @@ def _mig_2_calc_digests(conn: DbConn) -> None: if _LEGIT_YES != reply: raise HandledException('Migration aborted!') for path in missing: - conn.exec(SqlText('DELETE FROM files WHERE rel_path ='), (path,)) + conn.exec('DELETE FROM files WHERE rel_path =', (path,)) def hexdigest_file(path): print(f'Calculating digest for: {path}') @@ -55,17 +55,13 @@ def _mig_4_convert_digests(conn: DbConn) -> None: _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex) -_MIGRATIONS: set[DbMigration] = { +MIGRATIONS: set[DbMigration] = { DbMigration(0, Path('0_init.sql'), None), DbMigration(1, Path('1_add_files_last_updated.sql'), None), DbMigration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests), DbMigration(3, Path('3_files_redo.sql'), None), - DbMigration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests), + DbMigration(4, Path('4_add_files_sha512_blob.sql'), + _mig_4_convert_digests), DbMigration(5, Path('5_files_redo.sql'), None), DbMigration(6, Path('6_add_files_tags.sql'), None) } - - -def migrate(): - """Migrate DB file at expected default path to most recent version.""" - DbFile(version_to_validate=-1).migrate(_MIGRATIONS) diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py index 5dc6cf5..44ab385 100644 --- a/src/ytplom/misc.py +++ b/src/ytplom/misc.py @@ -18,7 +18,7 @@ import googleapiclient.discovery # type: ignore from mpv import MPV # type: ignore from yt_dlp import YoutubeDL # type: ignore # ourselves -from ytplom.db import DbConn, DbData, Hash, SqlText +from ytplom.db import DbConn, DbData, Hash from ytplom.primitives import HandledException, NotFoundException @@ -254,8 +254,7 @@ class YoutubeQuery(DbData): def get_all_for_video(cls, conn: DbConn, video_id: YoutubeId ) -> list[Self]: """Return YoutubeQueries containing YoutubeVideo's ID in results.""" - sql = SqlText('SELECT query_id FROM ' - 'yt_query_results WHERE video_id =') + sql = 'SELECT query_id FROM yt_query_results WHERE video_id =' query_ids = conn.exec(sql, (video_id,)).fetchall() return [cls.get_one(conn, query_id_tup[0]) for query_id_tup in query_ids] @@ -305,16 +304,14 @@ class YoutubeVideo(DbData): @classmethod def get_all_for_query(cls, conn: DbConn, query_id: QueryId) -> list[Self]: """Return all videos for query of query_id.""" - sql = SqlText('SELECT video_id ' - 'FROM yt_query_results WHERE query_id =') + sql = 'SELECT video_id FROM yt_query_results WHERE query_id =' video_ids = conn.exec(sql, (query_id,)).fetchall() return [cls.get_one(conn, video_id_tup[0]) for video_id_tup in video_ids] def save_to_query(self, conn: DbConn, query_id: QueryId) -> None: """Save inclusion of self in results to query of query_id.""" - conn.exec(SqlText('REPLACE INTO yt_query_results VALUES'), - (query_id, self.id_)) + conn.exec('REPLACE INTO yt_query_results VALUES', (query_id, self.id_)) class VideoFile(DbData): @@ -373,8 +370,8 @@ class VideoFile(DbData): @classmethod def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self: """Return VideoFile of .yt_id.""" - sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id =') - row = conn.exec(sql, (yt_id,)).fetchone() + row = conn.exec(f'SELECT * FROM {cls._table_name} WHERE yt_id =', + (yt_id,)).fetchone() if not row: raise NotFoundException(f'no entry for file to Youtube ID {yt_id}') return cls._from_table_row(row) @@ -489,9 +486,8 @@ class VideoFile(DbData): if file.present: file.unlink_locally() print(f'SYNC: purging off DB: {file.digest.b64} ({file.rel_path})') - conn.exec( - SqlText(f'DELETE FROM {cls._table_name} WHERE digest ='), - (file.digest.bytes,)) + conn.exec(f'DELETE FROM {cls._table_name} WHERE digest =', + (file.digest.bytes,)) class QuotaLog(DbData): @@ -528,8 +524,8 @@ class QuotaLog(DbData): @classmethod def _remove_old(cls, conn: DbConn) -> None: cutoff = datetime.now() - timedelta(days=1) - sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp <') - conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),)) + conn.exec(f'DELETE FROM {cls._table_name} WHERE timestamp <', + (cutoff.strftime(TIMESTAMP_FMT),)) class Player: