From 2c02e74e57e74ec53f3fb5c61516d5a9b12a8e0b Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sat, 30 Nov 2024 23:55:21 +0100
Subject: [PATCH] Minor DbConnection usage fixes, a renaming for convenience.

---
 src/sync.py        | 25 +++++++++++++------------
 src/ytplom/http.py | 24 ++++++++++++------------
 src/ytplom/misc.py | 30 +++++++++++++++---------------
 3 files changed, 40 insertions(+), 39 deletions(-)

diff --git a/src/sync.py b/src/sync.py
index 757deb5..641e78c 100755
--- a/src/sync.py
+++ b/src/sync.py
@@ -11,7 +11,7 @@ from paramiko import SSHClient  # type: ignore
 from scp import SCPClient  # type: ignore
 from ytplom.misc import (
         PATH_DB, PATH_DOWNLOADS, PATH_TEMP,
-        Config, DbConnection, QuotaLog, VideoFile, YoutubeQuery, YoutubeVideo)
+        Config, DbConn, QuotaLog, VideoFile, YoutubeQuery, YoutubeVideo)
 from ytplom.http import PAGE_NAMES
 
 
@@ -20,7 +20,7 @@ ATTR_NAME_LAST_UPDATE = 'last_update'
 
 
 def back_and_forth(sync_func: Callable,
-                   dbs: tuple[DbConnection, DbConnection],
+                   dbs: tuple[DbConn, DbConn],
                    shared: YoutubeVideo | tuple[Any, str]
                    ) -> None:
     """Run sync_func on arg pairs + shared, and again with pairs switched."""
@@ -30,7 +30,7 @@ def back_and_forth(sync_func: Callable,
 
 
 def sync_objects(host_names: tuple[str, str],
-                 dbs: tuple[DbConnection, DbConnection],
+                 dbs: tuple[DbConn, DbConn],
                  shared: tuple[Any, str]
                  ) -> None:
     """Equalize both DB's object collections; prefer newer states to older."""
@@ -58,7 +58,7 @@ def sync_objects(host_names: tuple[str, str],
 
 
 def sync_relations(host_names: tuple[str, str],
-                   dbs: tuple[DbConnection, DbConnection],
+                   dbs: tuple[DbConn, DbConn],
                    yt_result: YoutubeVideo
                    ) -> None:
     """To dbs[1] add YT yt_video->yt_q_colls[0] mapping not in yt_q_colls[1]"""
@@ -78,14 +78,15 @@ def main():
     ssh.connect(config.remote)
     scp = SCPClient(ssh.get_transport())
     scp.get(PATH_DB, PATH_DB_REMOTE)
-    local_db, remote_db = DbConnection(PATH_DB), DbConnection(PATH_DB_REMOTE)
-    for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile):
-        back_and_forth(sync_objects, (local_db, remote_db),
-                       (cls, 'rel_path' if cls is VideoFile else 'id_'))
-    for yt_video_local in YoutubeVideo.get_all(local_db):
-        back_and_forth(sync_relations, (local_db, remote_db), yt_video_local)
-    local_db.commit_close()
-    remote_db.commit_close()
+    with DbConn(PATH_DB) as db_local, DbConn(PATH_DB_REMOTE) as db_remote:
+        for cls in (QuotaLog, YoutubeQuery, YoutubeVideo, VideoFile):
+            back_and_forth(sync_objects, (db_local, db_remote),
+                           (cls, 'rel_path' if cls is VideoFile else 'id_'))
+        for yt_video_local in YoutubeVideo.get_all(db_local):
+            back_and_forth(sync_relations, (db_local, db_remote),
+                           yt_video_local)
+        db_remote.commit()
+        db_local.commit()
     scp.put(PATH_DB_REMOTE, PATH_DB)
     PATH_DB_REMOTE.unlink()
     missings = []
diff --git a/src/ytplom/http.py b/src/ytplom/http.py
index bdedc3e..9cfc48b 100644
--- a/src/ytplom/http.py
+++ b/src/ytplom/http.py
@@ -14,7 +14,7 @@ from ytplom.misc import (
         QueryId, QueryText, QuotaCost, UrlStr, YoutubeId,
         FILE_FLAGS, PATH_APP_DATA, PATH_THUMBNAILS, YOUTUBE_URL_PREFIX,
         ensure_expected_dirs,
-        Config, DbConnection, DownloadsManager, Player, QuotaLog, VideoFile,
+        Config, DbConn, DownloadsManager, Player, QuotaLog, VideoFile,
         YoutubeQuery, YoutubeVideo
 )
 
@@ -131,7 +131,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
 
     def _receive_files_command(self, command: str) -> None:
         if command.startswith('play_'):
-            with DbConnection() as conn:
+            with DbConn() as conn:
                 file = VideoFile.get_by_b64(conn,
                                             B64Str(command.split('_', 1)[1]))
             self.server.player.inject_and_play(file)
@@ -141,7 +141,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                             rel_path_b64: B64Str,
                             flag_names: list[FlagName]
                             ) -> None:
