# included libs
from base64 import urlsafe_b64decode, urlsafe_b64encode
-from difflib import Differ
from hashlib import file_digest
from pathlib import Path
-from sqlite3 import (
- connect as sql_connect, Cursor as SqlCursor, Row as SqlRow)
-from typing import Callable, Literal, NewType, Optional, Self
+from sqlite3 import Row as SqlRow
+from typing import Self
# ourselves
+from plomlib.db import (
+ PlomDbConn, PlomDbFile, PlomDbMigration, TypePlomDbMigration)
from ytplom.primitives import (
HandledException, NotFoundException, PATH_APP_DATA)
-EXPECTED_DB_VERSION = 6
PATH_DB = PATH_APP_DATA.joinpath('db.sql')
-SqlText = NewType('SqlText', str)
-
-_SQL_DB_VERSION = SqlText('PRAGMA user_version')
+_EXPECTED_DB_VERSION = 6
_PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
_HASH_ALGO = 'sha512'
_PATH_DB_SCHEMA = _PATH_MIGRATIONS.joinpath('new_init.sql')
return urlsafe_b64encode(self.bytes).decode('utf8')
-class DbFile:
- """Wrapper around the file of a sqlite3 database."""
-
- def __init__(self,
- path: Path = PATH_DB,
- version_to_validate: int = EXPECTED_DB_VERSION
- ) -> None:
- self.path = path
- if not self.path.is_file():
- raise HandledException(f'no DB file at {self.path}')
- if version_to_validate < 0:
- return
- if (user_version := self._get_user_version()) != version_to_validate:
- raise HandledException(
- f'wrong DB version {user_version} (!= {version_to_validate})')
- with DbConn(self) as conn:
- self._validate_schema(conn)
-
- @staticmethod
- def _validate_schema(conn: 'DbConn'):
- schema_rows_normed = []
- indent = ' '
- for row in [
- r[0] for r in conn.exec(SqlText(
- 'SELECT sql FROM sqlite_master ORDER BY sql'))
- if r[0]]:
- row_normed = []
- for subrow in [sr.rstrip() for sr in row.split('\n')]:
- in_parentheses = 0
- split_at = []
- for i, c in enumerate(subrow):
- if '(' == c:
- in_parentheses += 1
- elif ')' == c:
- in_parentheses -= 1
- elif ',' == c and 0 == in_parentheses:
- split_at += [i + 1]
- prev_split = 0
- for i in split_at:
- if segment := subrow[prev_split:i].strip():
- row_normed += [f'{indent}{segment}']
- prev_split = i
- if segment := subrow[prev_split:].strip():
- row_normed += [f'{indent}{segment}']
- row_normed[0] = row_normed[0].lstrip() # no indent for opening …
- row_normed[-1] = row_normed[-1].lstrip() # … and closing line
- if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
- row_normed[-3] = row_normed[-3] + ','
- row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
- row_normed[-1] = row_normed[-1] + ';'
- schema_rows_normed += row_normed
- if ((expected_rows :=
- _PATH_DB_SCHEMA.read_text(encoding='utf8').rstrip().splitlines()
- ) != schema_rows_normed):
- raise HandledException(
- 'Unexpected tables schema. Diff to {_PATH_DB_SCHEMA}:\n' +
- '\n'.join(Differ().compare(schema_rows_normed, expected_rows)))
-
- def _get_user_version(self) -> int:
- with sql_connect(self.path) as conn:
- return list(conn.execute(_SQL_DB_VERSION))[0][0]
-
- @staticmethod
- def create(path: Path = PATH_DB) -> None:
- """Create DB file at path according to _PATH_DB_SCHEMA."""
- if path.exists():
- raise HandledException(
- f'There already exists a node at {path}.')
- if not path.parent.is_dir():
- raise HandledException(
- f'No directory {path.parent} found to write into.')
- with sql_connect(path) as conn:
- conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
- conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
-
- def migrate(self, migrations: set['DbMigration']) -> None:
- """Migrate self towards EXPECTED_DB_VERSION"""
- start_version = self._get_user_version()
- if start_version == EXPECTED_DB_VERSION:
- raise HandledException(
- f'Already at {EXPECTED_DB_VERSION}, nothing to migrate.')
- if start_version > EXPECTED_DB_VERSION:
- raise HandledException(
- f'Cannot migrate backwards from {start_version}'
- f'to {EXPECTED_DB_VERSION}.')
- with DbConn(self) as conn:
- for migration in DbMigration.from_to_in_set(
- start_version, EXPECTED_DB_VERSION, migrations):
- migration.perform(conn)
- self._validate_schema(conn)
- conn.commit()
-
-
-class DbMigration:
- """Representation of DbFile migration data."""
-
- def __init__(self,
- version: int,
- sql_path: Optional[Path] = None,
- after_sql_steps: Optional[Callable[['DbConn'], None]] = None
- ) -> None:
- if sql_path:
- start_tok = str(sql_path).split('_', maxsplit=1)[0]
- if (not start_tok.isdigit()) or int(start_tok) != version:
- raise HandledException(
- f'migration {version} mapped to bad path {sql_path}')
- self._version = version
- self._sql_path = sql_path
- self._after_sql_steps = after_sql_steps
+class DbMigration(PlomDbMigration):
+ """Collects and enacts DbFile migration commands."""
+ migs_dir_path = _PATH_MIGRATIONS
@classmethod
- def from_to_in_set(
- cls, from_version: int, to_version: int, migs_set: set[Self]
- ) -> list[Self]:
- """From migs_set make sorted unbroken list from_version to_version."""
+ def gather(cls,
+ from_version: int,
+ base_set: set[TypePlomDbMigration]
+ ) -> list[TypePlomDbMigration]:
selected_migs = []
- for version in [n+1 for n in range(from_version, to_version)]:
- matching_migs = [m for m in migs_set if version == m._version]
+ for version in [n+1 for n in range(from_version,
+ _EXPECTED_DB_VERSION)]:
+ matching_migs = [m for m in base_set # cls.collection
+ if version == m.target_version]
if not matching_migs:
raise HandledException(f'Missing migration of v{version}')
if len(matching_migs) > 1:
selected_migs += [matching_migs[0]]
return selected_migs
- def perform(self, conn: 'DbConn') -> None:
- """Do 1) script at sql_path, 2) after_sql_steps, 3) versino setting."""
- if self._sql_path:
- conn.exec_script(
- SqlText(_PATH_MIGRATIONS.joinpath(self._sql_path)
- .read_text(encoding='utf8')))
- if self._after_sql_steps:
- self._after_sql_steps(conn)
- conn.exec(SqlText(f'{_SQL_DB_VERSION} = {self._version}'))
-
-
-class DbConn:
- """Wrapper for sqlite3 connections."""
-
- def __init__(self, db_file: Optional[DbFile] = None) -> None:
- self._conn = sql_connect((db_file or DbFile()).path, autocommit=False)
- self.commit = self._conn.commit
-
- def __enter__(self) -> Self:
- return self
- def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]:
- self._conn.close()
- return False
+class DbFile(PlomDbFile):
+ """File readable as DB of expected schema, user version."""
+ indent_n = 2
+ target_version = _EXPECTED_DB_VERSION
+ path_schema = _PATH_DB_SCHEMA
+ default_path = PATH_DB
+ mig_class = DbMigration
- def exec(self, sql: SqlText, inputs: tuple = tuple()
- ) -> SqlCursor:
- """Wrapper around sqlite3.Connection.execute, building '?' if inputs"""
- if len(inputs) > 0:
- q_marks = ('?' if len(inputs) == 1
- else '(' + ','.join(['?'] * len(inputs)) + ')')
- return self._conn.execute(SqlText(f'{sql} {q_marks}'), inputs)
- return self._conn.execute(sql)
- def exec_script(self, sql: SqlText) -> None:
- """Wrapper around sqlite3.Connection.executescript."""
- self._conn.executescript(sql)
+class DbConn(PlomDbConn):
+ """SQL connection to DbFile."""
+ default_path = PATH_DB
class DbData:
@classmethod
def get_one(cls, conn: DbConn, id_: str | Hash) -> Self:
"""Return single entry of id_ from DB."""
- sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE {cls.id_name} =')
+ sql = f'SELECT * FROM {cls._table_name} WHERE {cls.id_name} ='
id__ = id_.bytes if isinstance(id_, Hash) else id_
row = conn.exec(sql, (id__,)).fetchone()
if not row:
@classmethod
def get_all(cls, conn: DbConn) -> list[Self]:
"""Return all entries from DB."""
- sql = SqlText(f'SELECT * FROM {cls._table_name}')
+ sql = f'SELECT * FROM {cls._table_name}'
rows = conn.exec(sql).fetchall()
return [cls._from_table_row(row) for row in rows]
elif isinstance(val, Hash):
val = val.bytes
vals += [val]
- conn.exec(SqlText(f'REPLACE INTO {self._table_name} VALUES'),
+ conn.exec(f'REPLACE INTO {self._table_name} VALUES',
tuple(vals))
from pathlib import Path
from typing import Callable
# ourselves
-from ytplom.db import DbConn, DbFile, DbMigration, SqlText
+from ytplom.db import DbConn, DbMigration
from ytplom.primitives import HandledException
def _rewrite_files_last_field_processing_first_field(conn: DbConn,
cb: Callable
) -> None:
- rows = conn.exec(SqlText('SELECT * FROM files')).fetchall()
+ rows = conn.exec('SELECT * FROM files').fetchall()
for row in [list(r) for r in rows]:
row[-1] = cb(row[0])
- conn.exec(SqlText('REPLACE INTO files VALUES'), tuple(row))
+ conn.exec('REPLACE INTO files VALUES', tuple(row))
def _mig_2_calc_digests(conn: DbConn) -> None:
from ytplom.misc import PATH_DOWNLOADS
rel_paths = [
p[0] for p
- in conn.exec(SqlText('SELECT rel_path FROM files')).fetchall()]
+ in conn.exec('SELECT rel_path FROM files').fetchall()]
missing = [p for p in rel_paths
if not Path(PATH_DOWNLOADS).joinpath(p).exists()]
if missing:
if _LEGIT_YES != reply:
raise HandledException('Migration aborted!')
for path in missing:
- conn.exec(SqlText('DELETE FROM files WHERE rel_path ='), (path,))
+ conn.exec('DELETE FROM files WHERE rel_path =', (path,))
def hexdigest_file(path):
print(f'Calculating digest for: {path}')
_rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
-_MIGRATIONS: set[DbMigration] = {
+MIGRATIONS: set[DbMigration] = {
DbMigration(0, Path('0_init.sql'), None),
DbMigration(1, Path('1_add_files_last_updated.sql'), None),
DbMigration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
DbMigration(3, Path('3_files_redo.sql'), None),
- DbMigration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+ DbMigration(4, Path('4_add_files_sha512_blob.sql'),
+ _mig_4_convert_digests),
DbMigration(5, Path('5_files_redo.sql'), None),
DbMigration(6, Path('6_add_files_tags.sql'), None)
}
-
-
-def migrate():
- """Migrate DB file at expected default path to most recent version."""
- DbFile(version_to_validate=-1).migrate(_MIGRATIONS)
from mpv import MPV # type: ignore
from yt_dlp import YoutubeDL # type: ignore
# ourselves
-from ytplom.db import DbConn, DbData, Hash, SqlText
+from ytplom.db import DbConn, DbData, Hash
from ytplom.primitives import HandledException, NotFoundException
def get_all_for_video(cls, conn: DbConn, video_id: YoutubeId
) -> list[Self]:
"""Return YoutubeQueries containing YoutubeVideo's ID in results."""
- sql = SqlText('SELECT query_id FROM '
- 'yt_query_results WHERE video_id =')
+ sql = 'SELECT query_id FROM yt_query_results WHERE video_id ='
query_ids = conn.exec(sql, (video_id,)).fetchall()
return [cls.get_one(conn, query_id_tup[0])
for query_id_tup in query_ids]
@classmethod
def get_all_for_query(cls, conn: DbConn, query_id: QueryId) -> list[Self]:
"""Return all videos for query of query_id."""
- sql = SqlText('SELECT video_id '
- 'FROM yt_query_results WHERE query_id =')
+ sql = 'SELECT video_id FROM yt_query_results WHERE query_id ='
video_ids = conn.exec(sql, (query_id,)).fetchall()
return [cls.get_one(conn, video_id_tup[0])
for video_id_tup in video_ids]
def save_to_query(self, conn: DbConn, query_id: QueryId) -> None:
"""Save inclusion of self in results to query of query_id."""
- conn.exec(SqlText('REPLACE INTO yt_query_results VALUES'),
- (query_id, self.id_))
+ conn.exec('REPLACE INTO yt_query_results VALUES', (query_id, self.id_))
class VideoFile(DbData):
@classmethod
def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self:
"""Return VideoFile of .yt_id."""
- sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id =')
- row = conn.exec(sql, (yt_id,)).fetchone()
+ row = conn.exec(f'SELECT * FROM {cls._table_name} WHERE yt_id =',
+ (yt_id,)).fetchone()
if not row:
raise NotFoundException(f'no entry for file to Youtube ID {yt_id}')
return cls._from_table_row(row)
if file.present:
file.unlink_locally()
print(f'SYNC: purging off DB: {file.digest.b64} ({file.rel_path})')
- conn.exec(
- SqlText(f'DELETE FROM {cls._table_name} WHERE digest ='),
- (file.digest.bytes,))
+ conn.exec(f'DELETE FROM {cls._table_name} WHERE digest =',
+ (file.digest.bytes,))
class QuotaLog(DbData):
@classmethod
def _remove_old(cls, conn: DbConn) -> None:
cutoff = datetime.now() - timedelta(days=1)
- sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp <')
- conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),))
+ conn.exec(f'DELETE FROM {cls._table_name} WHERE timestamp <',
+ (cutoff.strftime(TIMESTAMP_FMT),))
class Player: