1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from json import loads as json_loads
6 from urllib.parse import urlencode
8 from os import remove as remove_file
9 from typing import Mapping, Any
10 from plomtask.db import DatabaseFile, DatabaseConnection
11 from plomtask.http import TaskHandler, TaskServer
12 from plomtask.processes import Process, ProcessStep
13 from plomtask.conditions import Condition
14 from plomtask.days import Day
15 from plomtask.todos import Todo
16 from plomtask.exceptions import NotFoundException, HandledException
19 class TestCaseSansDB(TestCase):
20 """Tests requiring no DB setup."""
22 do_id_test: bool = False
23 default_init_args: list[Any] = []
24 versioned_defaults_to_test: dict[str, str | float] = {}
26 def test_id_setting(self) -> None:
27 """Test .id_ being set and its legal range being enforced."""
28 if not self.do_id_test:
30 with self.assertRaises(HandledException):
31 self.checked_class(0, *self.default_init_args)
32 obj = self.checked_class(5, *self.default_init_args)
33 self.assertEqual(obj.id_, 5)
35 def test_versioned_defaults(self) -> None:
36 """Test defaults of VersionedAttributes."""
37 if len(self.versioned_defaults_to_test) == 0:
39 obj = self.checked_class(1, *self.default_init_args)
40 for k, v in self.versioned_defaults_to_test.items():
41 self.assertEqual(getattr(obj, k).newest, v)
44 class TestCaseWithDB(TestCase):
45 """Module tests not requiring DB setup."""
47 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
48 default_init_kwargs: dict[str, Any] = {}
49 test_versioneds: dict[str, type] = {}
51 def setUp(self) -> None:
52 Condition.empty_cache()
55 ProcessStep.empty_cache()
57 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
58 self.db_conn = DatabaseConnection(self.db_file)
60 def tearDown(self) -> None:
62 remove_file(self.db_file.path)
64 def test_saving_and_caching(self) -> None:
65 """Test storage and initialization of instances and attributes."""
66 if not hasattr(self, 'checked_class'):
68 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69 obj = self.checked_class(None, **self.default_init_kwargs)
70 obj.save(self.db_conn)
71 self.assertEqual(obj.id_, 2)
72 for k, v in self.test_versioneds.items():
73 self.check_saving_of_versioned(k, v)
75 def check_storage(self, content: list[Any]) -> None:
76 """Test cache and DB equal content."""
79 expected_cache[item.id_] = item
80 self.assertEqual(self.checked_class.get_cache(), expected_cache)
81 hashes_content = [hash(x) for x in content]
82 db_found: list[Any] = []
84 assert isinstance(item.id_, type(self.default_ids[0]))
85 for row in self.db_conn.row_where(self.checked_class.table_name,
87 db_found += [self.checked_class.from_table_row(self.db_conn,
89 hashes_db_found = [hash(x) for x in db_found]
90 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
92 def check_saving_and_caching(self, **kwargs: Any) -> None:
93 """Test instance.save in its core without relations."""
94 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
95 # check object init itself doesn't store anything yet
96 self.check_storage([])
97 # check saving sets core attributes properly
98 obj.save(self.db_conn)
99 for key, value in kwargs.items():
100 self.assertEqual(getattr(obj, key), value)
101 # check saving stored properly in cache and DB
102 self.check_storage([obj])
104 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
105 """Test owner's versioned attributes."""
106 owner = self.checked_class(None)
107 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
108 attr = getattr(owner, attr_name)
111 owner.save(self.db_conn)
112 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
113 attr = getattr(retrieved, attr_name)
114 self.assertEqual(sorted(attr.history.values()), vals)
116 def check_by_id(self) -> None:
117 """Test .by_id(), including creation."""
118 # check failure if not yet saved
119 id1, id2 = self.default_ids[0], self.default_ids[1]
120 obj = self.checked_class(id1) # pylint: disable=not-callable
121 with self.assertRaises(NotFoundException):
122 self.checked_class.by_id(self.db_conn, id1)
123 # check identity of saved and retrieved
124 obj.save(self.db_conn)
125 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
126 # check create=True acts like normal instantiation (sans saving)
127 by_id_created = self.checked_class.by_id(self.db_conn, id2,
129 # pylint: disable=not-callable
130 self.assertEqual(self.checked_class(id2), by_id_created)
131 self.check_storage([obj])
133 def check_from_table_row(self, *args: Any) -> None:
134 """Test .from_table_row() properly reads in class from DB"""
135 id_ = self.default_ids[0]
136 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
137 obj.save(self.db_conn)
138 assert isinstance(obj.id_, type(self.default_ids[0]))
139 for row in self.db_conn.row_where(self.checked_class.table_name,
141 hash_original = hash(obj)
142 retrieved = self.checked_class.from_table_row(self.db_conn, row)
143 self.assertEqual(hash_original, hash(retrieved))
144 self.assertEqual({retrieved.id_: retrieved},
145 self.checked_class.get_cache())
147 def check_versioned_from_table_row(self, attr_name: str,
148 type_: type) -> None:
149 """Test .from_table_row() reads versioned attributes from DB."""
150 owner = self.checked_class(None)
151 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
152 attr = getattr(owner, attr_name)
155 owner.save(self.db_conn)
156 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
157 retrieved = owner.__class__.from_table_row(self.db_conn, row)
158 attr = getattr(retrieved, attr_name)
159 self.assertEqual(sorted(attr.history.values()), vals)
161 def check_all(self) -> tuple[Any, Any, Any]:
163 # pylint: disable=not-callable
164 item1 = self.checked_class(self.default_ids[0])
165 item2 = self.checked_class(self.default_ids[1])
166 item3 = self.checked_class(self.default_ids[2])
167 # check pre-save .all() returns empty list
168 self.assertEqual(self.checked_class.all(self.db_conn), [])
169 # check that all() shows all saved, but no unsaved items
170 item1.save(self.db_conn)
171 item3.save(self.db_conn)
172 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
173 sorted([item1, item3]))
174 item2.save(self.db_conn)
175 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
176 sorted([item1, item2, item3]))
177 return item1, item2, item3
179 def check_singularity(self, defaulting_field: str,
180 non_default_value: Any, *args: Any) -> None:
181 """Test pointers made for single object keep pointing to it."""
182 id1 = self.default_ids[0]
183 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
184 obj.save(self.db_conn)
185 setattr(obj, defaulting_field, non_default_value)
186 retrieved = self.checked_class.by_id(self.db_conn, id1)
187 self.assertEqual(non_default_value,
188 getattr(retrieved, defaulting_field))
190 def check_versioned_singularity(self) -> None:
191 """Test singularity of VersionedAttributes on saving (with .title)."""
192 obj = self.checked_class(None) # pylint: disable=not-callable
193 obj.save(self.db_conn)
194 assert isinstance(obj.id_, int)
195 obj.title.set('named')
196 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
197 self.assertEqual(obj.title.history, retrieved.title.history)
199 def check_remove(self, *args: Any) -> None:
200 """Test .remove() effects on DB and cache."""
201 id_ = self.default_ids[0]
202 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
203 with self.assertRaises(HandledException):
204 obj.remove(self.db_conn)
205 obj.save(self.db_conn)
206 obj.remove(self.db_conn)
207 self.check_storage([])
210 class TestCaseWithServer(TestCaseWithDB):
211 """Module tests against our HTTP server/handler (and database)."""
213 def setUp(self) -> None:
215 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
216 self.server_thread = Thread(target=self.httpd.serve_forever)
217 self.server_thread.daemon = True
218 self.server_thread.start()
219 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
220 self.httpd.server_address[1])
221 self.httpd.set_json_mode()
223 def tearDown(self) -> None:
224 self.httpd.shutdown()
225 self.httpd.server_close()
226 self.server_thread.join()
229 def check_redirect(self, target: str) -> None:
230 """Check that self.conn answers with a 302 redirect to target."""
231 response = self.conn.getresponse()
232 self.assertEqual(response.status, 302)
233 self.assertEqual(response.getheader('Location'), target)
235 def check_get(self, target: str, expected_code: int) -> None:
236 """Check that a GET to target yields expected_code."""
237 self.conn.request('GET', target)
238 self.assertEqual(self.conn.getresponse().status, expected_code)
240 def check_post(self, data: Mapping[str, object], target: str,
241 expected_code: int, redirect_location: str = '') -> None:
242 """Check that POST of data to target yields expected_code."""
243 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
244 headers = {'Content-Type': 'application/x-www-form-urlencoded',
245 'Content-Length': str(len(encoded_form_data))}
246 self.conn.request('POST', target,
247 body=encoded_form_data, headers=headers)
248 if 302 == expected_code:
249 if redirect_location == '':
250 redirect_location = target
251 self.check_redirect(redirect_location)
253 self.assertEqual(self.conn.getresponse().status, expected_code)
255 def check_get_defaults(self, path: str) -> None:
256 """Some standard model paths to test."""
257 self.check_get(path, 200)
258 self.check_get(f'{path}?id=', 200)
259 self.check_get(f'{path}?id=foo', 400)
260 self.check_get(f'/{path}?id=0', 500)
261 self.check_get(f'{path}?id=1', 200)
263 def post_process(self, id_: int = 1,
264 form_data: dict[str, Any] | None = None
266 """POST basic Process."""
268 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
269 self.check_post(form_data, f'/process?id={id_}', 302,
270 f'/process?id={id_}')
273 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
274 """Compare JSON on GET path with expected.
276 To simplify comparison of VersionedAttribute histories, transforms
277 keys under "history"-named dicts into bracketed integer strings
278 counting upwards in chronology.
280 def rewrite_history_keys_in(item: Any) -> Any:
281 if isinstance(item, dict):
282 if 'history' in item.keys():
283 vals = item['history'].values()
285 for i, val in enumerate(vals):
286 history[f'[{i}]'] = val
287 item['history'] = history
288 for k in list(item.keys()):
289 rewrite_history_keys_in(item[k])
290 elif isinstance(item, list):
291 item[:] = [rewrite_history_keys_in(i) for i in item]
293 self.conn.request('GET', path)
294 response = self.conn.getresponse()
295 self.assertEqual(response.status, 200)
296 retrieved = json_loads(response.read().decode())
297 rewrite_history_keys_in(retrieved)
298 self.assertEqual(expected, retrieved)