-        with DbConnection() as conn:
+        with DbConn() as conn:
             file = VideoFile.get_by_b64(conn, rel_path_b64)
             file.set_flags([FILE_FLAGS[name] for name in flag_names])
             file.save(conn)
@@ -152,7 +152,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                        .joinpath(rel_path_b64))
 
     def _receive_yt_query(self, query_txt: QueryText) -> None:
-        with DbConnection() as conn:
+        with DbConn() as conn:
             query_data = YoutubeQuery.new_by_request_saved(
                     conn, self.server.config, query_txt)
             conn.commit()
@@ -219,7 +219,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
 
     def _send_or_download_video(self, video_id: YoutubeId) -> None:
         try:
-            with DbConnection() as conn:
+            with DbConn() as conn:
                 file_data = VideoFile.get_by_yt_id(conn, video_id)
         except NotFoundException:
             self.server.downloads.queue_download(video_id)
@@ -232,7 +232,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
         self._send_http(content=video)
 
     def _send_yt_query_page(self, query_id: QueryId) -> None:
-        with DbConnection() as conn:
+        with DbConn() as conn:
             query = YoutubeQuery.get_one(conn, str(query_id))
             results = YoutubeVideo.get_all_for_query(conn, query_id)
         self._send_rendered_template(
@@ -240,7 +240,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                 {'query': query.text, 'videos': results})
 
     def _send_yt_queries_index_and_search(self) -> None:
-        with DbConnection() as conn:
+        with DbConn() as conn:
             quota_count = QuotaLog.current(conn)
             queries_data = YoutubeQuery.get_all(conn)
         queries_data.sort(key=lambda q: q.retrieved_at, reverse=True)
@@ -249,8 +249,8 @@ class _TaskHandler(BaseHTTPRequestHandler):
                 {'queries': queries_data, 'quota_count': quota_count})
 
     def _send_yt_result(self, video_id: YoutubeId) -> None:
-        conn = DbConnection()
-        with DbConnection() as conn:
+        conn = DbConn()
+        with DbConn() as conn:
             linked_queries = YoutubeQuery.get_all_for_video(conn, video_id)
             try:
                 video_data = YoutubeVideo.get_one(conn, video_id)
@@ -270,7 +270,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                  'queries': linked_queries})
 
     def _send_file_data(self, rel_path_b64: B64Str) -> None:
-        with DbConnection() as conn:
+        with DbConn() as conn:
             file = VideoFile.get_by_b64(conn, rel_path_b64)
         self._send_rendered_template(
                 _NAME_TEMPLATE_FILE_DATA,
@@ -280,7 +280,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                           filter_: _ParamsStr,
                           show_absent: bool
                           ) -> None:
-        with DbConnection() as conn:
+        with DbConn() as conn:
             files = [f for f in VideoFile.get_all(conn)
                      if filter_.lower() in str(f.rel_path).lower()
                      and (show_absent or f.present)]
@@ -291,7 +291,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                  'show_absent': show_absent})
 
     def _send_missing_json(self) -> None:
-        with DbConnection() as conn:
+        with DbConn() as conn:
             missing = [f.rel_path for f in VideoFile.get_all(conn)
                        if f.missing]
         self._send_http(bytes(json_dumps(missing), 'utf8'),
diff --git a/src/ytplom/misc.py b/src/ytplom/misc.py
index c466e35..c0aa85f 100644
--- a/src/ytplom/misc.py
+++ b/src/ytplom/misc.py
@@ -127,7 +127,7 @@ class Config:
                              if k.isupper() and k.startswith(ENVIRON_PREFIX)})
 
 
-class DbConnection:
+class DbConn:
     """Wrapped sqlite3.Connection."""
 
     def __init__(self, path: Path = PATH_DB) -> None:
@@ -191,7 +191,7 @@ class DbData:
         return cls(**kwargs)
 
     @classmethod
-    def get_one(cls, conn: DbConnection, id_: str) -> Self:
+    def get_one(cls, conn: DbConn, id_: str) -> Self:
         """Return single entry of id_ from DB."""
         sql = SqlText(f'SELECT * FROM {cls._table_name} '
                       f'WHERE {cls.id_name} = ?')
@@ -202,13 +202,13 @@ class DbData:
         return cls._from_table_row(row)
 
     @classmethod
