home · contact · privacy
Re-write caching.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15
16
17 class TestCaseWithDB(TestCase):
18     """Module tests not requiring DB setup."""
19
20     def setUp(self) -> None:
21         Condition.empty_cache()
22         Day.empty_cache()
23         Process.empty_cache()
24         ProcessStep.empty_cache()
25         Todo.empty_cache()
26         timestamp = datetime.now().timestamp()
27         self.db_file = DatabaseFile(f'test_db:{timestamp}')
28         self.db_file.remake()
29         self.db_conn = DatabaseConnection(self.db_file)
30
31     def tearDown(self) -> None:
32         self.db_conn.close()
33         remove_file(self.db_file.path)
34
35
36 class TestCaseWithServer(TestCaseWithDB):
37     """Module tests against our HTTP server/handler (and database)."""
38
39     def setUp(self) -> None:
40         super().setUp()
41         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
42         self.server_thread = Thread(target=self.httpd.serve_forever)
43         self.server_thread.daemon = True
44         self.server_thread.start()
45         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
46                                    self.httpd.server_address[1])
47
48     def tearDown(self) -> None:
49         self.httpd.shutdown()
50         self.httpd.server_close()
51         self.server_thread.join()
52         super().tearDown()
53
54     def check_redirect(self, target: str) -> None:
55         """Check that self.conn answers with a 302 redirect to target."""
56         response = self.conn.getresponse()
57         self.assertEqual(response.status, 302)
58         self.assertEqual(response.getheader('Location'), target)
59
60     def check_get(self, target: str, expected_code: int) -> None:
61         """Check that a GET to target yields expected_code."""
62         self.conn.request('GET', target)
63         self.assertEqual(self.conn.getresponse().status, expected_code)
64
65     def check_post(self, data: Mapping[str, object], target: str,
66                    expected_code: int, redirect_location: str = '/') -> None:
67         """Check that POST of data to target yields expected_code."""
68         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
69         headers = {'Content-Type': 'application/x-www-form-urlencoded',
70                    'Content-Length': str(len(encoded_form_data))}
71         self.conn.request('POST', target,
72                           body=encoded_form_data, headers=headers)
73         if 302 == expected_code:
74             self.check_redirect(redirect_location)
75         else:
76             self.assertEqual(self.conn.getresponse().status, expected_code)