home · contact · privacy
61dbb36b949ee3fdfd5a38dcada155a5ab18a924
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseWithDB(TestCase):
19     """Module tests not requiring DB setup."""
20     checked_class: Any
21     default_ids: tuple[int | str, int | str, int | str]
22
23     def setUp(self) -> None:
24         Condition.empty_cache()
25         Day.empty_cache()
26         Process.empty_cache()
27         ProcessStep.empty_cache()
28         Todo.empty_cache()
29         timestamp = datetime.now().timestamp()
30         self.db_file = DatabaseFile(f'test_db:{timestamp}')
31         self.db_file.remake()
32         self.db_conn = DatabaseConnection(self.db_file)
33
34     def tearDown(self) -> None:
35         self.db_conn.close()
36         remove_file(self.db_file.path)
37
38     def check_storage(self, content: list[Any]) -> None:
39         """Test cache and DB equal content."""
40         expected_cache = {}
41         for item in content:
42             expected_cache[item.id_] = item
43         self.assertEqual(self.checked_class.get_cache(), expected_cache)
44         db_found: list[Any] = []
45         for item in content:
46             assert isinstance(item.id_, (str, int))
47             for row in self.db_conn.row_where(self.checked_class.table_name,
48                                               'id', item.id_):
49                 db_found += [self.checked_class.from_table_row(self.db_conn,
50                                                                row)]
51         self.assertEqual(sorted(content), sorted(db_found))
52
53     def check_saving_and_caching(self, **kwargs: Any) -> Any:
54         """Test instance.save in its core without relations."""
55         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
56         # check object init itself doesn't store anything yet
57         self.check_storage([])
58         # check saving stores in cache and DB
59         obj.save(self.db_conn)
60         self.check_storage([obj])
61         # check core attributes set properly (and not unset by saving)
62         for key, value in kwargs.items():
63             self.assertEqual(getattr(obj, key), value)
64
65     def check_by_id(self) -> None:
66         """Test .by_id(), including creation."""
67         # check failure if not yet saved
68         id1, id2 = self.default_ids[0], self.default_ids[1]
69         obj = self.checked_class(id1)  # pylint: disable=not-callable
70         with self.assertRaises(NotFoundException):
71             self.checked_class.by_id(self.db_conn, id1)
72         # check identity of saved and retrieved
73         obj.save(self.db_conn)
74         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
75         # check create=True acts like normal instantiation (sans saving)
76         by_id_created = self.checked_class.by_id(self.db_conn, id2,
77                                                  create=True)
78         # pylint: disable=not-callable
79         self.assertEqual(self.checked_class(id2), by_id_created)
80         self.check_storage([obj])
81
82     def check_from_table_row(self) -> None:
83         """Test .from_table_row() properly reads in class from DB"""
84         id_ = self.default_ids[0]
85         obj = self.checked_class(id_)  # pylint: disable=not-callable
86         obj.save(self.db_conn)
87         assert isinstance(obj.id_, (str, int))
88         for row in self.db_conn.row_where(self.checked_class.table_name,
89                                           'id', obj.id_):
90             retrieved = self.checked_class.from_table_row(self.db_conn, row)
91             self.assertEqual(obj, retrieved)
92             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
93
94     def check_all(self) -> tuple[Any, Any, Any]:
95         """Test .all()."""
96         # pylint: disable=not-callable
97         item1 = self.checked_class(self.default_ids[0])
98         item2 = self.checked_class(self.default_ids[1])
99         item3 = self.checked_class(self.default_ids[2])
100         # check pre-save .all() returns empty list
101         self.assertEqual(self.checked_class.all(self.db_conn), [])
102         # check that all() shows all saved, but no unsaved items
103         item1.save(self.db_conn)
104         item3.save(self.db_conn)
105         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
106                          sorted([item1, item3]))
107         item2.save(self.db_conn)
108         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
109                          sorted([item1, item2, item3]))
110         return item1, item2, item3
111
112     def check_singularity(self, defaulting_field: str,
113                           non_default_value: Any) -> None:
114         """Test pointers made for single object keep pointing to it."""
115         id1 = self.default_ids[0]
116         obj = self.checked_class(id1)  # pylint: disable=not-callable
117         obj.save(self.db_conn)
118         setattr(obj, defaulting_field, non_default_value)
119         retrieved = self.checked_class.by_id(self.db_conn, id1)
120         self.assertEqual(non_default_value,
121                          getattr(retrieved, defaulting_field))
122
123     def check_remove(self) -> None:
124         """Test .remove() effects on DB and cache."""
125         id_ = self.default_ids[0]
126         obj = self.checked_class(id_)  # pylint: disable=not-callable
127         with self.assertRaises(HandledException):
128             obj.remove(self.db_conn)
129         obj.save(self.db_conn)
130         obj.remove(self.db_conn)
131         self.check_storage([])
132
133
134 class TestCaseWithServer(TestCaseWithDB):
135     """Module tests against our HTTP server/handler (and database)."""
136
137     def setUp(self) -> None:
138         super().setUp()
139         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
140         self.server_thread = Thread(target=self.httpd.serve_forever)
141         self.server_thread.daemon = True
142         self.server_thread.start()
143         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
144                                    self.httpd.server_address[1])
145
146     def tearDown(self) -> None:
147         self.httpd.shutdown()
148         self.httpd.server_close()
149         self.server_thread.join()
150         super().tearDown()
151
152     def check_redirect(self, target: str) -> None:
153         """Check that self.conn answers with a 302 redirect to target."""
154         response = self.conn.getresponse()
155         self.assertEqual(response.status, 302)
156         self.assertEqual(response.getheader('Location'), target)
157
158     def check_get(self, target: str, expected_code: int) -> None:
159         """Check that a GET to target yields expected_code."""
160         self.conn.request('GET', target)
161         self.assertEqual(self.conn.getresponse().status, expected_code)
162
163     def check_post(self, data: Mapping[str, object], target: str,
164                    expected_code: int, redirect_location: str = '') -> None:
165         """Check that POST of data to target yields expected_code."""
166         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
167         headers = {'Content-Type': 'application/x-www-form-urlencoded',
168                    'Content-Length': str(len(encoded_form_data))}
169         self.conn.request('POST', target,
170                           body=encoded_form_data, headers=headers)
171         if 302 == expected_code:
172             if redirect_location == '':
173                 redirect_location = target
174             self.check_redirect(redirect_location)
175         else:
176             self.assertEqual(self.conn.getresponse().status, expected_code)
177
178     def check_get_defaults(self, path: str) -> None:
179         """Some standard model paths to test."""
180         self.check_get(path, 200)
181         self.check_get(f'{path}?id=', 200)
182         self.check_get(f'{path}?id=foo', 400)
183         self.check_get(f'/{path}?id=0', 500)
184         self.check_get(f'{path}?id=1', 200)
185
186     def post_process(self, id_: int = 1,
187                      form_data: dict[str, Any] | None = None
188                      ) -> dict[str, Any]:
189         """POST basic Process."""
190         if not form_data:
191             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
192         self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
193         return form_data