From: Christian Heller Date: Sat, 30 Nov 2024 22:45:10 +0000 (+0100) Subject: Refactor DbConnection into context manager. X-Git-Url: https://plomlompom.com/repos/%7B%7B%20web_path%20%7D%7D/%7B%7Bdb.prefix%7D%7D/static/balance?a=commitdiff_plain;h=f997df770b4a4e6e615e013b0049fc2c88edefe5;p=ytplom Refactor DbConnection into context manager. --- diff --git a/src/ytplom/http.py b/src/ytplom/http.py index c1c427a..bdedc3e 100644 --- a/src/ytplom/http.py +++ b/src/ytplom/http.py @@ -131,9 +131,9 @@ class _TaskHandler(BaseHTTPRequestHandler): def _receive_files_command(self, command: str) -> None: if command.startswith('play_'): - conn = DbConnection() - file = VideoFile.get_by_b64(conn, B64Str(command.split('_', 1)[1])) - conn.commit_close() + with DbConnection() as conn: + file = VideoFile.get_by_b64(conn, + B64Str(command.split('_', 1)[1])) self.server.player.inject_and_play(file) self._redirect(Path('/')) @@ -141,21 +141,21 @@ class _TaskHandler(BaseHTTPRequestHandler): rel_path_b64: B64Str, flag_names: list[FlagName] ) -> None: - conn = DbConnection() - file = VideoFile.get_by_b64(conn, rel_path_b64) - file.set_flags([FILE_FLAGS[name] for name in flag_names]) - file.save(conn) - conn.commit_close() + with DbConnection() as conn: + file = VideoFile.get_by_b64(conn, rel_path_b64) + file.set_flags([FILE_FLAGS[name] for name in flag_names]) + file.save(conn) + conn.commit() file.ensure_absence_if_deleted() self._redirect(Path('/') .joinpath(PAGE_NAMES['file']) .joinpath(rel_path_b64)) def _receive_yt_query(self, query_txt: QueryText) -> None: - conn = DbConnection() - query_data = YoutubeQuery.new_by_request_saved( - conn, self.server.config, query_txt) - conn.commit_close() + with DbConnection() as conn: + query_data = YoutubeQuery.new_by_request_saved( + conn, self.server.config, query_txt) + conn.commit() self._redirect(Path('/') .joinpath(PAGE_NAMES['yt_query']) .joinpath(query_data.id_)) @@ -218,35 +218,31 @@ class _TaskHandler(BaseHTTPRequestHandler): self._send_http(img, [('Content-type', 'image/jpg')]) def _send_or_download_video(self, video_id: YoutubeId) -> None: - conn = DbConnection() try: - file_data = VideoFile.get_by_yt_id(conn, video_id) + with DbConnection() as conn: + file_data = VideoFile.get_by_yt_id(conn, video_id) except NotFoundException: - conn.commit_close() self.server.downloads.queue_download(video_id) self._redirect(Path('/') .joinpath(PAGE_NAMES['yt_result']) .joinpath(video_id)) return - conn.commit_close() with file_data.full_path.open('rb') as video_file: video = video_file.read() self._send_http(content=video) def _send_yt_query_page(self, query_id: QueryId) -> None: - conn = DbConnection() - query = YoutubeQuery.get_one(conn, str(query_id)) - results = YoutubeVideo.get_all_for_query(conn, query_id) - conn.commit_close() + with DbConnection() as conn: + query = YoutubeQuery.get_one(conn, str(query_id)) + results = YoutubeVideo.get_all_for_query(conn, query_id) self._send_rendered_template( _NAME_TEMPLATE_RESULTS, {'query': query.text, 'videos': results}) def _send_yt_queries_index_and_search(self) -> None: - conn = DbConnection() - quota_count = QuotaLog.current(conn) - queries_data = YoutubeQuery.get_all(conn) - conn.commit_close() + with DbConnection() as conn: + quota_count = QuotaLog.current(conn) + queries_data = YoutubeQuery.get_all(conn) queries_data.sort(key=lambda q: q.retrieved_at, reverse=True) self._send_rendered_template( _NAME_TEMPLATE_QUERIES, @@ -254,17 +250,17 @@ class _TaskHandler(BaseHTTPRequestHandler): def _send_yt_result(self, video_id: YoutubeId) -> None: conn = DbConnection() - linked_queries = YoutubeQuery.get_all_for_video(conn, video_id) - try: - video_data = YoutubeVideo.get_one(conn, video_id) - except NotFoundException: - video_data = YoutubeVideo(video_id) - try: - file = VideoFile.get_by_yt_id(conn, video_id) - file_path = file.full_path if file.present else None - except NotFoundException: - file_path = None - conn.commit_close() + with DbConnection() as conn: + linked_queries = YoutubeQuery.get_all_for_video(conn, video_id) + try: + video_data = YoutubeVideo.get_one(conn, video_id) + except NotFoundException: + video_data = YoutubeVideo(video_id) + try: + file = VideoFile.get_by_yt_id(conn, video_id) + file_path = file.full_path if file.present else None + except NotFoundException: + file_path = None self._send_rendered_template( _NAME_TEMPLATE_YT_VIDEO, {'video_data': video_data, @@ -274,9 +270,8 @@ class _TaskHandler(BaseHTTPRequestHandler): 'queries': linked_queries}) def _send_file_data(self, rel_path_b64: B64Str) -> None: - conn = DbConnection() - file = VideoFile.get_by_b64(conn, rel_path_b64) - conn.commit_close() + with DbConnection() as conn: + file = VideoFile.get_by_b64(conn, rel_path_b64) self._send_rendered_template( _NAME_TEMPLATE_FILE_DATA, {'file': file, 'flag_names': list(FILE_FLAGS)}) @@ -285,11 +280,10 @@ class _TaskHandler(BaseHTTPRequestHandler): filter_: _ParamsStr, show_absent: bool ) -> None: - conn = DbConnection() - files = [f for f in VideoFile.get_all(conn) - if filter_.lower() in str(f.rel_path).lower() - and (show_absent or f.present)] - conn.commit_close() + with DbConnection() as conn: + files = [f for f in VideoFile.get_all(conn) + if filter_.lower() in str(f.rel_path).lower() + and (show_absent or f.present)] files.sort(key=lambda t: t.rel_path) self._send_rendered_template( _NAME_TEMPLATE_FILES, @@ -297,9 +291,9 @@ class _TaskHandler(BaseHTTPRequestHandler): 'show_absent': show_absent}) def _send_missing_json(self) -> None: - conn = DbConnection() - missing = [f.rel_path for f in VideoFile.get_all(conn) if f.missing] - conn.commit_close() + with DbConnection() as conn: + missing = [f.rel_path for f in VideoFile.get_all(conn) + if f.missing] self._send_http(bytes(json_dumps(missing), 'utf8'), headers=[('Content-type', 'application/json')]) diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py index 012c3f5..c466e35 100644 --- a/src/ytplom/misc.py +++ b/src/ytplom/misc.py @@ -1,7 +1,7 @@ """Main ytplom lib.""" # included libs -from typing import Any, NewType, Optional, Self, TypeAlias +from typing import Any, Literal, NewType, Optional, Self, TypeAlias from os import chdir, environ from base64 import urlsafe_b64encode, urlsafe_b64decode from random import shuffle @@ -150,14 +150,20 @@ class DbConnection: f'{EXPECTED_DB_VERSION} – run "migrate"?') self._conn = sql_connect(self._path) + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]: + self._conn.close() + return False + def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor: """Wrapper around sqlite3.Connection.execute.""" return self._conn.execute(sql, inputs) - def commit_close(self) -> None: - """Run sqlite3.Connection.commit and .close.""" + def commit(self) -> None: + """Commit changes (i.e. DbData.save() calls) to database.""" self._conn.commit() - self._conn.close() class DbData: @@ -506,9 +512,8 @@ class Player: def load_files(self) -> None: """Collect files in PATH_DOWNLOADS DB-known and of legal extension.""" - conn = DbConnection() - known_files = {f.full_path: f for f in VideoFile.get_all(conn)} - conn.commit_close() + with DbConnection() as conn: + known_files = {f.full_path: f for f in VideoFile.get_all(conn)} self._files = [known_files[p] for p in PATH_DOWNLOADS.iterdir() if p in known_files and p.is_file() @@ -667,21 +672,21 @@ class DownloadsManager: self._sync_db() def _sync_db(self): - conn = DbConnection() - known_paths = [file.rel_path for file in VideoFile.get_all(conn)] - old_cwd = Path.cwd() - chdir(PATH_DOWNLOADS) - for path in [p for p in Path('.').iterdir() - if p.is_file() and p not in known_paths]: - yt_id = self._id_from_filename(path) - file = VideoFile(path, yt_id) - print(f'SYNC: new file {path}, saving with YT ID "{yt_id}".') - file.save(conn) - self._files = VideoFile.get_all(conn) - for file in self._files: - file.ensure_absence_if_deleted() - chdir(old_cwd) - conn.commit_close() + with DbConnection as conn: + known_paths = [file.rel_path for file in VideoFile.get_all(conn)] + old_cwd = Path.cwd() + chdir(PATH_DOWNLOADS) + for path in [p for p in Path('.').iterdir() + if p.is_file() and p not in known_paths]: + yt_id = self._id_from_filename(path) + file = VideoFile(path, yt_id) + print(f'SYNC: new file {path}, saving with YT ID "{yt_id}".') + file.save(conn) + self._files = VideoFile.get_all(conn) + for file in self._files: + file.ensure_absence_if_deleted() + chdir(old_cwd) + conn.commit() @staticmethod def _id_from_filename(path: Path) -> YoutubeId: