home · contact · privacy
Adapt web server code to plomlib.web.
authorChristian Heller <c.heller@plomlompom.de>
Sat, 18 Jan 2025 02:27:51 +0000 (03:27 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Sat, 18 Jan 2025 02:27:51 +0000 (03:27 +0100)
src/plomlib
src/ytplom/http.py

index 743dbe0d493ddeb47eca981fa5be6d78e4d754c9..e7202fcfd78c6a60bd90da789a68c8ec4baf7b1a 160000 (submodule)
@@ -1 +1 @@
-Subproject commit 743dbe0d493ddeb47eca981fa5be6d78e4d754c9
+Subproject commit e7202fcfd78c6a60bd90da789a68c8ec4baf7b1a
index c57dfaea10a034f9f898dd000ec13ea43102625a..c91100c37250873fafda692a42c90470cb03be61 100644 (file)
@@ -1,18 +1,16 @@
 """Collect directly HTTP-related elements."""
 
 # included libs
-from http.server import ThreadingHTTPServer, BaseHTTPRequestHandler
-from json import dumps as json_dumps, loads as json_loads
+from json import dumps as json_dumps
 from pathlib import Path
+from socketserver import ThreadingMixIn
 from time import sleep, time
 from typing import Any, Optional
-from urllib.parse import parse_qs, urlparse
 from urllib.request import urlretrieve
 from urllib.error import HTTPError
-# non-included libs
-from jinja2 import (  # type: ignore
-        Environment as JinjaEnv, FileSystemLoader as JinjaFSLoader)
 # ourselves
+from plomlib.web import (
+        PlomHttpHandler, PlomHttpServer, PlomQueryMap, MIME_APP_JSON)
 from ytplom.db import Hash, DbConn
 from ytplom.misc import (
         FilterStr, FlagName, QueryId, QueryText, TagSet, YoutubeId,
@@ -38,36 +36,35 @@ _NAME_TEMPLATE_YT_RESULT = Path('yt_result.tmpl')
 _NAME_TEMPLATE_YT_RESULTS = Path('yt_results.tmpl')
 
 # page names
-PAGE_NAMES: dict[str, Path] = {
-    'download': Path('dl'),
-    'events': Path('events'),
-    'file': Path('file'),
-    'files': Path('files'),
-    'missing': Path('missing'),
-    'player': Path('player'),
-    'playlist': Path('playlist'),
-    'purge': Path('purge'),
-    'thumbnails': Path('thumbnails'),
-    'yt_queries': Path('yt_queries'),
-    'yt_query': Path('yt_query'),
-    'yt_result': Path('yt_result')
+PAGE_NAMES: dict[str, str] = {
+    'download': 'dl',
+    'events': 'events',
+    'file': 'file',
+    'files': 'files',
+    'missing': 'missing',
+    'player': 'player',
+    'playlist': 'playlist',
+    'purge': 'purge',
+    'thumbnails': 'thumbnails',
+    'yt_queries': 'yt_queries',
+    'yt_query': 'yt_query',
+    'yt_result': 'yt_result'
 }
 
 # misc
 _PING_INTERVAL_S = 1
 _EVENTS_UPDATE_INTERVAL_S = 0.1
 _HEADER_CONTENT_TYPE = 'Content-Type'
-_HEADER_APP_JSON = 'application/json'
 
 
-class Server(ThreadingHTTPServer):
+class Server(ThreadingMixIn, PlomHttpServer):
     """Extension of parent server providing for Player and DownloadsManager."""
+    daemon_threads = True
 
     def __init__(self, config: Config, *args, **kwargs) -> None:
-        super().__init__(
-                (config.host, config.port), _TaskHandler, *args, **kwargs)
+        super().__init__(_PATH_TEMPLATES, (config.host, config.port),
+                         _TaskHandler, *args, **kwargs)
         self.config = config
-        self.jinja = JinjaEnv(loader=JinjaFSLoader(_PATH_TEMPLATES))
         self.player = Player(config.whitelist_tags_display,
                              config.whitelist_tags_prefilter,
                              config.needed_tags_prefilter)
@@ -76,95 +73,77 @@ class Server(ThreadingHTTPServer):
         self.downloads.start_thread()
 
 
-class _ReqMap:
+class _ReqMap(PlomQueryMap):
     """Wrapper over dictionary-like HTTP postings."""
 
-    def __init__(self, as_str: str, is_json: bool = False) -> None:
-        self._as_dict = json_loads(as_str) if is_json else parse_qs(as_str)
-
     def has_key(self, key: str) -> bool:
         """Return if key exists at all."""
-        return key in self._as_dict
+        return key in self.as_dict
 
     def first_for(self, key: str) -> str:
         """Return first value mapped to key, '' if none."""
-        return self._as_dict.get(key, [''])[0]
+        return self.first(key) or ''
 
     def all_for(self, key: str) -> list[str]:
         """Return all values mapped to key."""
-        return self._as_dict.get(key, [])
+        return self.all(key) or []
 
     def keys_starting_with(self, prefix: str) -> tuple[str, ...]:
         """Return all keys present starting with prefix."""
-        return tuple(k for k in self._as_dict if k.startswith(prefix))
+        return self.keys_prefixed(prefix)
 
 
-class _TaskHandler(BaseHTTPRequestHandler):
+class _TaskHandler(PlomHttpHandler):
     """Handler for GET and POST requests to our server."""
     server: Server
-
-    def _send_http(self,
-                   content: str | bytes = b'',
-                   headers: Optional[list[tuple[str, str]]] = None,
-                   code: int = 200
-                   ) -> None:
-        headers = headers if headers else []
-        self.send_response(code)
-        for header_tuple in headers:
-            self.send_header(header_tuple[0], header_tuple[1])
-        self.end_headers()
-        if content:
-            self.wfile.write(bytes(content, 'utf8') if isinstance(content, str)
-                             else content)
+    params: _ReqMap
+    postvars: _ReqMap
+    mapper = _ReqMap
 
     def _redirect(self, target: Path) -> None:
-        self._send_http(headers=[('Location', str(target))], code=302)
+        self.redirect(target)
 
     def do_POST(self) -> None:  # pylint:disable=invalid-name
         """Map POST requests to handlers for various paths."""
-        toks_url = Path(urlparse(self.path).path).parts
-        page_name = Path(toks_url[1] if len(toks_url) > 1 else '')
-        postvars = _ReqMap(
-                self.rfile.read(int(self.headers['content-length'])).decode(),
-                _HEADER_APP_JSON == self.headers[_HEADER_CONTENT_TYPE])
-        if PAGE_NAMES['file'] == page_name:
-            self._receive_file_data(Hash.from_b64(toks_url[2]), postvars)
-        elif PAGE_NAMES['files'] == page_name:
-            self._receive_files_command(postvars)
-        elif PAGE_NAMES['player'] == page_name:
-            self._receive_player_command(postvars)
-        elif PAGE_NAMES['purge'] == page_name:
+        if self.pagename == PAGE_NAMES['file']:
+            self._receive_file_data()
+        elif self.pagename == PAGE_NAMES['files']:
+            self._receive_files_command()
+        elif self.pagename == PAGE_NAMES['player']:
+            self._receive_player_command()
+        elif self.pagename == PAGE_NAMES['purge']:
             self._purge_deleted_files()
-        elif PAGE_NAMES['yt_queries'] == page_name:
-            self._receive_yt_query(QueryText(postvars.first_for('query')))
+        elif self.pagename == PAGE_NAMES['yt_queries']:
+            self._receive_yt_query()
 
-    def _receive_file_data(self, digest: Hash, postvars: _ReqMap) -> None:
+    def _receive_file_data(self) -> None:
+        digest = Hash.from_b64(self.path_toks[2])
         if not (self.server.config.allow_file_edit  # also if whitelist, …
                 and self.server.config.whitelist_tags_display.empty):
-            self._send_http('no way', code=403)  # … cuz input form under …
+            self.send_http(b'no way', code=403)  # … cuz input form under …
             return  # … this display filter might have suppressed set tags
         with DbConn() as conn:
             file = VideoFile.get_one(conn, digest)
-            if postvars.has_key('unlink'):
+            if self.postvars.has_key('unlink'):
                 file.unlink_locally()
             file.set_flags({FILE_FLAGS[FlagName(name)]
-                            for name in postvars.all_for('flags')})
-            file.tags = TagSet.from_str_list(postvars.all_for('tags'))
+                            for name in self.postvars.all_for('flags')})
+            file.tags = TagSet.from_str_list(self.postvars.all_for('tags'))
             file.save(conn)
             conn.commit()
         file.ensure_absence_if_deleted()
-        self._redirect(Path(postvars.first_for('redir_target')))
+        self._redirect(Path(self.postvars.first_for('redir_target')))
 
-    def _receive_files_command(self, postvars: _ReqMap) -> None:
-        for k in postvars.keys_starting_with('play_'):
+    def _receive_files_command(self) -> None:
+        for k in self.postvars.keys_starting_with('play_'):
             with DbConn() as conn:
                 file = VideoFile.get_one(
                         conn, Hash.from_b64(k.split('_', 1)[1]))
             self.server.player.inject_and_play(file)
-        self._redirect(Path(postvars.first_for('redir_target')))
+        self._redirect(Path(self.postvars.first_for('redir_target')))
 
-    def _receive_player_command(self, postvars: _ReqMap) -> None:
-        command = postvars.first_for('command')
+    def _receive_player_command(self) -> None:
+        command = self.postvars.first_for('command')
         if 'play' == command:
             self.server.player.toggle_play()
         elif 'prev' == command:
@@ -179,22 +158,23 @@ class _TaskHandler(BaseHTTPRequestHandler):
             self.server.player.move_entry(int(command.split('_')[1]))
         elif command.startswith('down_'):
             self.server.player.move_entry(int(command.split('_')[1]), False)
-        if postvars.has_key('filter_path'):
+        if self.postvars.has_key('filter_path'):
             self.server.player.filter_path = FilterStr(
-                    postvars.first_for('filter_path'))
-        if postvars.has_key('needed_tags'):
+                    self.postvars.first_for('filter_path'))
+        if self.postvars.has_key('needed_tags'):
             self.server.player.needed_tags = TagSet.from_joined(
-                    postvars.first_for('needed_tags'))
-        self._send_http('OK')
+                    self.postvars.first_for('needed_tags'))
+        self.send_http(b'OK')
 
     def _purge_deleted_files(self) -> None:
         with DbConn() as conn:
             VideoFile.purge_deleteds(conn)
             self.server.player.load_files_and_mpv()
             conn.commit()
-        self._send_http('OK')
+        self.send_http(b'OK')
 
-    def _receive_yt_query(self, query_txt: QueryText) -> None:
+    def _receive_yt_query(self) -> None:
+        query_txt = QueryText(self.postvars.first_for('query'))
         with DbConn() as conn:
             query_data = YoutubeQuery.new_by_request_saved(
                     conn, self.server.config, query_txt)
@@ -205,45 +185,42 @@ class _TaskHandler(BaseHTTPRequestHandler):
 
     def do_GET(self) -> None:  # pylint:disable=invalid-name
         """Map GET requests to handlers for various paths."""
-        url = urlparse(self.path)
-        toks_url = Path(url.path).parts
-        page_name = Path(toks_url[1] if len(toks_url) > 1 else '')
         try:
-            if PAGE_NAMES['download'] == page_name:
-                self._send_or_download_video(YoutubeId(toks_url[2]))
-            elif PAGE_NAMES['events'] == page_name:
-                self._send_events(_ReqMap(url.query))
-            elif PAGE_NAMES['file'] == page_name:
-                self._send_file_data(Hash.from_b64(toks_url[2]))
-            elif PAGE_NAMES['files'] == page_name:
-                self._send_files_index(_ReqMap(url.query))
-            elif PAGE_NAMES['missing'] == page_name:
+            if self.pagename == PAGE_NAMES['download']:
+                self._send_or_download_video()
+            elif self.pagename == PAGE_NAMES['events']:
+                self._send_events()
+            elif self.pagename == PAGE_NAMES['file']:
+                self._send_file_data()
+            elif self.pagename == PAGE_NAMES['files']:
+                self._send_files_index()
+            elif self.pagename == PAGE_NAMES['missing']:
                 self._send_missing_json()
