From 2c02e74e57e74ec53f3fb5c61516d5a9b12a8e0b Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Sat, 30 Nov 2024 23:55:21 +0100 Subject: [PATCH] Minor DbConnection usage fixes, a renaming for convenience. --- src/sync.py | 25 +++++++++++++------------ src/ytplom/http.py | 24 ++++++++++++------------ src/ytplom/misc.py | 30 +++++++++++++++--------------- 3 files changed, 40 insertions(+), 39 deletions(-) diff --git a/src/sync.py b/src/sync.py index 757deb5..641e78c 100755 --- a/src/sync.py +++ b/src/sync.py @@ -11,7 +11,7 @@ from paramiko import SSHClient # type: ignore from scp import SCPClient # type: ignore from ytplom.misc import ( PATH_DB, PATH_DOWNLOADS, PATH_TEMP, - Config, DbConnection, QuotaLog, VideoFile, YoutubeQuery, YoutubeVideo) + Config, DbConn, QuotaLog, VideoFile, YoutubeQuery, YoutubeVideo) from ytplom.http import PAGE_NAMES @@ -20,7 +20,7 @@ ATTR_NAME_LAST_UPDATE = 'last_update' def back_and_forth(sync_func: Callable, - dbs: tuple[DbConnection, DbConnection], + dbs: tuple[DbConn, DbConn], shared: YoutubeVideo | tuple[Any, str] ) -> None: """Run sync_func on arg pairs + shared, and again with pairs switched.""" @@ -30,7 +30,7 @@ def back_and_forth(sync_func: Callable, def sync_objects(host_names: tuple[str, str], - dbs: tuple[DbConnection, DbConnection], + dbs: tuple[DbConn, DbConn], shared: tuple[Any, str] ) -> None: """Equalize both DB's object collections; prefer newer states to older.""" @@ -58,7 +58,7 @@ def sync_objects(host_names: tuple[str, str], def sync_relations(host_names: tuple[str, str], - dbs: tuple[DbConnection, DbConnection], + dbs: tuple[DbConn, DbConn], yt_result: YoutubeVideo ) -> None: """To dbs[1] add YT yt_video->yt_q_colls[0] mapping not in yt_q_colls[1]""" @@ -78,14 +78,15 @@ def main(): ssh.connect(config.remote) scp = SCPClient(ssh.get_transport()) scp.get(PATH_DB, PATH_DB_REMOTE) - local_db, remote_db = DbConnection(PATH_DB), DbConnection(PATH_DB_REMOTE) - for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile): - back_and_forth(sync_objects, (local_db, remote_db), - (cls, 'rel_path' if cls is VideoFile else 'id_')) - for yt_video_local in YoutubeVideo.get_all(local_db): - back_and_forth(sync_relations, (local_db, remote_db), yt_video_local) - local_db.commit_close() - remote_db.commit_close() + with DbConn(PATH_DB) as db_local, DbConn(PATH_DB_REMOTE) as db_remote: + for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile): + back_and_forth(sync_objects, (db_local, db_remote), + (cls, 'rel_path' if cls is VideoFile else 'id_')) + for yt_video_local in YoutubeVideo.get_all(db_local): + back_and_forth(sync_relations, (db_local, db_remote), + yt_video_local) + db_remote.commit() + db_local.commit() scp.put(PATH_DB_REMOTE, PATH_DB) PATH_DB_REMOTE.unlink() missings = [] diff --git a/src/ytplom/http.py b/src/ytplom/http.py index bdedc3e..9cfc48b 100644 --- a/src/ytplom/http.py +++ b/src/ytplom/http.py @@ -14,7 +14,7 @@ from ytplom.misc import ( QueryId, QueryText, QuotaCost, UrlStr, YoutubeId, FILE_FLAGS, PATH_APP_DATA, PATH_THUMBNAILS, YOUTUBE_URL_PREFIX, ensure_expected_dirs, - Config, DbConnection, DownloadsManager, Player, QuotaLog, VideoFile, + Config, DbConn, DownloadsManager, Player, QuotaLog, VideoFile, YoutubeQuery, YoutubeVideo ) @@ -131,7 +131,7 @@ class _TaskHandler(BaseHTTPRequestHandler): def _receive_files_command(self, command: str) -> None: if command.startswith('play_'): - with DbConnection() as conn: + with DbConn() as conn: file = VideoFile.get_by_b64(conn, B64Str(command.split('_', 1)[1])) self.server.player.inject_and_play(file) @@ -141,7 +141,7 @@ class _TaskHandler(BaseHTTPRequestHandler): rel_path_b64: B64Str, flag_names: list[FlagName] ) -> None: - with DbConnection() as conn: + with DbConn() as conn: file = VideoFile.get_by_b64(conn, rel_path_b64) file.set_flags([FILE_FLAGS[name] for name in flag_names]) file.save(conn) @@ -152,7 +152,7 @@ class _TaskHandler(BaseHTTPRequestHandler): .joinpath(rel_path_b64)) def _receive_yt_query(self, query_txt: QueryText) -> None: - with DbConnection() as conn: + with DbConn() as conn: query_data = YoutubeQuery.new_by_request_saved( conn, self.server.config, query_txt) conn.commit() @@ -219,7 +219,7 @@ class _TaskHandler(BaseHTTPRequestHandler): def _send_or_download_video(self, video_id: YoutubeId) -> None: try: - with DbConnection() as conn: + with DbConn() as conn: file_data = VideoFile.get_by_yt_id(conn, video_id) except NotFoundException: self.server.downloads.queue_download(video_id) @@ -232,7 +232,7 @@ class _TaskHandler(BaseHTTPRequestHandler): self._send_http(content=video) def _send_yt_query_page(self, query_id: QueryId) -> None: - with DbConnection() as conn: + with DbConn() as conn: query = YoutubeQuery.get_one(conn, str(query_id)) results = YoutubeVideo.get_all_for_query(conn, query_id) self._send_rendered_template( @@ -240,7 +240,7 @@ class _TaskHandler(BaseHTTPRequestHandler): {'query': query.text, 'videos': results}) def _send_yt_queries_index_and_search(self) -> None: - with DbConnection() as conn: + with DbConn() as conn: quota_count = QuotaLog.current(conn) queries_data = YoutubeQuery.get_all(conn) queries_data.sort(key=lambda q: q.retrieved_at, reverse=True) @@ -249,8 +249,8 @@ class _TaskHandler(BaseHTTPRequestHandler): {'queries': queries_data, 'quota_count': quota_count}) def _send_yt_result(self, video_id: YoutubeId) -> None: - conn = DbConnection() - with DbConnection() as conn: + conn = DbConn() + with DbConn() as conn: linked_queries = YoutubeQuery.get_all_for_video(conn, video_id) try: video_data = YoutubeVideo.get_one(conn, video_id) @@ -270,7 +270,7 @@ class _TaskHandler(BaseHTTPRequestHandler): 'queries': linked_queries}) def _send_file_data(self, rel_path_b64: B64Str) -> None: - with DbConnection() as conn: + with DbConn() as conn: file = VideoFile.get_by_b64(conn, rel_path_b64) self._send_rendered_template( _NAME_TEMPLATE_FILE_DATA, @@ -280,7 +280,7 @@ class _TaskHandler(BaseHTTPRequestHandler): filter_: _ParamsStr, show_absent: bool ) -> None: - with DbConnection() as conn: + with DbConn() as conn: files = [f for f in VideoFile.get_all(conn) if filter_.lower() in str(f.rel_path).lower() and (show_absent or f.present)] @@ -291,7 +291,7 @@ class _TaskHandler(BaseHTTPRequestHandler): 'show_absent': show_absent}) def _send_missing_json(self) -> None: - with DbConnection() as conn: + with DbConn() as conn: missing = [f.rel_path for f in VideoFile.get_all(conn) if f.missing] self._send_http(bytes(json_dumps(missing), 'utf8'), diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py index c466e35..c0aa85f 100644 --- a/src/ytplom/misc.py +++ b/src/ytplom/misc.py @@ -127,7 +127,7 @@ class Config: if k.isupper() and k.startswith(ENVIRON_PREFIX)}) -class DbConnection: +class DbConn: """Wrapped sqlite3.Connection.""" def __init__(self, path: Path = PATH_DB) -> None: @@ -191,7 +191,7 @@ class DbData: return cls(**kwargs) @classmethod - def get_one(cls, conn: DbConnection, id_: str) -> Self: + def get_one(cls, conn: DbConn, id_: str) -> Self: """Return single entry of id_ from DB.""" sql = SqlText(f'SELECT * FROM {cls._table_name} ' f'WHERE {cls.id_name} = ?') @@ -202,13 +202,13 @@ class DbData: return cls._from_table_row(row) @classmethod - def get_all(cls, conn: DbConnection) -> list[Self]: + def get_all(cls, conn: DbConn) -> list[Self]: """Return all entries from DB.""" sql = SqlText(f'SELECT * FROM {cls._table_name}') rows = conn.exec(sql).fetchall() return [cls._from_table_row(row) for row in rows] - def save(self, conn: DbConnection) -> Cursor: + def save(self, conn: DbConn) -> Cursor: """Save entry to DB.""" vals = [getattr(self, col_name) for col_name in self._cols] q_marks = '(' + ','.join(['?'] * len(vals)) + ')' @@ -233,7 +233,7 @@ class YoutubeQuery(DbData): @classmethod def new_by_request_saved(cls, - conn: DbConnection, + conn: DbConn, config: Config, query_txt: QueryText ) -> Self: @@ -284,7 +284,7 @@ class YoutubeQuery(DbData): @classmethod def get_all_for_video(cls, - conn: DbConnection, + conn: DbConn, video_id: YoutubeId ) -> list[Self]: """Return YoutubeQueries containing YoutubeVideo's ID in results.""" @@ -341,7 +341,7 @@ class YoutubeVideo(DbData): @classmethod def get_all_for_query(cls, - conn: DbConnection, + conn: DbConn, query_id: QueryId ) -> list[Self]: """Return all videos for query of query_id.""" @@ -351,7 +351,7 @@ class YoutubeVideo(DbData): return [cls.get_one(conn, video_id_tup[0]) for video_id_tup in video_ids] - def save_to_query(self, conn: DbConnection, query_id: QueryId) -> None: + def save_to_query(self, conn: DbConn, query_id: QueryId) -> None: """Save inclusion of self in results to query of query_id.""" conn.exec(SqlText('REPLACE INTO yt_query_results VALUES (?, ?)'), (query_id, self.id_)) @@ -383,7 +383,7 @@ class VideoFile(DbData): self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT)) @classmethod - def get_by_yt_id(cls, conn: DbConnection, yt_id: YoutubeId) -> Self: + def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self: """Return VideoFile of .yt_id.""" sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id = ?') row = conn.exec(sql, (yt_id,)).fetchone() @@ -392,7 +392,7 @@ class VideoFile(DbData): return cls._from_table_row(row) @classmethod - def get_by_b64(cls, conn: DbConnection, rel_path_b64: B64Str) -> Self: + def get_by_b64(cls, conn: DbConn, rel_path_b64: B64Str) -> Self: """Retrieve by .rel_path provided as urlsafe_b64 encoding.""" return cls.get_one(conn, urlsafe_b64decode(rel_path_b64).decode()) @@ -457,7 +457,7 @@ class QuotaLog(DbData): self.cost = cost @classmethod - def update(cls, conn: DbConnection, cost: QuotaCost) -> None: + def update(cls, conn: DbConn, cost: QuotaCost) -> None: """Adds cost mapped to current datetime.""" cls._remove_old(conn) new = cls(None, @@ -466,14 +466,14 @@ class QuotaLog(DbData): new.save(conn) @classmethod - def current(cls, conn: DbConnection) -> QuotaCost: + def current(cls, conn: DbConn) -> QuotaCost: """Returns quota cost total for last 24 hours, purges old data.""" cls._remove_old(conn) quota_costs = cls.get_all(conn) return QuotaCost(sum(c.cost for c in quota_costs)) @classmethod - def _remove_old(cls, conn: DbConnection) -> None: + def _remove_old(cls, conn: DbConn) -> None: cutoff = datetime.now() - timedelta(days=1) sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp < ?') conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),)) @@ -512,7 +512,7 @@ class Player: def load_files(self) -> None: """Collect files in PATH_DOWNLOADS DB-known and of legal extension.""" - with DbConnection() as conn: + with DbConn() as conn: known_files = {f.full_path: f for f in VideoFile.get_all(conn)} self._files = [known_files[p] for p in PATH_DOWNLOADS.iterdir() if p in known_files @@ -672,7 +672,7 @@ class DownloadsManager: self._sync_db() def _sync_db(self): - with DbConnection as conn: + with DbConn() as conn: known_paths = [file.rel_path for file in VideoFile.get_all(conn)] old_cwd = Path.cwd() chdir(PATH_DOWNLOADS) -- 2.30.2