From: Christian Heller Date: Sun, 10 Nov 2024 15:41:38 +0000 (+0100) Subject: Further harden type safety. X-Git-Url: https://plomlompom.com/repos/%7Broute%7D?a=commitdiff_plain;h=631b24c6af61bf080de31e438b582062cb1f13f5;p=ytplom Further harden type safety. --- diff --git a/ytplom.py b/ytplom.py index 0d55a26..7dcdfab 100755 --- a/ytplom.py +++ b/ytplom.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """Minimalistic download-focused YouTube interface.""" -from typing import TypeAlias, Optional +from typing import TypeAlias, Optional, NewType from os import environ, makedirs, scandir, remove as os_remove from os.path import (isdir, isfile, exists as path_exists, join as path_join, splitext, basename) @@ -16,13 +16,13 @@ from jinja2 import Template from yt_dlp import YoutubeDL # type: ignore import googleapiclient.discovery # type: ignore -DatetimeStr: TypeAlias = str -QuotaCost: TypeAlias = int -VideoId: TypeAlias = str -FilePathStr: TypeAlias = str -QueryId = str +DatetimeStr = NewType('DatetimeStr', str) +QuotaCost = NewType('QuotaCost', int) +VideoId = NewType('VideoId', str) +FilePathStr = NewType('FilePathStr', str) +QueryId = NewType('QueryId', str) Result: TypeAlias = dict[str, str | bool] -QueryData: TypeAlias = dict[QueryId, str | int | list[Result]] +QueryData: TypeAlias = dict[str, str | int | list[Result]] QuotaLog: TypeAlias = dict[DatetimeStr, QuotaCost] Header: TypeAlias = tuple[str, str] DownloadsDB = dict[VideoId, FilePathStr] @@ -31,20 +31,20 @@ TemplateContext = dict[str, int | str | QueryData | list[QueryData]] API_KEY = environ.get('GOOGLE_API_KEY') HTTP_PORT = 8083 -PATH_QUOTA_LOG = 'quota_log.json' -PATH_DIR_DOWNLOADS = 'downloads' -PATH_DIR_THUMBNAILS = 'thumbnails' -PATH_DIR_REQUESTS_CACHE = 'cache_googleapi' -PATH_DIR_TEMPLATES = 'templates' -NAME_DIR_TEMP = 'temp' -NAME_TEMPLATE_INDEX = 'index.tmpl' -NAME_TEMPLATE_RESULTS = 'results.tmpl' +PATH_QUOTA_LOG = FilePathStr('quota_log.json') +PATH_DIR_DOWNLOADS = FilePathStr('downloads') +PATH_DIR_THUMBNAILS = FilePathStr('thumbnails') +PATH_DIR_REQUESTS_CACHE = FilePathStr('cache_googleapi') +PATH_DIR_TEMPLATES = FilePathStr('templates') +NAME_DIR_TEMP = FilePathStr('temp') +NAME_TEMPLATE_INDEX = FilePathStr('index.tmpl') +NAME_TEMPLATE_RESULTS = FilePathStr('results.tmpl') -PATH_DIR_TEMP: FilePathStr = path_join(PATH_DIR_DOWNLOADS, NAME_DIR_TEMP) +PATH_DIR_TEMP = FilePathStr(path_join(PATH_DIR_DOWNLOADS, NAME_DIR_TEMP)) EXPECTED_DIRS = [PATH_DIR_DOWNLOADS, PATH_DIR_TEMP, PATH_DIR_THUMBNAILS, PATH_DIR_REQUESTS_CACHE] -PATH_TEMPLATE_INDEX: FilePathStr = path_join(PATH_DIR_TEMPLATES, - NAME_TEMPLATE_INDEX) +PATH_TEMPLATE_INDEX = FilePathStr(path_join(PATH_DIR_TEMPLATES, + NAME_TEMPLATE_INDEX)) TIMESTAMP_FMT = '%Y-%m-%d %H:%M:%S.%f' YOUTUBE_URL_PREFIX = 'https://www.youtube.com/watch?v=' YT_DOWNLOAD_FORMAT = 'bestvideo[height<=1080][width<=1920]+bestaudio'\ @@ -53,8 +53,8 @@ YT_DL_PARAMS = {'paths': {'home': PATH_DIR_DOWNLOADS, 'temp': NAME_DIR_TEMP}, 'format': YT_DOWNLOAD_FORMAT} -QUOTA_COST_YOUTUBE_SEARCH: QuotaCost = 100 -QUOTA_COST_YOUTUBE_DETAILS: QuotaCost = 1 +QUOTA_COST_YOUTUBE_SEARCH = QuotaCost(100) +QUOTA_COST_YOUTUBE_DETAILS = QuotaCost(1) to_download: list[VideoId] = [] @@ -114,7 +114,7 @@ def read_quota_log() -> QuotaLog: def update_quota_log(now: DatetimeStr, cost: QuotaCost) -> None: """Update quota log from read_quota_log, add cost to now's row.""" quota_log = read_quota_log() - quota_log[now] = quota_log.get(now, 0) + cost + quota_log[now] = QuotaCost(quota_log.get(now, 0) + cost) with open(PATH_QUOTA_LOG, 'w', encoding='utf8') as f: json_dump(quota_log, f) @@ -155,7 +155,7 @@ class TaskHandler(BaseHTTPRequestHandler): with open(path_join(PATH_DIR_REQUESTS_CACHE, f'{md5sum}.json'), 'w', encoding='utf8') as f: json_dump(query_data, f) - return md5sum + return QueryId(md5sum) def collect_results(now: DatetimeStr, query_txt: str) -> list[Result]: youtube = googleapiclient.discovery.build('youtube', 'v3', @@ -195,7 +195,7 @@ class TaskHandler(BaseHTTPRequestHandler): body_length = int(self.headers['content-length']) postvars = parse_qs(self.rfile.read(body_length).decode()) query_txt = postvars['query'][0] - now = datetime.now().strftime(TIMESTAMP_FMT) + now = DatetimeStr(datetime.now().strftime(TIMESTAMP_FMT)) results = collect_results(now, query_txt) md5sum = store_at_filename_hashing_query( {'text': query_txt, 'retrieved_at': now, 'results': results}) @@ -207,11 +207,12 @@ class TaskHandler(BaseHTTPRequestHandler): toks_url: list[str] = url.path.split('/') page_name = toks_url[1] if 'thumbnails' == page_name: - self._send_thumbnail(toks_url[2]) + self._send_thumbnail(FilePathStr(toks_url[2])) if 'dl' == page_name: - self._send_or_download_video(toks_url[2], parse_qs(url.query)) + self._send_or_download_video(VideoId(toks_url[2]), + parse_qs(url.query)) if 'query' == page_name: - self._send_query_page(toks_url[2]) + self._send_query_page(QueryId(toks_url[2])) else: # e.g. for / self._send_queries_index_and_search() @@ -224,7 +225,7 @@ class TaskHandler(BaseHTTPRequestHandler): 'r', encoding='utf8' ) as templ_file: tmpl = Template(str(templ_file.read())) - html= tmpl.render(**tmpl_ctx) + html = tmpl.render(**tmpl_ctx) self._send_http(bytes(html, 'utf8')) @staticmethod @@ -240,9 +241,9 @@ class TaskHandler(BaseHTTPRequestHandler): """Create dictionary of downloads mapping video IDs to file paths.""" downloads_db = {} for e in [e for e in scandir(PATH_DIR_DOWNLOADS) if isfile(e.path)]: - before_ext, _ = splitext(e.path) - id_: VideoId = before_ext.split('[')[-1].split(']')[0] - downloads_db[id_] = e.path + before_ext = splitext(e.path)[0] + id_ = VideoId(before_ext.split('[')[-1].split(']')[0]) + downloads_db[id_] = FilePathStr(e.path) return downloads_db def _send_thumbnail(self, filename: FilePathStr) -> None: @@ -263,7 +264,7 @@ class TaskHandler(BaseHTTPRequestHandler): self._send_http(content=video) return to_download.append(video_id) - dl_query_id: QueryId = params.get('from_query', [''])[0] + dl_query_id = params.get('from_query', [''])[0] redir_path = f'/query/{dl_query_id}' if dl_query_id else '/' self._send_http(headers=[('Location', redir_path)], code=302) @@ -272,7 +273,7 @@ class TaskHandler(BaseHTTPRequestHandler): def reformat_duration(duration_str: str): date_dur, time_dur = duration_str.split('T') - seconds: int = 0 + seconds = 0 date_dur = date_dur[1:] for dur_char, len_seconds in (('Y', 60*60*24*365.25), ('M', 60*60*24*30), @@ -313,7 +314,7 @@ class TaskHandler(BaseHTTPRequestHandler): queries: list[QueryData] = [] for file in [f for f in scandir(PATH_DIR_REQUESTS_CACHE) if isfile(f.path)]: - id_, _ = splitext(basename(file.path)) + id_ = splitext(basename(file.path))[0] with open(file.path, 'r', encoding='utf8') as query_file: filed_query: QueryData = json_load(query_file) filed_query['id'] = id_