-            elif PAGE_NAMES['thumbnails'] == page_name:
-                self._send_thumbnail(Path(toks_url[2]))
-            elif PAGE_NAMES['yt_result'] == page_name:
-                self._send_yt_result(YoutubeId(toks_url[2]))
-            elif PAGE_NAMES['yt_queries'] == page_name:
+            elif self.pagename == PAGE_NAMES['thumbnails']:
+                self._send_thumbnail()
+            elif self.pagename == PAGE_NAMES['yt_result']:
+                self._send_yt_result()
+            elif self.pagename == PAGE_NAMES['yt_queries']:
                 self._send_yt_queries_index_and_search()
-            elif PAGE_NAMES['yt_query'] == page_name:
-                self._send_yt_query_page(QueryId(toks_url[2]))
+            elif self.pagename == PAGE_NAMES['yt_query']:
+                self._send_yt_query_page()
             else:  # e.g. for /
                 self._send_playlist()
         except NotFoundException as e:
-            self._send_http(str(e), code=404)
+            self.send_http(bytes(str(e), encoding='utf8'), code=404)
 
     def _send_rendered_template(self,
                                 tmpl_name: Path,
                                 tmpl_ctx: dict[str, Any]
                                 ) -> None:
-        tmpl = self.server.jinja.get_template(str(tmpl_name))
         tmpl_ctx['selected'] = tmpl_ctx.get('selected', '')
         tmpl_ctx['redir_target'] = self.path
         tmpl_ctx['background_color'] = self.server.config.background_color
         tmpl_ctx['page_names'] = PAGE_NAMES
