1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
20 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
21 def wrapper(self: TestCase) -> None:
22 if hasattr(self, 'checked_class'):
27 class TestCaseSansDB(TestCase):
28 """Tests requiring no DB setup."""
30 default_init_args: list[Any] = []
31 versioned_defaults_to_test: dict[str, str | float] = {}
35 @_within_checked_class
36 def test_id_validation(self) -> None:
37 """Test .id_ validation/setting."""
38 for id_ in self.illegal_ids:
39 with self.assertRaises(HandledException):
40 self.checked_class(id_, *self.default_init_args)
41 for id_ in self.legal_ids:
42 obj = self.checked_class(id_, *self.default_init_args)
43 self.assertEqual(obj.id_, id_)
45 @_within_checked_class
46 def test_versioned_defaults(self) -> None:
47 """Test defaults of VersionedAttributes."""
48 id_ = self.legal_ids[0]
49 obj = self.checked_class(id_, *self.default_init_args)
50 for k, v in self.versioned_defaults_to_test.items():
51 self.assertEqual(getattr(obj, k).newest, v)
54 class TestCaseWithDB(TestCase):
55 """Module tests not requiring DB setup."""
57 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
58 default_init_kwargs: dict[str, Any] = {}
59 test_versioneds: dict[str, type] = {}
61 def setUp(self) -> None:
62 Condition.empty_cache()
65 ProcessStep.empty_cache()
67 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
68 self.db_conn = DatabaseConnection(self.db_file)
70 def tearDown(self) -> None:
72 remove_file(self.db_file.path)
74 def _load_from_db(self, id_: int | str) -> list[object]:
75 db_found: list[object] = []
76 for row in self.db_conn.row_where(self.checked_class.table_name,
78 db_found += [self.checked_class.from_table_row(self.db_conn,
82 def _change_obj(self, obj: object) -> str:
83 attr_name: str = self.checked_class.to_save[-1]
84 attr = getattr(obj, attr_name)
85 new_attr: str | int | float | bool
86 if isinstance(attr, (int, float)):
88 elif isinstance(attr, str):
90 elif isinstance(attr, bool):
92 setattr(obj, attr_name, new_attr)
95 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
96 """Test both cache and DB equal content."""
99 expected_cache[item.id_] = item
100 self.assertEqual(self.checked_class.get_cache(), expected_cache)
101 hashes_content = [hash(x) for x in content]
102 db_found: list[Any] = []
104 assert isinstance(item.id_, type(self.default_ids[0]))
105 db_found += self._load_from_db(item.id_)
106 hashes_db_found = [hash(x) for x in db_found]
107 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
109 @_within_checked_class
110 def test_saving_versioned(self) -> None:
111 """Test storage and initialization of versioned attributes."""
112 def retrieve_attr_vals() -> list[object]:
113 attr_vals_saved: list[object] = []
114 assert hasattr(retrieved, 'id_')
115 for row in self.db_conn.row_where(attr.table_name, 'parent',
117 attr_vals_saved += [row[2]]
118 return attr_vals_saved
119 for attr_name, type_ in self.test_versioneds.items():
120 # fail saving attributes on non-saved owner
121 owner = self.checked_class(None, **self.default_init_kwargs)
122 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
123 attr = getattr(owner, attr_name)
126 with self.assertRaises(NotFoundException):
127 attr.save(self.db_conn)
128 owner.save(self.db_conn)
129 # check stored attribute is as expected
130 retrieved = self._load_from_db(owner.id_)[0]
131 attr = getattr(retrieved, attr_name)
132 self.assertEqual(sorted(attr.history.values()), vals)
133 # check owner.save() created entries in attr table
134 attr_vals_saved = retrieve_attr_vals()
135 self.assertEqual(vals, attr_vals_saved)
136 # check setting new val to attr inconsequential to DB without save
138 attr_vals_saved = retrieve_attr_vals()
139 self.assertEqual(vals, attr_vals_saved)
140 # check save finally adds new val
141 attr.save(self.db_conn)
142 attr_vals_saved = retrieve_attr_vals()
143 self.assertEqual(vals + [vals[0]], attr_vals_saved)
145 @_within_checked_class
146 def test_saving_and_caching(self) -> None:
147 """Test effects of .cache() and .save()."""
148 id1 = self.default_ids[0]
149 # check failure to cache without ID (if None-ID input possible)
150 if isinstance(id1, int):
151 obj0 = self.checked_class(None, **self.default_init_kwargs)
152 with self.assertRaises(HandledException):
154 # check mere object init itself doesn't even store in cache
155 obj1 = self.checked_class(id1, **self.default_init_kwargs)
156 self.assertEqual(self.checked_class.get_cache(), {})
157 # check .cache() fills cache, but not DB
159 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
160 db_found = self._load_from_db(id1)
161 self.assertEqual(db_found, [])
162 # check .save() sets ID (for int IDs), updates cache, and fills DB
163 # (expect ID to be set to id1, despite obj1 already having that as ID:
164 # it's generated by cursor.lastrowid on the DB table, and with obj1
165 # not written there, obj2 should get it first!)
166 id_input = None if isinstance(id1, int) else id1
167 obj2 = self.checked_class(id_input, **self.default_init_kwargs)
168 obj2.save(self.db_conn)
169 obj2_hash = hash(obj2)
170 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
171 db_found += self._load_from_db(id1)
172 self.assertEqual([hash(o) for o in db_found], [obj2_hash])
173 # check we cannot overwrite obj2 with obj1 despite its same ID,
174 # since it has disappeared now
175 with self.assertRaises(HandledException):
176 obj1.save(self.db_conn)
178 @_within_checked_class
179 def test_by_id(self) -> None:
181 id1, id2, _ = self.default_ids
182 # check failure if not yet saved
183 obj1 = self.checked_class(id1, **self.default_init_kwargs)
184 with self.assertRaises(NotFoundException):
185 self.checked_class.by_id(self.db_conn, id1)
186 # check identity of cached and retrieved
188 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
189 # check identity of saved and retrieved
190 obj2 = self.checked_class(id2, **self.default_init_kwargs)
191 obj2.save(self.db_conn)
192 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
194 @_within_checked_class
195 def test_by_id_or_create(self) -> None:
196 """Test .by_id_or_create."""
197 # check .by_id_or_create fails if wrong class
198 if not self.checked_class.can_create_by_id:
199 with self.assertRaises(HandledException):
200 self.checked_class.by_id_or_create(self.db_conn, None)
202 # check ID input of None creates, on saving, ID=1,2,… for int IDs
203 if isinstance(self.default_ids[0], int):
205 item = self.checked_class.by_id_or_create(self.db_conn, None)
206 self.assertEqual(item.id_, None)
207 item.save(self.db_conn)
208 self.assertEqual(item.id_, n+1)
209 # check .by_id_or_create acts like normal instantiation (sans saving)
210 id_ = self.default_ids[2]
211 item = self.checked_class.by_id_or_create(self.db_conn, id_)
212 self.assertEqual(item.id_, id_)
213 with self.assertRaises(NotFoundException):
214 self.checked_class.by_id(self.db_conn, item.id_)
215 self.assertEqual(self.checked_class(item.id_), item)
217 @_within_checked_class
218 def test_from_table_row(self) -> None:
219 """Test .from_table_row() properly reads in class directly from DB."""
220 id_ = self.default_ids[0]
221 obj = self.checked_class(id_, **self.default_init_kwargs)
222 obj.save(self.db_conn)
223 assert isinstance(obj.id_, type(id_))
224 for row in self.db_conn.row_where(self.checked_class.table_name,
226 # check .from_table_row reproduces state saved, no matter if obj
227 # later changed (with caching even)
228 hash_original = hash(obj)
229 attr_name = self._change_obj(obj)
231 to_cmp = getattr(obj, attr_name)
232 retrieved = self.checked_class.from_table_row(self.db_conn, row)
233 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
234 self.assertEqual(hash_original, hash(retrieved))
235 # check cache contains what .from_table_row just produced
236 self.assertEqual({retrieved.id_: retrieved},
237 self.checked_class.get_cache())
238 # check .from_table_row also reads versioned attributes from DB
239 for attr_name, type_ in self.test_versioneds.items():
240 owner = self.checked_class(None)
241 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
242 attr = getattr(owner, attr_name)
245 owner.save(self.db_conn)
246 for row in self.db_conn.row_where(owner.table_name, 'id',
248 retrieved = owner.__class__.from_table_row(self.db_conn, row)
249 attr = getattr(retrieved, attr_name)
250 self.assertEqual(sorted(attr.history.values()), vals)
252 @_within_checked_class
253 def test_all(self) -> None:
254 """Test .all() and its relation to cache and savings."""
255 id_1, id_2, id_3 = self.default_ids
256 item1 = self.checked_class(id_1, **self.default_init_kwargs)
257 item2 = self.checked_class(id_2, **self.default_init_kwargs)
258 item3 = self.checked_class(id_3, **self.default_init_kwargs)
259 # check .all() returns empty list on un-cached items
260 self.assertEqual(self.checked_class.all(self.db_conn), [])
261 # check that all() shows only cached/saved items
263 item3.save(self.db_conn)
264 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
265 sorted([item1, item3]))
266 item2.save(self.db_conn)
267 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
268 sorted([item1, item2, item3]))
270 @_within_checked_class
271 def test_singularity(self) -> None:
272 """Test pointers made for single object keep pointing to it."""
273 id1 = self.default_ids[0]
274 obj = self.checked_class(id1, **self.default_init_kwargs)
275 obj.save(self.db_conn)
276 # change object, expect retrieved through .by_id to carry change
277 attr_name = self._change_obj(obj)
278 new_attr = getattr(obj, attr_name)
279 retrieved = self.checked_class.by_id(self.db_conn, id1)
280 self.assertEqual(new_attr, getattr(retrieved, attr_name))
282 @_within_checked_class
283 def test_versioned_singularity_title(self) -> None:
284 """Test singularity of VersionedAttributes on saving (with .title)."""
285 if 'title' in self.test_versioneds:
286 obj = self.checked_class(None)
287 obj.save(self.db_conn)
288 assert isinstance(obj.id_, int)
289 # change obj, expect retrieved through .by_id to carry change
290 obj.title.set('named')
291 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
292 self.assertEqual(obj.title.history, retrieved.title.history)
294 @_within_checked_class
295 def test_remove(self) -> None:
296 """Test .remove() effects on DB and cache."""
297 id_ = self.default_ids[0]
298 obj = self.checked_class(id_, **self.default_init_kwargs)
299 # check removal only works after saving
300 with self.assertRaises(HandledException):
301 obj.remove(self.db_conn)
302 obj.save(self.db_conn)
303 obj.remove(self.db_conn)
304 # check access to obj fails after removal
305 with self.assertRaises(HandledException):
307 # check DB and cache now empty
308 self.check_identity_with_cache_and_db([])
311 class TestCaseWithServer(TestCaseWithDB):
312 """Module tests against our HTTP server/handler (and database)."""
314 def setUp(self) -> None:
316 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
317 self.server_thread = Thread(target=self.httpd.serve_forever)
318 self.server_thread.daemon = True
319 self.server_thread.start()
320 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
321 self.httpd.server_address[1])
322 self.httpd.set_json_mode()
324 def tearDown(self) -> None:
325 self.httpd.shutdown()
326 self.httpd.server_close()
327 self.server_thread.join()
331 def proc_as_dict(id_: int = 1,
333 description: str = '',
335 enables: None | list[dict[str, object]] = None,
336 disables: None | list[dict[str, object]] = None,
337 conditions: None | list[dict[str, object]] = None,
338 blockers: None | list[dict[str, object]] = None
339 ) -> dict[str, object]:
340 """Return JSON of Process to expect."""
341 # pylint: disable=too-many-arguments
343 'calendarize': False,
344 'suppressed_steps': [],
345 'explicit_steps': [],
348 'description': {0: description},
349 'effort': {0: effort}
351 'conditions': conditions if conditions else [],
352 'disables': disables if disables else [],
353 'enables': enables if enables else [],
354 'blockers': blockers if blockers else []}
357 def check_redirect(self, target: str) -> None:
358 """Check that self.conn answers with a 302 redirect to target."""
359 response = self.conn.getresponse()
360 self.assertEqual(response.status, 302)
361 self.assertEqual(response.getheader('Location'), target)
363 def check_get(self, target: str, expected_code: int) -> None:
364 """Check that a GET to target yields expected_code."""
365 self.conn.request('GET', target)
366 self.assertEqual(self.conn.getresponse().status, expected_code)
368 def check_post(self, data: Mapping[str, object], target: str,
369 expected_code: int, redirect_location: str = '') -> None:
370 """Check that POST of data to target yields expected_code."""
371 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
372 headers = {'Content-Type': 'application/x-www-form-urlencoded',
373 'Content-Length': str(len(encoded_form_data))}
374 self.conn.request('POST', target,
375 body=encoded_form_data, headers=headers)
376 if 302 == expected_code:
377 if redirect_location == '':
378 redirect_location = target
379 self.check_redirect(redirect_location)
381 self.assertEqual(self.conn.getresponse().status, expected_code)
383 def check_get_defaults(self, path: str) -> None:
384 """Some standard model paths to test."""
385 self.check_get(path, 200)
386 self.check_get(f'{path}?id=', 200)
387 self.check_get(f'{path}?id=foo', 400)
388 self.check_get(f'/{path}?id=0', 500)
389 self.check_get(f'{path}?id=1', 200)
391 def post_process(self, id_: int = 1,
392 form_data: dict[str, Any] | None = None
394 """POST basic Process."""
396 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
397 self.check_post(form_data, f'/process?id={id_}', 302,
398 f'/process?id={id_}')
401 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
402 """Compare JSON on GET path with expected.
404 To simplify comparison of VersionedAttribute histories, transforms
405 timestamp keys of VersionedAttribute history keys into integers
406 counting chronologically forward from 0.
408 def rewrite_history_keys_in(item: Any) -> Any:
409 if isinstance(item, dict):
410 if '_versioned' in item.keys():
411 for k in item['_versioned']:
412 vals = item['_versioned'][k].values()
414 for i, val in enumerate(vals):
416 item['_versioned'][k] = history
417 for k in list(item.keys()):
418 rewrite_history_keys_in(item[k])
419 elif isinstance(item, list):
420 item[:] = [rewrite_history_keys_in(i) for i in item]
422 self.conn.request('GET', path)
423 response = self.conn.getresponse()
424 self.assertEqual(response.status, 200)
425 retrieved = json_loads(response.read().decode())
426 rewrite_history_keys_in(retrieved)
427 self.assertEqual(expected, retrieved)