# non-included libs
from paramiko import SSHClient # type: ignore
from scp import SCPClient # type: ignore
-from ytplom.db import DbConn, Hash, PATH_DB
+from ytplom.db import DbConn, DbFile, Hash, PATH_DB
from ytplom.misc import (PATH_TEMP, Config, FlagName, QuotaLog, VideoFile,
YoutubeQuery, YoutubeVideo)
from ytplom.http import PAGE_NAMES
def sync_dbs(scp: SCPClient) -> None:
"""Download remote DB, run sync_(objects|relations), put remote DB back."""
scp.get(PATH_DB, PATH_DB_REMOTE)
- with DbConn(PATH_DB) as db_local, DbConn(PATH_DB_REMOTE) as db_remote:
+ with DbConn(DbFile(PATH_DB).connect()) as db_local, \
+ DbConn(DbFile(PATH_DB_REMOTE).connect()) as db_remote:
for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile):
back_and_forth(sync_objects, (db_local, db_remote), cls)
for yt_video_local in YoutubeVideo.get_all(db_local):
for url_missing in _urls_here_and_there(config, 'missing'):
with urlopen(url_missing) as response:
missings += [list(json_loads(response.read()))]
- conn = DbConn()
- for i, direction_mover in enumerate([('local->remote', scp.put),
- ('remote->local', scp.get)]):
- direction, mover = direction_mover
- for digest in (d for d in missings[i]
- if d not in missings[int(not bool(i))]):
- vf = VideoFile.get_one(conn, Hash.from_b64(digest))
- if vf.is_flag_set(FlagName('do not sync')):
- print(f'SYNC: not sending ("do not sync" set): {vf.full_path}')
- return
- print(f'SYNC: sending {direction}: {vf.full_path}')
- mover(vf.full_path, vf.full_path)
+ with DbConn() as conn:
+ for i, direction_mover in enumerate([('local->remote', scp.put),
+ ('remote->local', scp.get)]):
+ direction, mover = direction_mover
+ for digest in (d for d in missings[i]
+ if d not in missings[int(not bool(i))]):
+ vf = VideoFile.get_one(conn, Hash.from_b64(digest))
+ if vf.is_flag_set(FlagName('do not sync')):
+ print(f'SYNC: not sending ("do not sync" set)'
+ f': {vf.full_path}')
+ return
+ print(f'SYNC: sending {direction}: {vf.full_path}')
+ mover(vf.full_path, vf.full_path)
def main():
from base64 import urlsafe_b64decode, urlsafe_b64encode
from hashlib import file_digest
from pathlib import Path
-from sqlite3 import connect as sql_connect, Cursor as SqlCursor, Row as SqlRow
-from typing import Any, Literal, NewType, Self
+from sqlite3 import (connect as sql_connect, Connection as SqlConnection,
+ Cursor as SqlCursor, Row as SqlRow)
+from typing import Callable, Literal, NewType, Optional, Self
from ytplom.primitives import (
HandledException, NotFoundException, PATH_APP_DATA)
-SqlText = NewType('SqlText', str)
-
EXPECTED_DB_VERSION = 6
-PATH_DB = PATH_APP_DATA.joinpath('db.sql')
-SQL_DB_VERSION = SqlText('PRAGMA user_version')
-PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
-_HASH_ALGO = 'sha512'
-_PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath('new_init.sql')
-_NAME_INSTALLER = Path('install.sh')
+PATH_DB = PATH_APP_DATA.joinpath('TESTdb.sql')
+SqlText = NewType('SqlText', str)
+MigrationsList = list[tuple[Path, Optional[Callable]]]
-def get_db_version(db_path: Path) -> int:
- """Return user_version value of DB at db_path."""
- with sql_connect(db_path) as conn:
- return list(conn.execute(SQL_DB_VERSION))[0][0]
+_SQL_DB_VERSION = SqlText('PRAGMA user_version')
+_PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
+_HASH_ALGO = 'sha512'
+_PATH_DB_SCHEMA = _PATH_MIGRATIONS.joinpath('new_init.sql')
class Hash:
return urlsafe_b64encode(self.bytes).decode('utf8')
-class DbConn:
- """Wrapper for sqlite3 connections."""
+class DbFile:
+ """Wrapper around the file of a sqlite3 database."""
def __init__(self,
path: Path = PATH_DB,
expected_version: int = EXPECTED_DB_VERSION
) -> None:
+ self._path = path
if not path.is_file():
- if path.exists():
- raise HandledException(f'no DB at {path}; would create, '
- 'but something\'s already there?')
- if not path.parent.is_dir():
- raise HandledException(
- f'cannot find {path.parent} as directory to put '
- f'DB into, did you run {_NAME_INSTALLER}?')
- with sql_connect(path) as conn:
- print(f'No DB found at {path}, creating …')
- conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
- conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
+ raise HandledException(
+ f'no DB file at {path} – run "create"?')
if expected_version >= 0:
- cur_version = get_db_version(path)
- if cur_version != expected_version:
+ user_version = self._get_user_version()
+ if user_version != expected_version:
raise HandledException(
- f'wrong database version {cur_version}, expected: '
+ f'wrong database version {user_version}, expected: '
f'{expected_version} – run "migrate"?')
- self._conn = sql_connect(path, autocommit=False)
+
+ def _get_user_version(self) -> int:
+ with sql_connect(self._path) as conn:
+ return list(conn.execute(_SQL_DB_VERSION))[0][0]
+
+ @staticmethod
+ def create(path: Path = PATH_DB) -> None:
+ """Create DB file at path according to _PATH_DB_SCHEMA."""
+ if path.exists():
+ raise HandledException(
+ f'There already exists a node at {path}.')
+ if not path.parent.is_dir():
+ raise HandledException(
+ f'No directory {path.parent} found to write into.')
+ with sql_connect(path) as conn:
+ conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
+ conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
+
+ def connect(self) -> SqlConnection:
+ """Open database file into SQL connection, with autocommit off."""
+ return sql_connect(self._path, autocommit=False)
+
+ def migrate(self, migrations: MigrationsList) -> None:
+ """Migrate self towards EXPECTED_DB_VERSION"""
+ start_version = self._get_user_version()
+ if start_version == EXPECTED_DB_VERSION:
+ print('Database at expected version, no migrations to do.')
+ return
+ if start_version > EXPECTED_DB_VERSION:
+ raise HandledException(
+ f'Cannot migrate backward from version {start_version} to '
+ f'{EXPECTED_DB_VERSION}.')
+ print(f'Trying to migrate from DB version {start_version} to '
+ f'{EXPECTED_DB_VERSION} …')
+ migs_to_do = []
+ migs_by_n = dict(enumerate(migrations))
+ for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
+ if n not in migs_by_n:
+ raise HandledException(f'Needed migration missing: {n}')
+ migs_to_do += [(n, *migs_by_n[n])]
+ for version, filename_sql, after_sql_steps in migs_to_do:
+ print(f'Running migration towards: {version}')
+ with DbConn(self.connect()) as conn:
+ if filename_sql:
+ print(f'Executing {filename_sql}')
+ path_sql = _PATH_MIGRATIONS.joinpath(filename_sql)
+ conn.exec_script(
+ SqlText(path_sql.read_text(encoding='utf8')))
+ if after_sql_steps:
+ print('Running additional steps')
+ after_sql_steps(conn)
+ conn.exec(SqlText(f'{_SQL_DB_VERSION} = {version}'))
+ conn.commit()
+ print('Finished migrations.')
+
+
+class DbConn:
+ """Wrapper for sqlite3 connections."""
+
+ def __init__(self, sql_conn: Optional[SqlConnection] = None) -> None:
+ self._conn = sql_conn or DbFile().connect()
def __enter__(self) -> Self:
return self
self._conn.close()
return False
- def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()
+ def exec(self, sql: SqlText, inputs: tuple = tuple()
) -> SqlCursor:
"""Wrapper around sqlite3.Connection.execute, building '?' if inputs"""
if len(inputs) > 0:
_str_field: str
_cols: tuple[str, ...]
- def __eq__(self, other: Any) -> bool:
+ def __eq__(self, other) -> bool:
if not isinstance(other, self.__class__):
return False
for attr_name in self._cols:
"""Anything pertaining specifically to DB migrations."""
from pathlib import Path
from typing import Callable
-from ytplom.db import (
- get_db_version, DbConn, SqlText,
- EXPECTED_DB_VERSION, PATH_DB, PATH_MIGRATIONS, SQL_DB_VERSION)
+from ytplom.db import DbConn, MigrationsList, SqlText
from ytplom.primitives import HandledException
_LEGIT_YES = 'YES!!'
-def run_migrations() -> None:
- """Try to migrate DB towards EXPECTED_DB_VERSION."""
- start_version = get_db_version(PATH_DB)
- if start_version == EXPECTED_DB_VERSION:
- print('Database at expected version, no migrations to do.')
- return
- if start_version > EXPECTED_DB_VERSION:
- raise HandledException(
- f'Cannot migrate backward from version {start_version} to '
- f'{EXPECTED_DB_VERSION}.')
- print(f'Trying to migrate from DB version {start_version} to '
- f'{EXPECTED_DB_VERSION} …')
- migs_to_do = []
- migs_by_n = dict(enumerate(MIGRATIONS))
- for n in [n+1 for n in range(start_version, EXPECTED_DB_VERSION)]:
- if n not in migs_by_n:
- raise HandledException(f'Needed migration missing: {n}')
- migs_to_do += [(n, *migs_by_n[n])]
- for version, filename_sql, after_sql_steps in migs_to_do:
- print(f'Running migration towards: {version}')
- with DbConn(expected_version=version-1) as conn:
- if filename_sql:
- print(f'Executing {filename_sql}')
- path_sql = PATH_MIGRATIONS.joinpath(filename_sql)
- conn.exec_script(SqlText(path_sql.read_text(encoding='utf8')))
- if after_sql_steps:
- print('Running additional steps')
- after_sql_steps(conn)
- conn.exec(SqlText(f'{SQL_DB_VERSION} = {version}'))
- conn.commit()
- print('Finished migrations.')
-
-
def _rewrite_files_last_field_processing_first_field(conn: DbConn,
cb: Callable
) -> None:
_rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
-MIGRATIONS = [
+MIGRATIONS: MigrationsList = [
(Path('0_init.sql'), None),
(Path('1_add_files_last_updated.sql'), None),
(Path('2_add_files_sha512.sql'), _mig_2_calc_digests),