1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
20 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
21 def wrapper(self: TestCase) -> None:
22 if hasattr(self, 'checked_class'):
27 class TestCaseSansDB(TestCase):
28 """Tests requiring no DB setup."""
30 default_init_args: list[Any] = []
31 versioned_defaults_to_test: dict[str, str | float] = {}
35 @_within_checked_class
36 def test_id_validation(self) -> None:
37 """Test .id_ validation/setting."""
38 for id_ in self.illegal_ids:
39 with self.assertRaises(HandledException):
40 self.checked_class(id_, *self.default_init_args)
41 for id_ in self.legal_ids:
42 obj = self.checked_class(id_, *self.default_init_args)
43 self.assertEqual(obj.id_, id_)
45 @_within_checked_class
46 def test_versioned_defaults(self) -> None:
47 """Test defaults of VersionedAttributes."""
48 id_ = self.legal_ids[0]
49 obj = self.checked_class(id_, *self.default_init_args)
50 for k, v in self.versioned_defaults_to_test.items():
51 self.assertEqual(getattr(obj, k).newest, v)
54 class TestCaseWithDB(TestCase):
55 """Module tests not requiring DB setup."""
57 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
58 default_init_kwargs: dict[str, Any] = {}
59 test_versioneds: dict[str, type] = {}
61 def setUp(self) -> None:
62 Condition.empty_cache()
65 ProcessStep.empty_cache()
67 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
68 self.db_conn = DatabaseConnection(self.db_file)
70 def tearDown(self) -> None:
72 remove_file(self.db_file.path)
74 def _load_from_db(self, id_: int | str) -> list[object]:
75 db_found: list[object] = []
76 for row in self.db_conn.row_where(self.checked_class.table_name,
78 db_found += [self.checked_class.from_table_row(self.db_conn,
82 def _change_obj(self, obj: object) -> str:
83 attr_name: str = self.checked_class.to_save[-1]
84 attr = getattr(obj, attr_name)
85 new_attr: str | int | float | bool
86 if isinstance(attr, (int, float)):
88 elif isinstance(attr, str):
90 elif isinstance(attr, bool):
92 setattr(obj, attr_name, new_attr)
95 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
96 """Test both cache and DB equal content."""
99 expected_cache[item.id_] = item
100 self.assertEqual(self.checked_class.get_cache(), expected_cache)
101 hashes_content = [hash(x) for x in content]
102 db_found: list[Any] = []
104 assert isinstance(item.id_, type(self.default_ids[0]))
105 db_found += self._load_from_db(item.id_)
106 hashes_db_found = [hash(x) for x in db_found]
107 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
109 @_within_checked_class
110 def test_saving_versioned(self) -> None:
111 """Test storage and initialization of versioned attributes."""
112 def retrieve_attr_vals() -> list[object]:
113 attr_vals_saved: list[object] = []
114 assert hasattr(retrieved, 'id_')
115 for row in self.db_conn.row_where(attr.table_name, 'parent',
117 attr_vals_saved += [row[2]]
118 return attr_vals_saved
119 for attr_name, type_ in self.test_versioneds.items():
120 # fail saving attributes on non-saved owner
121 owner = self.checked_class(None, **self.default_init_kwargs)
122 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
123 attr = getattr(owner, attr_name)
126 with self.assertRaises(NotFoundException):
127 attr.save(self.db_conn)
128 owner.save(self.db_conn)
129 # check stored attribute is as expected
130 retrieved = self._load_from_db(owner.id_)[0]
131 attr = getattr(retrieved, attr_name)
132 self.assertEqual(sorted(attr.history.values()), vals)
133 # check owner.save() created entries in attr table
134 attr_vals_saved = retrieve_attr_vals()
135 self.assertEqual(vals, attr_vals_saved)
136 # check setting new val to attr inconsequential to DB without save
138 attr_vals_saved = retrieve_attr_vals()
139 self.assertEqual(vals, attr_vals_saved)
140 # check save finally adds new val
141 attr.save(self.db_conn)
142 attr_vals_saved = retrieve_attr_vals()
143 self.assertEqual(vals + [vals[0]], attr_vals_saved)
145 @_within_checked_class
146 def test_saving_and_caching(self) -> None:
147 """Test effects of .cache() and .save()."""
148 id1 = self.default_ids[0]
149 # check failure to cache without ID (if None-ID input possible)
150 if isinstance(id1, int):
151 obj0 = self.checked_class(None, **self.default_init_kwargs)
152 with self.assertRaises(HandledException):
154 # check mere object init itself doesn't even store in cache
155 obj1 = self.checked_class(id1, **self.default_init_kwargs)
156 self.assertEqual(self.checked_class.get_cache(), {})
157 # check .cache() fills cache, but not DB
159 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
160 db_found = self._load_from_db(id1)
161 self.assertEqual(db_found, [])
162 # check .save() sets ID (for int IDs), updates cache, and fills DB
163 # (expect ID to be set to id1, despite obj1 already having that as ID:
164 # it's generated by cursor.lastrowid on the DB table, and with obj1
165 # not written there, obj2 should get it first!)
166 id_input = None if isinstance(id1, int) else id1
167 obj2 = self.checked_class(id_input, **self.default_init_kwargs)
168 obj2.save(self.db_conn)
169 obj2_hash = hash(obj2)
170 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
171 db_found += self._load_from_db(id1)
172 self.assertEqual([hash(o) for o in db_found], [obj2_hash])
173 # check we cannot overwrite obj2 with obj1 despite its same ID,
174 # since it has disappeared now
175 with self.assertRaises(HandledException):
176 obj1.save(self.db_conn)
178 @_within_checked_class
179 def test_by_id(self) -> None:
181 id1, id2, _ = self.default_ids
182 # check failure if not yet saved
183 obj1 = self.checked_class(id1, **self.default_init_kwargs)
184 with self.assertRaises(NotFoundException):
185 self.checked_class.by_id(self.db_conn, id1)
186 # check identity of cached and retrieved
188 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
189 # check identity of saved and retrieved
190 obj2 = self.checked_class(id2, **self.default_init_kwargs)
191 obj2.save(self.db_conn)
192 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
194 @_within_checked_class
195 def test_by_id_or_create(self) -> None:
196 """Test .by_id_or_create."""
197 # check .by_id_or_create fails if wrong class
198 if not self.checked_class.can_create_by_id:
199 with self.assertRaises(HandledException):
200 self.checked_class.by_id_or_create(self.db_conn, None)
202 # check ID input of None creates, on saving, ID=1,2,… for int IDs
203 if isinstance(self.default_ids[0], int):
205 item = self.checked_class.by_id_or_create(self.db_conn, None)
206 self.assertEqual(item.id_, None)
207 item.save(self.db_conn)
208 self.assertEqual(item.id_, n+1)
209 # check .by_id_or_create acts like normal instantiation (sans saving)
210 id_ = self.default_ids[2]
211 item = self.checked_class.by_id_or_create(self.db_conn, id_)
212 self.assertEqual(item.id_, id_)
213 with self.assertRaises(NotFoundException):
214 self.checked_class.by_id(self.db_conn, item.id_)
215 self.assertEqual(self.checked_class(item.id_), item)
217 @_within_checked_class
218 def test_from_table_row(self) -> None:
219 """Test .from_table_row() properly reads in class directly from DB."""
220 id_ = self.default_ids[0]
221 obj = self.checked_class(id_, **self.default_init_kwargs)
222 obj.save(self.db_conn)
223 assert isinstance(obj.id_, type(id_))
224 for row in self.db_conn.row_where(self.checked_class.table_name,
226 # check .from_table_row reproduces state saved, no matter if obj
227 # later changed (with caching even)
228 hash_original = hash(obj)
229 attr_name = self._change_obj(obj)
231 to_cmp = getattr(obj, attr_name)
232 retrieved = self.checked_class.from_table_row(self.db_conn, row)
233 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
234 self.assertEqual(hash_original, hash(retrieved))
235 # check cache contains what .from_table_row just produced
236 self.assertEqual({retrieved.id_: retrieved},
237 self.checked_class.get_cache())
238 # check .from_table_row also reads versioned attributes from DB
239 for attr_name, type_ in self.test_versioneds.items():
240 owner = self.checked_class(None)
241 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
242 attr = getattr(owner, attr_name)
245 owner.save(self.db_conn)
246 for row in self.db_conn.row_where(owner.table_name, 'id',
248 retrieved = owner.__class__.from_table_row(self.db_conn, row)
249 attr = getattr(retrieved, attr_name)
250 self.assertEqual(sorted(attr.history.values()), vals)
252 @_within_checked_class
253 def test_all(self) -> None:
254 """Test .all() and its relation to cache and savings."""
255 id_1, id_2, id_3 = self.default_ids
256 item1 = self.checked_class(id_1, **self.default_init_kwargs)
257 item2 = self.checked_class(id_2, **self.default_init_kwargs)
258 item3 = self.checked_class(id_3, **self.default_init_kwargs)
259 # check .all() returns empty list on un-cached items
260 self.assertEqual(self.checked_class.all(self.db_conn), [])
261 # check that all() shows only cached/saved items
263 item3.save(self.db_conn)
264 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
265 sorted([item1, item3]))
266 item2.save(self.db_conn)
267 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
268 sorted([item1, item2, item3]))
270 @_within_checked_class
271 def test_singularity(self) -> None:
272 """Test pointers made for single object keep pointing to it."""
273 id1 = self.default_ids[0]
274 obj = self.checked_class(id1, **self.default_init_kwargs)
275 obj.save(self.db_conn)
276 # change object, expect retrieved through .by_id to carry change
277 attr_name = self._change_obj(obj)
278 new_attr = getattr(obj, attr_name)
279 retrieved = self.checked_class.by_id(self.db_conn, id1)
280 self.assertEqual(new_attr, getattr(retrieved, attr_name))
282 @_within_checked_class
283 def test_versioned_singularity_title(self) -> None:
284 """Test singularity of VersionedAttributes on saving (with .title)."""
285 if 'title' in self.test_versioneds:
286 obj = self.checked_class(None)
287 obj.save(self.db_conn)
288 assert isinstance(obj.id_, int)
289 # change obj, expect retrieved through .by_id to carry change
290 obj.title.set('named')
291 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
292 self.assertEqual(obj.title.history, retrieved.title.history)
294 @_within_checked_class
295 def test_remove(self) -> None:
296 """Test .remove() effects on DB and cache."""
297 id_ = self.default_ids[0]
298 obj = self.checked_class(id_, **self.default_init_kwargs)
299 # check removal only works after saving
300 with self.assertRaises(HandledException):
301 obj.remove(self.db_conn)
302 obj.save(self.db_conn)
303 obj.remove(self.db_conn)
304 # check access to obj fails after removal
305 with self.assertRaises(HandledException):
307 # check DB and cache now empty
308 self.check_identity_with_cache_and_db([])
311 class TestCaseWithServer(TestCaseWithDB):
312 """Module tests against our HTTP server/handler (and database)."""
314 def setUp(self) -> None:
316 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
317 self.server_thread = Thread(target=self.httpd.serve_forever)
318 self.server_thread.daemon = True
319 self.server_thread.start()
320 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
321 self.httpd.server_address[1])
322 self.httpd.set_json_mode()
324 def tearDown(self) -> None:
325 self.httpd.shutdown()
326 self.httpd.server_close()
327 self.server_thread.join()
331 def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
332 """Return list of only 'id' fields of items."""
335 assert isinstance(item['id'], (int, str))
336 id_list += [item['id']]
340 def as_refs(items: list[dict[str, object]]
341 ) -> dict[str, dict[str, object]]:
342 """Return dictionary of items by their 'id' fields."""
345 refs[str(item['id'])] = item
349 def cond_as_dict(id_: int = 1,
350 is_active: bool = False,
351 titles: None | list[str] = None,
352 descriptions: None | list[str] = None
353 ) -> dict[str, object]:
354 """Return JSON of Condition to expect."""
356 'is_active': is_active,
360 titles = titles if titles else []
361 descriptions = descriptions if descriptions else []
362 assert isinstance(d['_versioned'], dict)
363 for i, title in enumerate(titles):
364 d['_versioned']['title'][i] = title
365 for i, description in enumerate(descriptions):
366 d['_versioned']['description'][i] = description
370 def proc_as_dict(id_: int = 1,
372 description: str = '',
374 enables: None | list[dict[str, object]] = None,
375 disables: None | list[dict[str, object]] = None,
376 conditions: None | list[dict[str, object]] = None,
377 blockers: None | list[dict[str, object]] = None
378 ) -> dict[str, object]:
379 """Return JSON of Process to expect."""
380 # pylint: disable=too-many-arguments
381 as_id_list = TestCaseWithServer.as_id_list
383 'calendarize': False,
384 'suppressed_steps': [],
385 'explicit_steps': [],
388 'description': {0: description},
389 'effort': {0: effort}},
390 'conditions': as_id_list(conditions) if conditions else [],
391 'disables': as_id_list(disables) if disables else [],
392 'enables': as_id_list(enables) if enables else [],
393 'blockers': as_id_list(blockers) if blockers else []}
396 def check_redirect(self, target: str) -> None:
397 """Check that self.conn answers with a 302 redirect to target."""
398 response = self.conn.getresponse()
399 self.assertEqual(response.status, 302)
400 self.assertEqual(response.getheader('Location'), target)
402 def check_get(self, target: str, expected_code: int) -> None:
403 """Check that a GET to target yields expected_code."""
404 self.conn.request('GET', target)
405 self.assertEqual(self.conn.getresponse().status, expected_code)
407 def check_post(self, data: Mapping[str, object], target: str,
408 expected_code: int, redirect_location: str = '') -> None:
409 """Check that POST of data to target yields expected_code."""
410 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
411 headers = {'Content-Type': 'application/x-www-form-urlencoded',
412 'Content-Length': str(len(encoded_form_data))}
413 self.conn.request('POST', target,
414 body=encoded_form_data, headers=headers)
415 if 302 == expected_code:
416 if redirect_location == '':
417 redirect_location = target
418 self.check_redirect(redirect_location)
420 self.assertEqual(self.conn.getresponse().status, expected_code)
422 def check_get_defaults(self, path: str) -> None:
423 """Some standard model paths to test."""
424 self.check_get(path, 200)
425 self.check_get(f'{path}?id=', 200)
426 self.check_get(f'{path}?id=foo', 400)
427 self.check_get(f'/{path}?id=0', 500)
428 self.check_get(f'{path}?id=1', 200)
430 def post_process(self, id_: int = 1,
431 form_data: dict[str, Any] | None = None
433 """POST basic Process."""
435 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
436 self.check_post(form_data, f'/process?id={id_}', 302,
437 f'/process?id={id_}')
440 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
441 """Compare JSON on GET path with expected.
443 To simplify comparison of VersionedAttribute histories, transforms
444 timestamp keys of VersionedAttribute history keys into integers
445 counting chronologically forward from 0.
447 def rewrite_history_keys_in(item: Any) -> Any:
448 if isinstance(item, dict):
449 if '_versioned' in item.keys():
450 for k in item['_versioned']:
451 vals = item['_versioned'][k].values()
453 for i, val in enumerate(vals):
455 item['_versioned'][k] = history
456 for k in list(item.keys()):
457 rewrite_history_keys_in(item[k])
458 elif isinstance(item, list):
459 item[:] = [rewrite_history_keys_in(i) for i in item]
461 self.conn.request('GET', path)
462 response = self.conn.getresponse()
463 self.assertEqual(response.status, 200)
464 retrieved = json_loads(response.read().decode())
465 rewrite_history_keys_in(retrieved)
466 self.assertEqual(expected, retrieved)