-    def get_all(cls, conn: DbConnection) -> list[Self]:
+    def get_all(cls, conn: DbConn) -> list[Self]:
         """Return all entries from DB."""
         sql = SqlText(f'SELECT * FROM {cls._table_name}')
         rows = conn.exec(sql).fetchall()
         return [cls._from_table_row(row) for row in rows]
 
-    def save(self, conn: DbConnection) -> Cursor:
+    def save(self, conn: DbConn) -> Cursor:
         """Save entry to DB."""
         vals = [getattr(self, col_name) for col_name in self._cols]
         q_marks = '(' + ','.join(['?'] * len(vals)) + ')'
@@ -233,7 +233,7 @@ class YoutubeQuery(DbData):
 
     @classmethod
     def new_by_request_saved(cls,
-                             conn: DbConnection,
+                             conn: DbConn,
                              config: Config,
                              query_txt: QueryText
                              ) -> Self:
@@ -284,7 +284,7 @@ class YoutubeQuery(DbData):
 
     @classmethod
     def get_all_for_video(cls,
-                          conn: DbConnection,
+                          conn: DbConn,
                           video_id: YoutubeId
                           ) -> list[Self]:
         """Return YoutubeQueries containing YoutubeVideo's ID in results."""
@@ -341,7 +341,7 @@ class YoutubeVideo(DbData):
 
     @classmethod
     def get_all_for_query(cls,
-                          conn: DbConnection,
+                          conn: DbConn,
                           query_id: QueryId
                           ) -> list[Self]:
         """Return all videos for query of query_id."""
@@ -351,7 +351,7 @@ class YoutubeVideo(DbData):
         return [cls.get_one(conn, video_id_tup[0])
                 for video_id_tup in video_ids]
 
-    def save_to_query(self, conn: DbConnection, query_id: QueryId) -> None:
+    def save_to_query(self, conn: DbConn, query_id: QueryId) -> None:
         """Save inclusion of self in results to query of query_id."""
         conn.exec(SqlText('REPLACE INTO yt_query_results VALUES (?, ?)'),
                   (query_id, self.id_))
@@ -383,7 +383,7 @@ class VideoFile(DbData):
         self.last_update = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT))
 
     @classmethod
-    def get_by_yt_id(cls, conn: DbConnection, yt_id: YoutubeId) -> Self:
+    def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self:
         """Return VideoFile of .yt_id."""
         sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id = ?')
         row = conn.exec(sql, (yt_id,)).fetchone()
@@ -392,7 +392,7 @@ class VideoFile(DbData):
         return cls._from_table_row(row)
 
     @classmethod
-    def get_by_b64(cls, conn: DbConnection, rel_path_b64: B64Str) -> Self:
+    def get_by_b64(cls, conn: DbConn, rel_path_b64: B64Str) -> Self:
         """Retrieve by .rel_path provided as urlsafe_b64 encoding."""
         return cls.get_one(conn, urlsafe_b64decode(rel_path_b64).decode())
 
@@ -457,7 +457,7 @@ class QuotaLog(DbData):
         self.cost = cost
 
     @classmethod
-    def update(cls, conn: DbConnection, cost: QuotaCost) -> None:
+    def update(cls, conn: DbConn, cost: QuotaCost) -> None:
         """Adds cost mapped to current datetime."""
         cls._remove_old(conn)
         new = cls(None,
@@ -466,14 +466,14 @@ class QuotaLog(DbData):
         new.save(conn)
 
     @classmethod
-    def current(cls, conn: DbConnection) -> QuotaCost:
+    def current(cls, conn: DbConn) -> QuotaCost:
         """Returns quota cost total for last 24 hours, purges old data."""
         cls._remove_old(conn)
         quota_costs = cls.get_all(conn)
         return QuotaCost(sum(c.cost for c in quota_costs))
 
     @classmethod
-    def _remove_old(cls, conn: DbConnection) -> None:
+    def _remove_old(cls, conn: DbConn) -> None:
         cutoff = datetime.now() - timedelta(days=1)
         sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp < ?')
         conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),))
@@ -512,7 +512,7 @@ class Player:
 
     def load_files(self) -> None:
         """Collect files in PATH_DOWNLOADS DB-known and of legal extension."""
-        with DbConnection() as conn:
+        with DbConn() as conn:
             known_files = {f.full_path: f for f in VideoFile.get_all(conn)}
         self._files = [known_files[p] for p in PATH_DOWNLOADS.iterdir()
                        if p in known_files
@@ -672,7 +672,7 @@ class DownloadsManager:
         self._sync_db()
 
     def _sync_db(self):
-        with DbConnection as conn:
+        with DbConn() as conn:
             known_paths = [file.rel_path for file in VideoFile.get_all(conn)]
             old_cwd = Path.cwd()
             chdir(PATH_DOWNLOADS)
-- 
2.30.2