-        self._send_http(tmpl.render(**tmpl_ctx))
+        self.send_rendered(tmpl_name, tmpl_ctx)
 
-    def _send_or_download_video(self, video_id: YoutubeId) -> None:
+    def _send_or_download_video(self) -> None:
+        video_id = YoutubeId(self.path_toks[2])
         try:
             with DbConn() as conn:
                 file_data = VideoFile.get_by_yt_id(conn, video_id)
@@ -258,10 +235,10 @@ class _TaskHandler(BaseHTTPRequestHandler):
                            .joinpath(PAGE_NAMES['yt_result'])
                            .joinpath(video_id))
 
-    def _send_events(self, params: _ReqMap) -> None:
-        self._send_http(headers=[(_HEADER_CONTENT_TYPE, 'text/event-stream'),
-                                 ('Cache-Control', 'no-cache'),
-                                 ('Connection', 'keep-alive')])
+    def _send_events(self) -> None:
+        self.send_http(headers=[(_HEADER_CONTENT_TYPE, 'text/event-stream'),
+                                ('Cache-Control', 'no-cache'),
+                                ('Connection', 'keep-alive')])
         selected: Optional[VideoFile] = None
         last_sent = ''
         payload: dict[str, Any] = {}
@@ -300,7 +277,7 @@ class _TaskHandler(BaseHTTPRequestHandler):
                 payload['title_tags'] = tags
                 payload['title_digest'] = digest
                 payload['title'] = title
