From: Christian Heller Date: Tue, 26 Mar 2024 22:58:03 +0000 (+0100) Subject: Put mypy into strict mode, adapt code to still pass. X-Git-Url: https://plomlompom.com/repos/%22https:/validator.w3.org/%7B%7Bprefix%7D%7D/condition?a=commitdiff_plain;h=3558a14701955de18ae7adbda0e93aaee7710a92;p=plomtask Put mypy into strict mode, adapt code to still pass. --- diff --git a/plomtask/days.py b/plomtask/days.py index 0622f1d..3b81a7f 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -1,4 +1,5 @@ """Collecting Day and date-related items.""" +from __future__ import annotations from datetime import datetime, timedelta from sqlite3 import Row from plomtask.misc import HandledException @@ -7,16 +8,19 @@ from plomtask.db import DatabaseConnection DATE_FORMAT = '%Y-%m-%d' -def date_valid(date: str): - """Validate date against DATE_FORMAT, return Datetime or None.""" +def valid_date(date_str: str) -> str: + """Validate date against DATE_FORMAT or 'today', return in DATE_FORMAT.""" + if date_str == 'today': + date_str = todays_date() try: - result = datetime.strptime(date, DATE_FORMAT) - except (ValueError, TypeError): - return None - return result + dt = datetime.strptime(date_str, DATE_FORMAT) + except (ValueError, TypeError) as e: + msg = f'Given date of wrong format: {date_str}' + raise HandledException(msg) from e + return dt.strftime(DATE_FORMAT) -def todays_date(): +def todays_date() -> str: """Return current date in DATE_FORMAT.""" return datetime.now().strftime(DATE_FORMAT) @@ -24,38 +28,31 @@ def todays_date(): class Day: """Individual days defined by their dates.""" - def __init__(self, date: str, comment: str = ''): - self.date = date - self.datetime = date_valid(self.date) - if not self.datetime: - raise HandledException(f'Given date of wrong format: {self.date}') + def __init__(self, date: str, comment: str = '') -> None: + self.date = valid_date(date) + self.datetime = datetime.strptime(self.date, DATE_FORMAT) self.comment = comment - def __eq__(self, other: object): + def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and self.date == other.date - def __lt__(self, other): + def __lt__(self, other: Day) -> bool: return self.date < other.date @classmethod - def from_table_row(cls, row: Row): + def from_table_row(cls, row: Row) -> Day: """Make Day from database row.""" return cls(row[0], row[1]) @classmethod def all(cls, db_conn: DatabaseConnection, - date_range: tuple[str, str] = ('', ''), fill_gaps: bool = False): + date_range: tuple[str, str] = ('', ''), + fill_gaps: bool = False) -> list[Day]: """Return list of Days in database within date_range.""" - def date_from_range_str(date_str: str, default: str): - if date_str == '': - date_str = default - elif date_str == 'today': - date_str = todays_date() - elif not date_valid(date_str): - raise HandledException(f'Bad date: {date_str}') - return date_str - start_date = date_from_range_str(date_range[0], '2024-01-01') - end_date = date_from_range_str(date_range[1], '2030-01-01') + min_date = '2024-01-01' + max_date = '2030-12-31' + start_date = valid_date(date_range[0] if date_range[0] else min_date) + end_date = valid_date(date_range[1] if date_range[1] else max_date) days = [] sql = 'SELECT * FROM days WHERE date >= ? AND date <= ?' for row in db_conn.exec(sql, (start_date, end_date)): @@ -74,32 +71,32 @@ class Day: @classmethod def by_date(cls, db_conn: DatabaseConnection, - date: str, create: bool = False): + date: str, create: bool = False) -> Day: """Retrieve Day by date if in DB, else return None.""" for row in db_conn.exec('SELECT * FROM days WHERE date = ?', (date,)): return cls.from_table_row(row) - if create: - return cls(date) - return None + if not create: + raise HandledException(f'Day not found for date: {date}') + return cls(date) @property - def weekday(self): + def weekday(self) -> str: """Return what weekday matches self.date.""" return self.datetime.strftime('%A') @property - def prev_date(self): + def prev_date(self) -> str: """Return date preceding date of this Day.""" prev_datetime = self.datetime - timedelta(days=1) return prev_datetime.strftime(DATE_FORMAT) @property - def next_date(self): + def next_date(self) -> str: """Return date succeeding date of this Day.""" next_datetime = self.datetime + timedelta(days=1) return next_datetime.strftime(DATE_FORMAT) - def save(self, db_conn: DatabaseConnection): + def save(self, db_conn: DatabaseConnection) -> None: """Add (or re-write) self to database.""" db_conn.exec('REPLACE INTO days VALUES (?, ?)', (self.date, self.comment)) diff --git a/plomtask/db.py b/plomtask/db.py index d6966e6..e0a5d4f 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -1,7 +1,8 @@ """Database management.""" from os.path import isfile from difflib import Differ -from sqlite3 import connect as sql_connect +from sqlite3 import connect as sql_connect, Cursor +from typing import Any from plomtask.misc import HandledException PATH_DB_SCHEMA = 'scripts/init.sql' @@ -10,24 +11,24 @@ PATH_DB_SCHEMA = 'scripts/init.sql' class DatabaseFile: # pylint: disable=too-few-public-methods """Represents the sqlite3 database's file.""" - def __init__(self, path): + def __init__(self, path: str) -> None: self.path = path self._check() - def remake(self): + def remake(self) -> None: """Create tables in self.path file as per PATH_DB_SCHEMA sql file.""" with sql_connect(self.path) as conn: with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f: conn.executescript(f.read()) self._check() - def _check(self): + def _check(self) -> None: """Check file exists and is of proper schema.""" self.exists = isfile(self.path) if self.exists: self._validate_schema() - def _validate_schema(self): + def _validate_schema(self) -> None: """Compare found schema with what's stored at PATH_DB_SCHEMA.""" sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql' msg_err = 'Database has wrong tables schema. Diff:\n' @@ -45,18 +46,18 @@ class DatabaseFile: # pylint: disable=too-few-public-methods class DatabaseConnection: """A single connection to the database.""" - def __init__(self, db_file: DatabaseFile): + def __init__(self, db_file: DatabaseFile) -> None: self.file = db_file self.conn = sql_connect(self.file.path) - def commit(self): + def commit(self) -> None: """Commit SQL transaction.""" self.conn.commit() - def exec(self, code: str, inputs: tuple = ()): + def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor: """Add commands to SQL transaction.""" return self.conn.execute(code, inputs) - def close(self): + def close(self) -> None: """Close DB connection.""" self.conn.close() diff --git a/plomtask/http.py b/plomtask/http.py index 9a68221..ddea087 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -1,4 +1,5 @@ """Web server stuff.""" +from typing import Any from http.server import BaseHTTPRequestHandler from http.server import HTTPServer from urllib.parse import urlparse, parse_qs @@ -6,7 +7,7 @@ from os.path import split as path_split from jinja2 import Environment as JinjaEnv, FileSystemLoader as JinjaFSLoader from plomtask.days import Day, todays_date from plomtask.misc import HandledException -from plomtask.db import DatabaseConnection +from plomtask.db import DatabaseConnection, DatabaseFile from plomtask.processes import Process TEMPLATES_DIR = 'templates' @@ -15,7 +16,8 @@ TEMPLATES_DIR = 'templates' class TaskServer(HTTPServer): """Variant of HTTPServer that knows .jinja as Jinja Environment.""" - def __init__(self, db_file, *args, **kwargs): + def __init__(self, db_file: DatabaseFile, + *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.db = db_file self.jinja = JinjaEnv(loader=JinjaFSLoader(TEMPLATES_DIR)) @@ -25,7 +27,7 @@ class TaskHandler(BaseHTTPRequestHandler): """Handles single HTTP request.""" server: TaskServer - def do_GET(self): + def do_GET(self) -> None: """Handle any GET request.""" try: conn, site, params = self._init_handling() @@ -39,10 +41,10 @@ class TaskHandler(BaseHTTPRequestHandler): elif 'process' == site: id_ = params.get('id', [None])[0] try: - id_ = int(id_) if id_ else None + id__ = int(id_) if id_ else None except ValueError as e: raise HandledException(f'Bad ?id= value: {id_}') from e - html = self.do_GET_process(conn, id_) + html = self.do_GET_process(conn, id__) elif 'processes' == site: html = self.do_GET_processes(conn) else: @@ -53,89 +55,91 @@ class TaskHandler(BaseHTTPRequestHandler): except HandledException as error: self._send_msg(error) - def do_GET_calendar(self, conn: DatabaseConnection, start: str, end: str): + def do_GET_calendar(self, conn: DatabaseConnection, + start: str, end: str) -> str: """Show Days.""" days = Day.all(conn, date_range=(start, end), fill_gaps=True) return self.server.jinja.get_template('calendar.html').render( days=days, start=start, end=end) - def do_GET_day(self, conn: DatabaseConnection, date: str): + def do_GET_day(self, conn: DatabaseConnection, date: str) -> str: """Show single Day.""" day = Day.by_date(conn, date, create=True) return self.server.jinja.get_template('day.html').render(day=day) - def do_GET_process(self, conn: DatabaseConnection, id_: int | None): + def do_GET_process(self, conn: DatabaseConnection, id_: int | None) -> str: """Show process of id_.""" return self.server.jinja.get_template('process.html').render( process=Process.by_id(conn, id_, create=True)) - def do_GET_processes(self, conn: DatabaseConnection): + def do_GET_processes(self, conn: DatabaseConnection) -> str: """Show all Processes.""" return self.server.jinja.get_template('processes.html').render( processes=Process.all(conn)) - def do_POST(self): + def do_POST(self) -> None: """Handle any POST request.""" try: conn, site, params = self._init_handling() length = int(self.headers['content-length']) postvars = parse_qs(self.rfile.read(length).decode(), - keep_blank_values=1) + keep_blank_values=True) if 'day' == site: - date = params.get('date', [None])[0] + date = params.get('date', [''])[0] self.do_POST_day(conn, date, postvars) elif 'process' == site: - id_ = params.get('id', [None])[0] + id_ = params.get('id', [''])[0] try: - id_ = int(id_) if id_ else None + id__ = int(id_) if id_ else None except ValueError as e: raise HandledException(f'Bad ?id= value: {id_}') from e - self.do_POST_process(conn, id_, postvars) + self.do_POST_process(conn, id__, postvars) conn.commit() conn.close() self._redirect('/') except HandledException as error: self._send_msg(error) - def do_POST_day(self, conn: DatabaseConnection, date: str, postvars: dict): + def do_POST_day(self, conn: DatabaseConnection, + date: str, postvars: dict[str, list[str]]) -> None: """Update or insert Day of date and fields defined in postvars.""" day = Day.by_date(conn, date, create=True) day.comment = postvars['comment'][0] day.save(conn) def do_POST_process(self, conn: DatabaseConnection, id_: int | None, - postvars: dict): + postvars: dict[str, list[str]]) -> None: """Update or insert Process of id_ and fields defined in postvars.""" process = Process.by_id(conn, id_, create=True) - if process: - process.title.set(postvars['title'][0]) - process.description.set(postvars['description'][0]) - effort = postvars['effort'][0] - try: - process.effort.set(float(effort)) - except ValueError as e: - raise HandledException(f'Bad effort value: {effort}') from e - process.save(conn) - - def _init_handling(self): + process.title.set(postvars['title'][0]) + process.description.set(postvars['description'][0]) + effort = postvars['effort'][0] + try: + process.effort.set(float(effort)) + except ValueError as e: + raise HandledException(f'Bad effort value: {effort}') from e + process.save(conn) + + def _init_handling(self) -> \ + tuple[DatabaseConnection, str, dict[str, list[str]]]: conn = DatabaseConnection(self.server.db) parsed_url = urlparse(self.path) site = path_split(parsed_url.path)[1] params = parse_qs(parsed_url.query) return conn, site, params - def _redirect(self, target: str): + def _redirect(self, target: str) -> None: self.send_response(302) self.send_header('Location', target) self.end_headers() - def _send_html(self, html: str, code: int = 200): + def _send_html(self, html: str, code: int = 200) -> None: """Send HTML as proper HTTP response.""" self.send_response(code) self.end_headers() self.wfile.write(bytes(html, 'utf-8')) - def _send_msg(self, msg: str, code: int = 400): + def _send_msg(self, msg: Exception, code: int = 400) -> None: """Send message in HTML formatting as HTTP response.""" html = self.server.jinja.get_template('msg.html').render(msg=msg) self._send_html(html, code) diff --git a/plomtask/processes.py b/plomtask/processes.py index 8a5bf64..4867227 100644 --- a/plomtask/processes.py +++ b/plomtask/processes.py @@ -3,6 +3,7 @@ from __future__ import annotations from sqlite3 import Row from datetime import datetime from plomtask.db import DatabaseConnection +from plomtask.misc import HandledException class Process: @@ -35,15 +36,17 @@ class Process: return list(processes.values()) @classmethod - def by_id(cls, db_conn: DatabaseConnection, - id_: int | None, create: bool = False) -> Process | None: + def by_id(cls, db_conn: DatabaseConnection, id_: int | None, + create: bool = False) -> Process: """Collect all Processes and their connected VersionedAttributes.""" process = None for row in db_conn.exec('SELECT * FROM processes ' 'WHERE id = ?', (id_,)): process = cls(row[0]) break - if create and not process: + if not process: + if not create: + raise HandledException(f'Process not found of id: {id_}') process = Process(id_) if process: for row in db_conn.exec('SELECT * FROM process_titles ' diff --git a/scripts/pre-commit b/scripts/pre-commit index cab4553..6f84c41 100755 --- a/scripts/pre-commit +++ b/scripts/pre-commit @@ -2,7 +2,7 @@ set -e for dir in $(echo '.' 'plomtask' 'tests'); do echo "Running mypy on ${dir}/ …." - python3 -m mypy ${dir}/*.py + python3 -m mypy --strict ${dir}/*.py echo "Running flake8 on ${dir}/ …" python3 -m flake8 ${dir}/*.py echo "Running pylint on ${dir}/ …" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/days.py b/tests/days.py index 5ff6459..f3ed082 100644 --- a/tests/days.py +++ b/tests/days.py @@ -1,6 +1,5 @@ """Test Days module.""" from unittest import TestCase -from http.client import HTTPConnection from datetime import datetime from tests.utils import TestCaseWithDB, TestCaseWithServer from plomtask.days import Day, todays_date @@ -10,21 +9,17 @@ from plomtask.misc import HandledException class TestsSansDB(TestCase): """Days module tests not requiring DB setup.""" - def test_Day_dates(self): + def test_Day_dates(self) -> None: """Test Day's date format.""" with self.assertRaises(HandledException): Day('foo') - with self.assertRaises(HandledException): - Day(None) - with self.assertRaises(HandledException): - Day(3) with self.assertRaises(HandledException): Day('2024-02-30') with self.assertRaises(HandledException): Day('2024-02-01 23:00:00') self.assertEqual(datetime(2024, 1, 1), Day('2024-01-01').datetime) - def test_Day_sorting(self): + def test_Day_sorting(self) -> None: """Test Day.__lt__.""" day1 = Day('2024-01-01') day2 = Day('2024-01-02') @@ -32,7 +27,7 @@ class TestsSansDB(TestCase): days = [day3, day1, day2] self.assertEqual(sorted(days), [day1, day2, day3]) - def test_Day_weekday(self): + def test_Day_weekday(self) -> None: """Test Day.weekday.""" self.assertEqual(Day('2024-03-17').weekday, 'Sunday') @@ -40,18 +35,19 @@ class TestsSansDB(TestCase): class TestsWithDB(TestCaseWithDB): """Days module tests not requiring DB setup.""" - def test_Day_by_date(self): + def test_Day_by_date(self) -> None: """Test Day.by_date().""" - self.assertEqual(None, Day.by_date(self.db_conn, '2024-01-01')) + with self.assertRaises(HandledException): + Day.by_date(self.db_conn, '2024-01-01') Day('2024-01-01').save(self.db_conn) self.assertEqual(Day('2024-01-01'), Day.by_date(self.db_conn, '2024-01-01')) - self.assertEqual(None, - Day.by_date(self.db_conn, '2024-01-02')) + with self.assertRaises(HandledException): + Day.by_date(self.db_conn, '2024-01-02') self.assertEqual(Day('2024-01-02'), Day.by_date(self.db_conn, '2024-01-02', create=True)) - def test_Day_all(self): + def test_Day_all(self) -> None: """Test Day.all(), especially in regards to date range filtering.""" day1 = Day('2024-01-01') day2 = Day('2024-01-02') @@ -59,8 +55,7 @@ class TestsWithDB(TestCaseWithDB): day1.save(self.db_conn) day2.save(self.db_conn) day3.save(self.db_conn) - self.assertEqual(Day.all(self.db_conn), - [day1, day2, day3]) + self.assertEqual(Day.all(self.db_conn), [day1, day2, day3]) self.assertEqual(Day.all(self.db_conn, ('', '')), [day1, day2, day3]) self.assertEqual(Day.all(self.db_conn, ('2024-01-01', '2024-01-03')), @@ -86,7 +81,7 @@ class TestsWithDB(TestCaseWithDB): today.save(self.db_conn) self.assertEqual(Day.all(self.db_conn, ('today', 'today')), [today]) - def test_Day_neighbor_dates(self): + def test_Day_neighbor_dates(self) -> None: """Test Day.prev_date and Day.next_date.""" self.assertEqual(Day('2024-01-01').prev_date, '2023-12-31') self.assertEqual(Day('2023-02-28').next_date, '2023-03-01') @@ -95,22 +90,21 @@ class TestsWithDB(TestCaseWithDB): class TestsWithServer(TestCaseWithServer): """Tests against our HTTP server/handler (and database).""" - def test_do_GET(self): + def test_do_GET(self) -> None: """Test /day and /calendar response codes.""" - http_conn = HTTPConnection(*self.httpd.server_address) - http_conn.request('GET', '/day') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/day?date=3000-01-01') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/day?date=FOO') - self.assertEqual(http_conn.getresponse().status, 400) - http_conn.request('GET', '/calendar') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/calendar?start=&end=') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/calendar?start=today&end=today') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/calendar?start=2024-01-01&end=2025-01-01') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/calendar?start=foo') - self.assertEqual(http_conn.getresponse().status, 400) + self.conn.request('GET', '/day') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/day?date=3000-01-01') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/day?date=FOO') + self.assertEqual(self.conn.getresponse().status, 400) + self.conn.request('GET', '/calendar') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/calendar?start=&end=') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/calendar?start=today&end=today') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/calendar?start=2024-01-01&end=2025-01-01') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/calendar?start=foo') + self.assertEqual(self.conn.getresponse().status, 400) diff --git a/tests/processes.py b/tests/processes.py index 271289b..17af14e 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -1,15 +1,15 @@ """Test Processes module.""" from unittest import TestCase -from http.client import HTTPConnection from urllib.parse import urlencode from tests.utils import TestCaseWithDB, TestCaseWithServer from plomtask.processes import Process +from plomtask.misc import HandledException class TestsSansDB(TestCase): """Module tests not requiring DB setup.""" - def test_Process_versioned_defaults(self): + def test_Process_versioned_defaults(self) -> None: """Test defaults of Process' VersionedAttributes.""" self.assertEqual(Process(None).title.newest, 'UNNAMED') self.assertEqual(Process(None).description.newest, '') @@ -19,7 +19,7 @@ class TestsSansDB(TestCase): class TestsWithDB(TestCaseWithDB): """Mdule tests not requiring DB setup.""" - def test_Process_save(self): + def test_Process_save(self) -> None: """Test Process.save().""" p_saved = Process(None) p_saved.save(self.db_conn) @@ -40,11 +40,14 @@ class TestsWithDB(TestCaseWithDB): p_loaded = Process.by_id(self.db_conn, p_saved.id_) self.assertEqual(p_saved.title.history, p_loaded.title.history) - def test_Process_by_id(self): + def test_Process_by_id(self) -> None: """Test Process.by_id().""" - self.assertEqual(None, Process.by_id(self.db_conn, None, create=False)) - self.assertEqual(None, Process.by_id(self.db_conn, 0, create=False)) - self.assertEqual(None, Process.by_id(self.db_conn, 1, create=False)) + with self.assertRaises(HandledException): + Process.by_id(self.db_conn, None, create=False) + with self.assertRaises(HandledException): + Process.by_id(self.db_conn, 0, create=False) + with self.assertRaises(HandledException): + Process.by_id(self.db_conn, 1, create=False) self.assertNotEqual(Process(1).id_, Process.by_id(self.db_conn, None, create=True).id_) self.assertNotEqual(Process(1).id_, @@ -54,7 +57,7 @@ class TestsWithDB(TestCaseWithDB): self.assertEqual(Process(2).id_, Process.by_id(self.db_conn, 2, create=True).id_) - def test_Process_all(self): + def test_Process_all(self) -> None: """Test Process.all().""" p_1 = Process(None) p_1.save(self.db_conn) @@ -67,16 +70,16 @@ class TestsWithDB(TestCaseWithDB): class TestsWithServer(TestCaseWithServer): """Module tests against our HTTP server/handler (and database).""" - def test_do_POST_process(self): + def test_do_POST_process(self) -> None: """Test POST /process and its effect on the database.""" - def post_data_to_expect(form_data: dict, to_: str, expect: int): + def post_data_to_expect(form_data: dict[str, object], + to_: str, expect: int) -> None: encoded_form_data = urlencode(form_data).encode('utf-8') headers = {'Content-Type': 'application/x-www-form-urlencoded', 'Content-Length': str(len(encoded_form_data))} - http_conn.request('POST', to_, + self.conn.request('POST', to_, body=encoded_form_data, headers=headers) - self.assertEqual(http_conn.getresponse().status, expect) - http_conn = HTTPConnection(*self.httpd.server_address) + self.assertEqual(self.conn.getresponse().status, expect) form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.0} post_data_to_expect(form_data, '/process?id=FOO', 400) form_data['effort'] = 'foo' @@ -90,16 +93,15 @@ class TestsWithServer(TestCaseWithServer): self.assertEqual([p.id_ for p in Process.all(self.db_conn)], [retrieved.id_]) - def test_do_GET(self): + def test_do_GET(self) -> None: """Test /process and /processes response codes.""" - http_conn = HTTPConnection(*self.httpd.server_address) - http_conn.request('GET', '/process') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/process?id=') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/process?id=0') - self.assertEqual(http_conn.getresponse().status, 200) - http_conn.request('GET', '/process?id=FOO') - self.assertEqual(http_conn.getresponse().status, 400) - http_conn.request('GET', '/processes') - self.assertEqual(http_conn.getresponse().status, 200) + self.conn.request('GET', '/process') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/process?id=') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/process?id=0') + self.assertEqual(self.conn.getresponse().status, 200) + self.conn.request('GET', '/process?id=FOO') + self.assertEqual(self.conn.getresponse().status, 400) + self.conn.request('GET', '/processes') + self.assertEqual(self.conn.getresponse().status, 200) diff --git a/tests/utils.py b/tests/utils.py index cd0c457..9964201 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ """Shared test utilities.""" from unittest import TestCase from threading import Thread +from http.client import HTTPConnection from datetime import datetime from os import remove as remove_file from plomtask.db import DatabaseFile, DatabaseConnection @@ -10,13 +11,13 @@ from plomtask.http import TaskHandler, TaskServer class TestCaseWithDB(TestCase): """Module tests not requiring DB setup.""" - def setUp(self): + def setUp(self) -> None: timestamp = datetime.now().timestamp() self.db_file = DatabaseFile(f'test_db:{timestamp}') self.db_file.remake() self.db_conn = DatabaseConnection(self.db_file) - def tearDown(self): + def tearDown(self) -> None: self.db_conn.close() remove_file(self.db_file.path) @@ -24,14 +25,16 @@ class TestCaseWithDB(TestCase): class TestCaseWithServer(TestCaseWithDB): """Module tests against our HTTP server/handler (and database).""" - def setUp(self): + def setUp(self) -> None: super().setUp() self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler) self.server_thread = Thread(target=self.httpd.serve_forever) self.server_thread.daemon = True self.server_thread.start() + self.conn = HTTPConnection(str(self.httpd.server_address[0]), + self.httpd.server_address[1]) - def tearDown(self): + def tearDown(self) -> None: self.httpd.shutdown() self.httpd.server_close() self.server_thread.join()