From f997df770b4a4e6e615e013b0049fc2c88edefe5 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sat, 30 Nov 2024 23:45:10 +0100
Subject: [PATCH] Refactor DbConnection into context manager.

---
 src/ytplom/http.py | 86 +++++++++++++++++++++-------------------------
 src/ytplom/misc.py | 49 ++++++++++++++------------
 2 files changed, 67 insertions(+), 68 deletions(-)

diff --git a/src/ytplom/http.py b/src/ytplom/http.py
index c1c427a..bdedc3e 100644
--- a/src/ytplom/http.py
+++ b/src/ytplom/http.py
@@ -131,9 +131,9 @@ class _TaskHandler(BaseHTTPRequestHandler):
 
     def _receive_files_command(self, command: str) -> None:
         if command.startswith('play_'):
-            conn = DbConnection()
-            file = VideoFile.get_by_b64(conn, B64Str(command.split('_', 1)[1]))
-            conn.commit_close()
+            with DbConnection() as conn:
+                file = VideoFile.get_by_b64(conn,
+                                            B64Str(command.split('_', 1)[1]))
             self.server.player.inject_and_play(file)
         self._redirect(Path('/'))
 
@@ -141,21 +141,21 @@ class _TaskHandler(BaseHTTPRequestHandler):
                             rel_path_b64: B64Str,
                             flag_names: list[FlagName]
                             ) -> None:
-        conn = DbConnection()
-        file = VideoFile.get_by_b64(conn, rel_path_b64)
-        file.set_flags([FILE_FLAGS[name] for name in flag_names])
-        file.save(conn)
-        conn.commit_close()
+        with DbConnection() as conn:
+            file = VideoFile.get_by_b64(conn, rel_path_b64)
+            file.set_flags([FILE_FLAGS[name] for name in flag_names])
+            file.save(conn)
+            conn.commit()
         file.ensure_absence_if_deleted()
         self._redirect(Path('/')
                        .joinpath(PAGE_NAMES['file'])
                        .joinpath(rel_path_b64))
 
     def _receive_yt_query(self, query_txt: QueryText) -> None:
-        conn = DbConnection()
-        query_data = YoutubeQuery.new_by_request_saved(
-                conn, self.server.config, query_txt)
-        conn.commit_close()
+        with DbConnection() as conn:
+            query_data = YoutubeQuery.new_by_request_saved(
+                    conn, self.server.config, query_txt)
+            conn.commit()
         self._redirect(Path('/')
                        .joinpath(PAGE_NAMES['yt_query'])
                        .joinpath(query_data.id_))
@@ -218,35 +218,31 @@ class _TaskHandler(BaseHTTPRequestHandler):
         self._send_http(img, [('Content-type', 'image/jpg')])
 
     def _send_or_download_video(self, video_id: YoutubeId) -> None:
-        conn = DbConnection()
         try:
-            file_data = VideoFile.get_by_yt_id(conn, video_id)
+            with DbConnection() as conn:
+                file_data = VideoFile.get_by_yt_id(conn, video_id)
         except NotFoundException:
-            conn.commit_close()
             self.server.downloads.queue_download(video_id)
             self._redirect(Path('/')
                            .joinpath(PAGE_NAMES['yt_result'])
                            .joinpath(video_id))
             return
-        conn.commit_close()
         with file_data.full_path.open('rb') as video_file:
             video = video_file.read()
         self._send_http(content=video)
 
     def _send_yt_query_page(self, query_id: QueryId) -> None:
-        conn = DbConnection()
-        query = YoutubeQuery.get_one(conn, str(query_id))
-        results = YoutubeVideo.get_all_for_query(conn, query_id)
-        conn.commit_close()
+        with DbConnection() as conn:
+            query = YoutubeQuery.get_one(conn, str(query_id))
+            results = YoutubeVideo.get_all_for_query(conn, query_id)
         self._send_rendered_template(
                 _NAME_TEMPLATE_RESULTS,
                 {'query': query.text, 'videos': results})
 
     def _send_yt_queries_index_and_search(self) -> None:
-        conn = DbConnection()
-        quota_count = QuotaLog.current(conn)
-        queries_data = YoutubeQuery.get_all(conn)
-        conn.commit_close()
+        with DbConnection() as conn:
+            quota_count = QuotaLog.current(conn)
+            queries_data = YoutubeQuery.get_all(conn)
         queries_data.sort(key=lambda q: q.retrieved_at, reverse=True)
         self._send_rendered_template(
                 _NAME_TEMPLATE_QUERIES,
@@ -254,17 +250,17 @@ class _TaskHandler(BaseHTTPRequestHandler):
 
     def _send_yt_result(self, video_id: YoutubeId) -> None:
         conn = DbConnection()
-        linked_queries = YoutubeQuery.get_all_for_video(conn, video_id)
-        try:
-            video_data = YoutubeVideo.get_one(conn, video_id)
-        except NotFoundException:
-            video_data = YoutubeVideo(video_id)
-        try:
-            file = VideoFile.get_by_yt_id(conn, video_id)
-            file_path = file.full_path if file.present else None
-        except NotFoundException:
-            file_path = None
-        conn.commit_close()
+        with DbConnection() as conn:
+            linked_queries = YoutubeQuery.get_all_for_video(conn, video_id)
+            try:
+                video_data = YoutubeVideo.get_one(conn, video_id)
+            except NotFoundException:
+                video_data = YoutubeVideo(video_id)
+            try:
+                file = VideoFile.get_by_yt_id(conn, video_id)
+                file_path = file.full_path if file.present else None
+            except NotFoundException:
+                file_path = None
         self._send_rendered_template(
                 _NAME_TEMPLATE_YT_VIDEO,
                 {'video_data': video_data,
@@ -274,9 +270,8 @@ class _TaskHandler(BaseHTTPRequestHandler):
                  'queries': linked_queries})
 
     def _send_file_data(self, rel_path_b64: B64Str) -> None:
-        conn = DbConnection()
-        file = VideoFile.get_by_b64(conn, rel_path_b64)
-        conn.commit_close()
+        with DbConnection() as conn:
+            file = VideoFile.get_by_b64(conn, rel_path_b64)
         self._send_rendered_template(
                 _NAME_TEMPLATE_FILE_DATA,
                 {'file': file, 'flag_names': list(FILE_FLAGS)})
@@ -285,11 +280,10 @@ class _TaskHandler(BaseHTTPRequestHandler):
                           filter_: _ParamsStr,
                           show_absent: bool
                           ) -> None:
-        conn = DbConnection()
-        files = [f for f in VideoFile.get_all(conn)
-                 if filter_.lower() in str(f.rel_path).lower()
-                 and (show_absent or f.present)]
-        conn.commit_close()
+        with DbConnection() as conn:
+            files = [f for f in VideoFile.get_all(conn)
+                     if filter_.lower() in str(f.rel_path).lower()
+                     and (show_absent or f.present)]
         files.sort(key=lambda t: t.rel_path)
         self._send_rendered_template(
                 _NAME_TEMPLATE_FILES,
@@ -297,9 +291,9 @@ class _TaskHandler(BaseHTTPRequestHandler):
                  'show_absent': show_absent})
 
     def _send_missing_json(self) -> None:
-        conn = DbConnection()
-        missing = [f.rel_path for f in VideoFile.get_all(conn) if f.missing]
-        conn.commit_close()
+        with DbConnection() as conn:
+            missing = [f.rel_path for f in VideoFile.get_all(conn)
+                       if f.missing]
         self._send_http(bytes(json_dumps(missing), 'utf8'),
                         headers=[('Content-type', 'application/json')])
 
diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py
index 012c3f5..c466e35 100644
--- a/src/ytplom/misc.py
+++ b/src/ytplom/misc.py
@@ -1,7 +1,7 @@
 """Main ytplom lib."""
 
 # included libs
-from typing import Any, NewType, Optional, Self, TypeAlias
+from typing import Any, Literal, NewType, Optional, Self, TypeAlias
 from os import chdir, environ
 from base64 import urlsafe_b64encode, urlsafe_b64decode
 from random import shuffle
@@ -150,14 +150,20 @@ class DbConnection:
                     f'{EXPECTED_DB_VERSION} – run "migrate"?')
         self._conn = sql_connect(self._path)
 
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]:
+        self._conn.close()
+        return False
+
     def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor:
         """Wrapper around sqlite3.Connection.execute."""
         return self._conn.execute(sql, inputs)
 
-    def commit_close(self) -> None:
-        """Run sqlite3.Connection.commit and .close."""
+    def commit(self) -> None:
+        """Commit changes (i.e. DbData.save() calls) to database."""
         self._conn.commit()
-        self._conn.close()
 
 
 class DbData:
@@ -506,9 +512,8 @@ class Player:
 
     def load_files(self) -> None:
         """Collect files in PATH_DOWNLOADS DB-known and of legal extension."""
-        conn = DbConnection()
-        known_files = {f.full_path: f for f in VideoFile.get_all(conn)}
-        conn.commit_close()
+        with DbConnection() as conn:
+            known_files = {f.full_path: f for f in VideoFile.get_all(conn)}
         self._files = [known_files[p] for p in PATH_DOWNLOADS.iterdir()
                        if p in known_files
                        and p.is_file()
@@ -667,21 +672,21 @@ class DownloadsManager:
         self._sync_db()
 
     def _sync_db(self):
-        conn = DbConnection()
-        known_paths = [file.rel_path for file in VideoFile.get_all(conn)]
-        old_cwd = Path.cwd()
-        chdir(PATH_DOWNLOADS)
-        for path in [p for p in Path('.').iterdir()
-                     if p.is_file() and p not in known_paths]:
-            yt_id = self._id_from_filename(path)
-            file = VideoFile(path, yt_id)
-            print(f'SYNC: new file {path}, saving with YT ID "{yt_id}".')
-            file.save(conn)
-        self._files = VideoFile.get_all(conn)
-        for file in self._files:
-            file.ensure_absence_if_deleted()
-        chdir(old_cwd)
-        conn.commit_close()
+        with DbConnection as conn:
+            known_paths = [file.rel_path for file in VideoFile.get_all(conn)]
+            old_cwd = Path.cwd()
+            chdir(PATH_DOWNLOADS)
+            for path in [p for p in Path('.').iterdir()
+                         if p.is_file() and p not in known_paths]:
+                yt_id = self._id_from_filename(path)
+                file = VideoFile(path, yt_id)
+                print(f'SYNC: new file {path}, saving with YT ID "{yt_id}".')
+                file.save(conn)
+            self._files = VideoFile.get_all(conn)
+            for file in self._files:
+                file.ensure_absence_if_deleted()
+            chdir(old_cwd)
+            conn.commit()
 
     @staticmethod
     def _id_from_filename(path: Path) -> YoutubeId:
-- 
2.30.2