home · contact · privacy
d41e7b39ad5fe07abdc701be674d9f81678cd30f
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11
12
13 class TestCaseWithDB(TestCase):
14     """Module tests not requiring DB setup."""
15
16     def setUp(self) -> None:
17         timestamp = datetime.now().timestamp()
18         self.db_file = DatabaseFile(f'test_db:{timestamp}')
19         self.db_file.remake()
20         self.db_conn = DatabaseConnection(self.db_file)
21
22     def tearDown(self) -> None:
23         self.db_conn.close()
24         remove_file(self.db_file.path)
25
26
27 class TestCaseWithServer(TestCaseWithDB):
28     """Module tests against our HTTP server/handler (and database)."""
29
30     def setUp(self) -> None:
31         super().setUp()
32         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
33         self.server_thread = Thread(target=self.httpd.serve_forever)
34         self.server_thread.daemon = True
35         self.server_thread.start()
36         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
37                                    self.httpd.server_address[1])
38
39     def tearDown(self) -> None:
40         self.httpd.shutdown()
41         self.httpd.server_close()
42         self.server_thread.join()
43         super().tearDown()
44
45     def check_redirect(self, target: str) -> None:
46         """Check that self.conn answers with a 302 redirect to target."""
47         response = self.conn.getresponse()
48         self.assertEqual(response.status, 302)
49         self.assertEqual(response.getheader('Location'), target)
50
51     def check_get(self, target: str, expected_code: int) -> None:
52         """Check that a GET to target yields expected_code."""
53         self.conn.request('GET', target)
54         self.assertEqual(self.conn.getresponse().status, expected_code)
55
56     def check_post(self, data: Mapping[str, object], target: str,
57                    expected_code: int, redirect_location: str = '') -> None:
58         """Check that POST of data to target yields expected_code."""
59         encoded_form_data = urlencode(data).encode('utf-8')
60         headers = {'Content-Type': 'application/x-www-form-urlencoded',
61                    'Content-Length': str(len(encoded_form_data))}
62         self.conn.request('POST', target,
63                           body=encoded_form_data, headers=headers)
64         if 302 == expected_code:
65             self.check_redirect(redirect_location)
66         else:
67             self.assertEqual(self.conn.getresponse().status, expected_code)