home · contact · privacy
Extend POST tests, and handling of missing form data.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from plomtask.db import DatabaseFile, DatabaseConnection
9 from plomtask.http import TaskHandler, TaskServer
10
11
12 class TestCaseWithDB(TestCase):
13     """Module tests not requiring DB setup."""
14
15     def setUp(self) -> None:
16         timestamp = datetime.now().timestamp()
17         self.db_file = DatabaseFile(f'test_db:{timestamp}')
18         self.db_file.remake()
19         self.db_conn = DatabaseConnection(self.db_file)
20
21     def tearDown(self) -> None:
22         self.db_conn.close()
23         remove_file(self.db_file.path)
24
25
26 class TestCaseWithServer(TestCaseWithDB):
27     """Module tests against our HTTP server/handler (and database)."""
28
29     def setUp(self) -> None:
30         super().setUp()
31         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
32         self.server_thread = Thread(target=self.httpd.serve_forever)
33         self.server_thread.daemon = True
34         self.server_thread.start()
35         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
36                                    self.httpd.server_address[1])
37
38     def tearDown(self) -> None:
39         self.httpd.shutdown()
40         self.httpd.server_close()
41         self.server_thread.join()
42         super().tearDown()
43
44     def post_to(self, data: dict[str, object], target: str) -> None:
45         """Post form data to target URL."""
46         encoded_form_data = urlencode(data).encode('utf-8')
47         headers = {'Content-Type': 'application/x-www-form-urlencoded',
48                    'Content-Length': str(len(encoded_form_data))}
49         self.conn.request('POST', target,
50                           body=encoded_form_data, headers=headers)
51
52     def check_redirect(self, target: str) -> None:
53         """Check that self.conn answers with a 302 redirect to target."""
54         response = self.conn.getresponse()
55         self.assertEqual(response.status, 302)
56         self.assertEqual(response.getheader('Location'), target)