-                if params.has_key('playlist'):
+                if self.params.has_key('playlist'):
                     payload['idx'] = self.server.player.idx
                     payload['playlist_files'] = [
                         {'rel_path': str(f.rel_path), 'digest': f.digest.b64}
@@ -308,7 +285,8 @@ class _TaskHandler(BaseHTTPRequestHandler):
             else:
                 sleep(_EVENTS_UPDATE_INTERVAL_S)
 
-    def _send_file_data(self, digest: Hash) -> None:
+    def _send_file_data(self) -> None:
+        digest = Hash.from_b64(self.path_toks[2])
         with DbConn() as conn:
             file = VideoFile.get_one_with_whitelist_tags_display(
                     conn, digest, self.server.config.whitelist_tags_display)
@@ -320,10 +298,10 @@ class _TaskHandler(BaseHTTPRequestHandler):
                  'flag_names': list(FILE_FLAGS),
                  'unused_tags': unused_tags})
 
-    def _send_files_index(self, params: _ReqMap) -> None:
-        filter_path = FilterStr(params.first_for('filter_path'))
-        needed_tags_str = params.first_for('needed_tags')
-        show_absent = bool(params.first_for('show_absent'))
+    def _send_files_index(self) -> None:
+        filter_path = FilterStr(self.params.first_for('filter_path'))
+        needed_tags_str = self.params.first_for('needed_tags')
+        show_absent = bool(self.params.first_for('show_absent'))
         with DbConn() as conn:
             files = VideoFile.get_filtered(
                     conn,
@@ -348,10 +326,11 @@ class _TaskHandler(BaseHTTPRequestHandler):
         with DbConn() as conn:
             missing = [f.digest.b64 for f in VideoFile.get_all(conn)
                        if f.missing]
-        self._send_http(json_dumps(missing),
-                        headers=[(_HEADER_CONTENT_TYPE, _HEADER_APP_JSON)])
+        self.send_http(bytes(json_dumps(missing), encoding='utf8'),
+                       headers=[(_HEADER_CONTENT_TYPE, MIME_APP_JSON)])
 
-    def _send_thumbnail(self, filename: Path) -> None:
+    def _send_thumbnail(self) -> None:
+        filename = Path(self.path_toks[2])
         ensure_expected_dirs([PATH_THUMBNAILS])
         path_thumbnail = PATH_THUMBNAILS.joinpath(filename)
         if not path_thumbnail.exists():
@@ -364,9 +343,10 @@ class _TaskHandler(BaseHTTPRequestHandler):
                     raise NotFoundException from e
                 raise e
         with path_thumbnail.open('rb') as f:
-            self._send_http(f.read(), [(_HEADER_CONTENT_TYPE, 'image/jpg')])
+            self.send_http(f.read(), [(_HEADER_CONTENT_TYPE, 'image/jpg')])
 
-    def _send_yt_result(self, video_id: YoutubeId) -> None:
+    def _send_yt_result(self) -> None:
+        video_id = YoutubeId(self.path_toks[2])
         with DbConn() as conn:
             linked_queries = YoutubeQuery.get_all_for_video(conn, video_id)
             try:
@@ -402,7 +382,8 @@ class _TaskHandler(BaseHTTPRequestHandler):
                                       'quota_count': quota_count,
                                       'selected': 'yt_queries'})
 
-    def _send_yt_query_page(self, query_id: QueryId) -> None:
+    def _send_yt_query_page(self) -> None:
+        query_id = QueryId(self.path_toks[2])
         with DbConn() as conn:
             query = YoutubeQuery.get_one(conn, str(query_id))
             results = YoutubeVideo.get_all_for_query(conn, query_id)
@@ -419,12 +400,4 @@ class _TaskHandler(BaseHTTPRequestHandler):
 
 def serve():
     """Do Server.serve_forever on .config.port until keyboard interrupt."""
-    config = Config()
-    server = Server(Config())
-    print(f'running at port {config.port}')
-    try:
-        server.serve_forever()
-    except KeyboardInterrupt:
-        print('aborted due to keyboard interrupt; '
-              'repeat to end download thread too')
-    server.server_close()
+    Server(Config()).serve()