from base64 import urlsafe_b64decode, urlsafe_b64encode
from hashlib import file_digest
from pathlib import Path
-from sqlite3 import (
- connect as sql_connect, Connection as SqlConnection, Cursor, Row)
+from sqlite3 import connect as sql_connect, Cursor as DbCursor, Row
from typing import Any, Literal, NewType, Self
from ytplom.primitives import (
HandledException, NotFoundException, PATH_APP_DATA)
return urlsafe_b64encode(self.bytes).decode('utf8')
-class BaseDbConn:
- """Wrapper for pre-established sqlite3.Connection."""
+class DbConn:
+ """Wrapper for sqlite3 connections."""
- def __init__(self, sql_conn: SqlConnection) -> None:
- self._conn = sql_conn
-
- def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor:
- """Wrapper around sqlite3.Connection.execute."""
- return self._conn.execute(sql, inputs)
-
- def exec_on_values(self, sql: SqlText, inputs: tuple[Any, ...]) -> Cursor:
- """Wraps .exec on inputs, affixes to sql proper ' VALUES (?, …)'."""
- q_marks = '(' + ','.join(['?'] * len(inputs)) + ')'
- return self._conn.execute(f'{sql} VALUES {q_marks}', inputs)
-
- def commit(self) -> None:
- """Commit changes (i.e. DbData.save() calls) to database."""
- self._conn.commit()
-
-
-class DbConn(BaseDbConn):
- """Like parent, but opening (and as context mgr: closing) connection."""
-
- def __init__(self, path: Path = PATH_DB) -> None:
+ def __init__(self,
+ path: Path = PATH_DB,
+ expected_version: int = EXPECTED_DB_VERSION
+ ) -> None:
if not path.is_file():
if path.exists():
raise HandledException(f'no DB at {path}; would create, '
f'cannot find {path.parent} as directory to put '
f'DB into, did you run {_NAME_INSTALLER}?')
with sql_connect(path) as conn:
+ print(f'No DB found at {path}, creating …')
conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
- cur_version = get_db_version(path)
- if cur_version != EXPECTED_DB_VERSION:
- raise HandledException(
- f'wrong database version {cur_version}, expected: '
- f'{EXPECTED_DB_VERSION} – run "migrate"?')
- super().__init__(sql_connect(path, autocommit=False))
+ if expected_version >= 0:
+ cur_version = get_db_version(path)
+ if cur_version != expected_version:
+ raise HandledException(
+ f'wrong database version {cur_version}, expected: '
+ f'{expected_version} – run "migrate"?')
+ self._conn = sql_connect(path, autocommit=False)
def __enter__(self) -> Self:
return self
self._conn.close()
return False
+ def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()
+ ) -> DbCursor:
+ """Wrapper around sqlite3.Connection.execute."""
+ return self._conn.execute(sql, inputs)
+
+ def exec_on_values(self, sql: SqlText, inputs: tuple[Any, ...]
+ ) -> DbCursor:
+ """Wraps .exec on inputs, affixes to sql proper ' VALUES (?, …)'."""
+ q_marks = '(' + ','.join(['?'] * len(inputs)) + ')'
+ return self._conn.execute(f'{sql} VALUES {q_marks}', inputs)
+
+ def exec_script(self, sql: SqlText) -> DbCursor:
+ """Wrapper around sqlite3.Connection.executescript."""
+ return self._conn.executescript(sql)
+
+ def commit(self) -> None:
+ """Commit changes (i.e. DbData.save() calls) to database."""
+ self._conn.commit()
+
class DbData:
"""Abstraction of common DB operation."""
return cls(**kwargs)
@classmethod
- def get_one(cls, conn: BaseDbConn, id_: str | Hash) -> Self:
+ def get_one(cls, conn: DbConn, id_: str | Hash) -> Self:
"""Return single entry of id_ from DB."""
sql = SqlText(f'SELECT * FROM {cls._table_name} '
f'WHERE {cls.id_name} = ?')
return cls._from_table_row(row)
@classmethod
- def get_all(cls, conn: BaseDbConn) -> list[Self]:
+ def get_all(cls, conn: DbConn) -> list[Self]:
"""Return all entries from DB."""
sql = SqlText(f'SELECT * FROM {cls._table_name}')
rows = conn.exec(sql).fetchall()
return [cls._from_table_row(row) for row in rows]
- def save(self, conn: BaseDbConn) -> Cursor:
+ def save(self, conn: DbConn) -> DbCursor:
"""Save entry to DB."""
vals = []
for val in [getattr(self, col_name) for col_name in self._cols]:
"""Anything pertaining specifically to DB migrations."""
from pathlib import Path
-from sqlite3 import connect as sql_connect, Connection as SqlConnection
-from typing import Callable, Optional
+from typing import Callable
from ytplom.db import (
- get_db_version, BaseDbConn, SqlText, EXPECTED_DB_VERSION, PATH_DB,
- PATH_MIGRATIONS, SQL_DB_VERSION)
+ get_db_version, DbConn, SqlText,
+ EXPECTED_DB_VERSION, PATH_DB, PATH_MIGRATIONS, SQL_DB_VERSION)
from ytplom.primitives import HandledException
_LEGIT_YES = 'YES!!'
-class _Migration:
- """Wrapper for SQL and Python code to apply on migrating."""
-
- def __init__(self,
- version: int,
- filename_sql: Optional[Path] = None,
- after_sql_steps: Optional[Callable] = None
- ) -> None:
- self.version = version
- self._filename_sql = filename_sql
- self._sql_code = None
- if filename_sql:
- path_sql = PATH_MIGRATIONS.joinpath(filename_sql)
- self._sql_code = path_sql.read_text(encoding='utf8')
- self._after_sql_steps = after_sql_steps
-
- def apply_to(self, path_db: Path):
- """Apply to DB at path_db migration code stored in self."""
- with sql_connect(path_db, autocommit=False) as conn:
- if self._sql_code:
- print(f'Executing {self._filename_sql}')
- conn.executescript(self._sql_code)
- if self._after_sql_steps:
- print('Running additional steps')
- self._after_sql_steps(conn)
- conn.execute(SqlText(f'{SQL_DB_VERSION} = {self.version}'))
-
-
def run_migrations() -> None:
"""Try to migrate DB towards EXPECTED_DB_VERSION."""
start_version = get_db_version(PATH_DB)
print(f'Trying to migrate from DB version {start_version} to '
f'{EXPECTED_DB_VERSION} …')
migs_to_do = []
- migs_by_n = {mig.version: mig for mig in MIGRATIONS}
+ migs_by_n = dict(enumerate(MIGRATIONS))
for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
if n not in migs_by_n:
raise HandledException(f'Needed migration missing: {n}')
- migs_to_do += [migs_by_n[n]]
- for mig in migs_to_do:
- print(f'Running migration towards: {mig.version}')
- mig.apply_to(PATH_DB)
+ migs_to_do += [(n, *migs_by_n[n])]
+ for version, filename_sql, after_sql_steps in migs_to_do:
+ print(f'Running migration towards: {version}')
+ with DbConn(expected_version=version-1) as conn:
+ if filename_sql:
+ print(f'Executing {filename_sql}')
+ path_sql = PATH_MIGRATIONS.joinpath(filename_sql)
+ conn.exec_script(SqlText(path_sql.read_text(encoding='utf8')))
+ if after_sql_steps:
+ print('Running additional steps')
+ after_sql_steps(conn)
+ conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}'))
+ conn.commit()
print('Finished migrations.')
-def _rewrite_files_last_field_processing_first_field(conn: BaseDbConn,
+def _rewrite_files_last_field_processing_first_field(conn: DbConn,
cb: Callable
) -> None:
rows = conn.exec(SqlText('SELECT * FROM files')).fetchall()
conn.exec_on_values(SqlText('REPLACE INTO files'), tuple(row))
-def _mig_2_calc_digests(sql_conn: SqlConnection) -> None:
+def _mig_2_calc_digests(conn: DbConn) -> None:
"""Calculate sha512 digests to all known video files."""
# pylint: disable=import-outside-toplevel
from hashlib import file_digest
from ytplom.misc import PATH_DOWNLOADS
- conn = BaseDbConn(sql_conn)
rel_paths = [
p[0] for p
in conn.exec(SqlText('SELECT rel_path FROM files')).fetchall()]
_rewrite_files_last_field_processing_first_field(conn, hexdigest_file)
-def _mig_4_convert_digests(sql_conn: SqlConnection) -> None:
+def _mig_4_convert_digests(conn: DbConn) -> None:
"""Fill new files.sha512_blob field with binary .sha512_digest."""
- _rewrite_files_last_field_processing_first_field(
- BaseDbConn(sql_conn), bytes.fromhex)
+ _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
MIGRATIONS = [
- _Migration(0, Path('0_init.sql')),
- _Migration(1, Path('1_add_files_last_updated.sql')),
- _Migration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
- _Migration(3, Path('3_files_redo.sql')),
- _Migration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
- _Migration(5, Path('5_files_redo.sql')),
- _Migration(6, Path('6_add_files_tags.sql'))
+ (Path('0_init.sql'), None),
+ (Path('1_add_files_last_updated.sql'), None),
+ (Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
+ (Path('3_files_redo.sql'), None),
+ (Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+ (Path('5_files_redo.sql'), None),
+ (Path('6_add_files_tags.sql'), None)
]
from pathlib import Path
from threading import Thread
from queue import Queue
-from sqlite3 import Cursor
# non-included libs
from ffmpeg import probe as ffprobe # type: ignore
import googleapiclient.discovery # type: ignore
from mpv import MPV # type: ignore
from yt_dlp import YoutubeDL # type: ignore
# ourselves
-from ytplom.db import BaseDbConn, DbConn, DbData, Hash, SqlText
+from ytplom.db import DbConn, DbCursor, DbData, Hash, SqlText
from ytplom.primitives import HandledException, NotFoundException
@classmethod
def new_by_request_saved(cls,
- conn: BaseDbConn,
+ conn: DbConn,
config: Config,
query_txt: QueryText
) -> Self:
return query
@classmethod
- def get_all_for_video(cls,
- conn: BaseDbConn,
- video_id: YoutubeId
+ def get_all_for_video(cls, conn: DbConn, video_id: YoutubeId
) -> list[Self]:
"""Return YoutubeQueries containing YoutubeVideo's ID in results."""
sql = SqlText('SELECT query_id FROM '
self.duration = _readable_seconds(seconds)
@classmethod
- def get_all_for_query(cls,
- conn: BaseDbConn,
- query_id: QueryId
- ) -> list[Self]:
+ def get_all_for_query(cls, conn: DbConn, query_id: QueryId) -> list[Self]:
"""Return all videos for query of query_id."""
sql = SqlText('SELECT video_id '
'FROM yt_query_results WHERE query_id = ?')
return [cls.get_one(conn, video_id_tup[0])
for video_id_tup in video_ids]
- def save_to_query(self, conn: BaseDbConn, query_id: QueryId) -> None:
+ def save_to_query(self, conn: DbConn, query_id: QueryId) -> None:
"""Save inclusion of self in results to query of query_id."""
conn.exec_on_values(SqlText('REPLACE INTO yt_query_results'),
(query_id, self.id_))
def _renew_last_update(self):
self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT))
- def save(self, conn: BaseDbConn) -> Cursor:
+ def save(self, conn: DbConn) -> DbCursor:
"""Extend super().save by new .last_update if sufficient changes."""
if hash(self) != self._hash_on_last_update:
self._renew_last_update()
return super().save(conn)
@classmethod
- def get_one_with_whitelist_tags_display(cls, conn: BaseDbConn, id_: Hash,
- whitelist_tags_display: TagSet
- ) -> Self:
+ def get_one_with_whitelist_tags_display(
+ cls, conn: DbConn, id_: Hash, whitelist_tags_display: TagSet
+ ) -> Self:
"""Same as .get_one except sets .whitelist_tags_display."""
vf = cls.get_one(conn, id_)
vf.whitelist_tags_display = whitelist_tags_display
return vf
@classmethod
- def get_by_yt_id(cls, conn: BaseDbConn, yt_id: YoutubeId) -> Self:
+ def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self:
"""Return VideoFile of .yt_id."""
sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id = ?')
row = conn.exec(sql, (yt_id,)).fetchone()
@classmethod
def get_filtered(cls,
- conn: BaseDbConn,
+ conn: DbConn,
filter_path: FilterStr,
needed_tags_dark: TagSet,
needed_tags_seen: TagSet,
'canot show display-whitelisted tags on unset whitelist')
return self.tags.whitelisted(self.whitelist_tags_display)
- def unused_tags(self, conn: BaseDbConn) -> TagSet:
+ def unused_tags(self, conn: DbConn) -> TagSet:
"""Return tags used among other VideoFiles, not in self."""
if self.whitelist_tags_display is None:
raise HandledException(
self.full_path.unlink()
@classmethod
- def purge_deleteds(cls, conn: BaseDbConn) -> None:
+ def purge_deleteds(cls, conn: DbConn) -> None:
"""For all of .is_flag_set("deleted"), remove file _and_ DB entry."""
for file in [f for f in cls.get_all(conn)
if f.is_flag_set(FlagName('delete'))]:
self.cost = cost
@classmethod
- def update(cls, conn: BaseDbConn, cost: QuotaCost) -> None:
+ def update(cls, conn: DbConn, cost: QuotaCost) -> None:
"""Adds cost mapped to current datetime."""
cls._remove_old(conn)
new = cls(None,
new.save(conn)
@classmethod
- def current(cls, conn: BaseDbConn) -> QuotaCost:
+ def current(cls, conn: DbConn) -> QuotaCost:
"""Returns quota cost total for last 24 hours, purges old data."""
cls._remove_old(conn)
quota_costs = cls.get_all(conn)
return QuotaCost(sum(c.cost for c in quota_costs))
@classmethod
- def _remove_old(cls, conn: BaseDbConn) -> None:
+ def _remove_old(cls, conn: DbConn) -> None:
cutoff = datetime.now() - timedelta(days=1)
sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp < ?')
conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),))