1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from datetime import datetime, timedelta
9 from json import loads as json_loads
10 from urllib.parse import urlencode
11 from uuid import uuid4
12 from os import remove as remove_file
13 from plomtask.db import DatabaseFile, DatabaseConnection
14 from plomtask.http import TaskHandler, TaskServer
15 from plomtask.processes import Process, ProcessStep
16 from plomtask.conditions import Condition
17 from plomtask.days import Day
18 from plomtask.dating import DATE_FORMAT
19 from plomtask.todos import Todo
20 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
21 from plomtask.exceptions import NotFoundException, HandledException
24 VERSIONED_VALS: dict[str,
25 list[str] | list[float]] = {'str': ['A', 'B'],
29 class TestCaseAugmented(TestCase):
30 """Tester core providing helpful basic internal decorators and methods."""
32 default_init_kwargs: dict[str, Any] = {}
35 def _run_if_checked_class(f: Callable[..., None]) -> Callable[..., None]:
36 def wrapper(self: TestCase) -> None:
37 if hasattr(self, 'checked_class'):
42 def _run_on_versioned_attributes(cls,
43 f: Callable[..., None]
44 ) -> Callable[..., None]:
45 @cls._run_if_checked_class
46 def wrapper(self: TestCase) -> None:
47 assert isinstance(self, TestCaseAugmented)
48 for attr_name in self.checked_class.to_save_versioned():
49 default = self.checked_class.versioned_defaults[attr_name]
50 owner = self.checked_class(None, **self.default_init_kwargs)
51 attr = getattr(owner, attr_name)
52 to_set = VERSIONED_VALS[attr.value_type_name]
53 f(self, owner, attr_name, attr, default, to_set)
57 def _make_from_defaults(cls, id_: float | str | None) -> Any:
58 return cls.checked_class(id_, **cls.default_init_kwargs)
61 class TestCaseSansDB(TestCaseAugmented):
62 """Tests requiring no DB setup."""
63 legal_ids: list[str] | list[int] = [1, 5]
64 illegal_ids: list[str] | list[int] = [0]
66 @TestCaseAugmented._run_if_checked_class
67 def test_id_validation(self) -> None:
68 """Test .id_ validation/setting."""
69 for id_ in self.illegal_ids:
70 with self.assertRaises(HandledException):
71 self._make_from_defaults(id_)
72 for id_ in self.legal_ids:
73 obj = self._make_from_defaults(id_)
74 self.assertEqual(obj.id_, id_)
76 @TestCaseAugmented._run_on_versioned_attributes
77 def test_versioned_set(self,
80 attr: VersionedAttribute,
82 to_set: list[str | float]
84 """Test VersionedAttribute.set() behaves as expected."""
86 self.assertEqual(list(attr.history.values()), [default])
87 # check same value does not get set twice in a row,
88 # and that not even its timestamp get updated
89 timestamp = list(attr.history.keys())[0]
91 self.assertEqual(list(attr.history.values()), [default])
92 self.assertEqual(list(attr.history.keys())[0], timestamp)
93 # check that different value _will_ be set/added
95 timesorted_vals = [attr.history[t] for
96 t in sorted(attr.history.keys())]
97 expected = [default, to_set[0]]
98 self.assertEqual(timesorted_vals, expected)
99 # check that a previously used value can be set if not most recent
101 timesorted_vals = [attr.history[t] for
102 t in sorted(attr.history.keys())]
103 expected = [default, to_set[0], default]
104 self.assertEqual(timesorted_vals, expected)
105 # again check for same value not being set twice in a row, even for
108 timesorted_vals = [attr.history[t] for
109 t in sorted(attr.history.keys())]
110 expected = [default, to_set[0], default, to_set[1]]
111 self.assertEqual(timesorted_vals, expected)
113 self.assertEqual(timesorted_vals, expected)
115 @TestCaseAugmented._run_on_versioned_attributes
116 def test_versioned_newest(self,
119 attr: VersionedAttribute,
120 default: str | float,
121 to_set: list[str | float]
123 """Test VersionedAttribute.newest."""
124 # check .newest on empty history returns .default
125 self.assertEqual(attr.newest, default)
126 # check newest element always returned
127 for v in [to_set[0], to_set[1]]:
129 self.assertEqual(attr.newest, v)
130 # check newest element returned even if also early value
132 self.assertEqual(attr.newest, default)
134 @TestCaseAugmented._run_on_versioned_attributes
135 def test_versioned_at(self,
138 attr: VersionedAttribute,
139 default: str | float,
140 to_set: list[str | float]
142 """Test .at() returns values nearest to queried time, or default."""
143 # check .at() return default on empty history
144 timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
145 self.assertEqual(attr.at(timestamp_a), default)
146 # check value exactly at timestamp returned
148 timestamp_b = list(attr.history.keys())[0]
149 self.assertEqual(attr.at(timestamp_b), to_set[0])
150 # check earliest value returned if exists, rather than default
151 self.assertEqual(attr.at(timestamp_a), to_set[0])
152 # check reverts to previous value for timestamps not indexed
154 timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
157 timestamp_c = sorted(attr.history.keys())[-1]
158 self.assertEqual(attr.at(timestamp_c), to_set[1])
159 self.assertEqual(attr.at(timestamp_between), to_set[0])
161 timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
162 self.assertEqual(attr.at(timestamp_after_c), to_set[1])
165 class TestCaseWithDB(TestCaseAugmented):
166 """Module tests not requiring DB setup."""
167 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
169 def setUp(self) -> None:
170 Condition.empty_cache()
172 Process.empty_cache()
173 ProcessStep.empty_cache()
175 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
176 self.db_conn = DatabaseConnection(self.db_file)
178 def tearDown(self) -> None:
180 remove_file(self.db_file.path)
182 def _load_from_db(self, id_: int | str) -> list[object]:
183 db_found: list[object] = []
184 for row in self.db_conn.row_where(self.checked_class.table_name,
186 db_found += [self.checked_class.from_table_row(self.db_conn,
190 def _change_obj(self, obj: object) -> str:
191 attr_name: str = self.checked_class.to_save_simples[-1]
192 attr = getattr(obj, attr_name)
193 new_attr: str | int | float | bool
194 if isinstance(attr, (int, float)):
196 elif isinstance(attr, str):
197 new_attr = attr + '_'
198 elif isinstance(attr, bool):
200 setattr(obj, attr_name, new_attr)
203 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
204 """Test both cache and DB equal content."""
207 expected_cache[item.id_] = item
208 self.assertEqual(self.checked_class.get_cache(), expected_cache)
209 hashes_content = [hash(x) for x in content]
210 db_found: list[Any] = []
212 assert isinstance(item.id_, type(self.default_ids[0]))
213 db_found += self._load_from_db(item.id_)
214 hashes_db_found = [hash(x) for x in db_found]
215 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
217 def check_by_date_range_with_limits(self,
219 set_id_field: bool = True
221 """Test .by_date_range_with_limits."""
222 # pylint: disable=too-many-locals
223 f = self.checked_class.by_date_range_with_limits
224 # check illegal ranges
225 legal_range = ('yesterday', 'tomorrow')
227 for bad_date in ['foo', '2024-02-30', '2024-01-01 12:00:00']:
228 date_range = list(legal_range[:])
229 date_range[i] = bad_date
230 with self.assertRaises(HandledException):
231 f(self.db_conn, date_range, date_col)
232 # check empty, translation of 'yesterday' and 'tomorrow'
233 items, start, end = f(self.db_conn, legal_range, date_col)
234 self.assertEqual(items, [])
235 yesterday = datetime.now() + timedelta(days=-1)
236 tomorrow = datetime.now() + timedelta(days=+1)
237 self.assertEqual(start, yesterday.strftime(DATE_FORMAT))
238 self.assertEqual(end, tomorrow.strftime(DATE_FORMAT))
239 # prepare dated items for non-empty results
240 kwargs_with_date = self.default_init_kwargs.copy()
242 kwargs_with_date['id_'] = None
244 dates = ['2024-01-01', '2024-01-02', '2024-01-04']
245 for date in ['2024-01-01', '2024-01-02', '2024-01-04']:
246 kwargs_with_date['date'] = date
247 obj = self.checked_class(**kwargs_with_date)
249 # check ranges still empty before saving
250 date_range = [dates[0], dates[-1]]
251 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
252 # check all objs displayed within closed interval
254 obj.save(self.db_conn)
255 self.assertEqual(f(self.db_conn, date_range, date_col)[0], objs)
256 # check that only displayed what exists within interval
257 date_range = ['2023-12-20', '2024-01-03']
258 expected = [objs[0], objs[1]]
259 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
260 date_range = ['2024-01-03', '2024-01-30']
262 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
263 # check that inverted interval displays nothing
264 date_range = [dates[-1], dates[0]]
265 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
266 # check that "today" is interpreted, and single-element interval
267 today_date = datetime.now().strftime(DATE_FORMAT)
268 kwargs_with_date['date'] = today_date
269 obj_today = self.checked_class(**kwargs_with_date)
270 obj_today.save(self.db_conn)
271 date_range = ['today', 'today']
272 items, start, end = f(self.db_conn, date_range, date_col)
273 self.assertEqual(start, today_date)
274 self.assertEqual(start, end)
275 self.assertEqual(items, [obj_today])
277 @TestCaseAugmented._run_on_versioned_attributes
278 def test_saving_versioned_attributes(self,
281 attr: VersionedAttribute,
283 to_set: list[str | float]
285 """Test storage and initialization of versioned attributes."""
287 def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
288 attr_vals_saved: list[object] = []
289 for row in self.db_conn.row_where(attr.table_name, 'parent',
291 attr_vals_saved += [row[2]]
292 return attr_vals_saved
295 # check that without attr.save() no rows in DB
296 rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
297 self.assertEqual([], rows)
298 # fail saving attributes on non-saved owner
299 with self.assertRaises(NotFoundException):
300 attr.save(self.db_conn)
301 # check owner.save() created entries as expected in attr table
302 owner.save(self.db_conn)
303 attr_vals_saved = retrieve_attr_vals(attr)
304 self.assertEqual([to_set[0]], attr_vals_saved)
305 # check changing attr val without save affects owner in memory …
307 cmp_attr = getattr(owner, attr_name)
308 self.assertEqual(to_set, list(cmp_attr.history.values()))
309 self.assertEqual(cmp_attr.history, attr.history)
310 # … but does not yet affect DB
311 attr_vals_saved = retrieve_attr_vals(attr)
312 self.assertEqual([to_set[0]], attr_vals_saved)
313 # check individual attr.save also stores new val to DB
314 attr.save(self.db_conn)
315 attr_vals_saved = retrieve_attr_vals(attr)
316 self.assertEqual(to_set, attr_vals_saved)
318 @TestCaseAugmented._run_if_checked_class
319 def test_saving_and_caching(self) -> None:
320 """Test effects of .cache() and .save()."""
321 id1 = self.default_ids[0]
322 # check failure to cache without ID (if None-ID input possible)
323 if isinstance(id1, int):
324 obj0 = self._make_from_defaults(None)
325 with self.assertRaises(HandledException):
327 # check mere object init itself doesn't even store in cache
328 obj1 = self._make_from_defaults(id1)
329 self.assertEqual(self.checked_class.get_cache(), {})
330 # check .cache() fills cache, but not DB
332 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
333 found_in_db = self._load_from_db(id1)
334 self.assertEqual(found_in_db, [])
335 # check .save() sets ID (for int IDs), updates cache, and fills DB
336 # (expect ID to be set to id1, despite obj1 already having that as ID:
337 # it's generated by cursor.lastrowid on the DB table, and with obj1
338 # not written there, obj2 should get it first!)
339 id_input = None if isinstance(id1, int) else id1
340 obj2 = self._make_from_defaults(id_input)
341 obj2.save(self.db_conn)
342 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
343 # NB: we'll only compare hashes because obj2 itself disappears on
344 # .from_table_row-trioggered database reload
345 obj2_hash = hash(obj2)
346 found_in_db += self._load_from_db(id1)
347 self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
348 # check we cannot overwrite obj2 with obj1 despite its same ID,
349 # since it has disappeared now
350 with self.assertRaises(HandledException):
351 obj1.save(self.db_conn)
353 @TestCaseAugmented._run_if_checked_class
354 def test_by_id(self) -> None:
356 id1, id2, _ = self.default_ids
357 # check failure if not yet saved
358 obj1 = self._make_from_defaults(id1)
359 with self.assertRaises(NotFoundException):
360 self.checked_class.by_id(self.db_conn, id1)
361 # check identity of cached and retrieved
363 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
364 # check identity of saved and retrieved
365 obj2 = self._make_from_defaults(id2)
366 obj2.save(self.db_conn)
367 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
369 @TestCaseAugmented._run_if_checked_class
370 def test_by_id_or_create(self) -> None:
371 """Test .by_id_or_create."""
372 # check .by_id_or_create fails if wrong class
373 if not self.checked_class.can_create_by_id:
374 with self.assertRaises(HandledException):
375 self.checked_class.by_id_or_create(self.db_conn, None)
377 # check ID input of None creates, on saving, ID=1,2,… for int IDs
378 if isinstance(self.default_ids[0], int):
380 item = self.checked_class.by_id_or_create(self.db_conn, None)
381 self.assertEqual(item.id_, None)
382 item.save(self.db_conn)
383 self.assertEqual(item.id_, n+1)
384 # check .by_id_or_create acts like normal instantiation (sans saving)
385 id_ = self.default_ids[2]
386 item = self.checked_class.by_id_or_create(self.db_conn, id_)
387 self.assertEqual(item.id_, id_)
388 with self.assertRaises(NotFoundException):
389 self.checked_class.by_id(self.db_conn, item.id_)
390 self.assertEqual(self.checked_class(item.id_), item)
392 @TestCaseAugmented._run_if_checked_class
393 def test_from_table_row(self) -> None:
394 """Test .from_table_row() properly reads in class directly from DB."""
395 id_ = self.default_ids[0]
396 obj = self._make_from_defaults(id_)
397 obj.save(self.db_conn)
398 assert isinstance(obj.id_, type(id_))
399 for row in self.db_conn.row_where(self.checked_class.table_name,
401 # check .from_table_row reproduces state saved, no matter if obj
402 # later changed (with caching even)
403 # NB: we'll only compare hashes because obj itself disappears on
404 # .from_table_row-triggered database reload
405 hash_original = hash(obj)
406 attr_name = self._change_obj(obj)
408 to_cmp = getattr(obj, attr_name)
409 retrieved = self.checked_class.from_table_row(self.db_conn, row)
410 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
411 self.assertEqual(hash_original, hash(retrieved))
412 # check cache contains what .from_table_row just produced
413 self.assertEqual({retrieved.id_: retrieved},
414 self.checked_class.get_cache())
416 @TestCaseAugmented._run_on_versioned_attributes
417 def test_versioned_history_from_row(self,
420 attr: VersionedAttribute,
421 default: str | float,
422 to_set: list[str | float]
424 """"Test VersionedAttribute.history_from_row() knows its DB rows."""
427 owner.save(self.db_conn)
428 # make empty VersionedAttribute, fill from rows, compare to owner's
429 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
430 loaded_attr = VersionedAttribute(owner, attr.table_name, default)
431 for row in self.db_conn.row_where(attr.table_name, 'parent',
433 loaded_attr.history_from_row(row)
434 self.assertEqual(len(attr.history.keys()),
435 len(loaded_attr.history.keys()))
436 for timestamp, value in attr.history.items():
437 self.assertEqual(value, loaded_attr.history[timestamp])
439 @TestCaseAugmented._run_if_checked_class
440 def test_all(self) -> None:
441 """Test .all() and its relation to cache and savings."""
442 id1, id2, id3 = self.default_ids
443 item1 = self._make_from_defaults(id1)
444 item2 = self._make_from_defaults(id2)
445 item3 = self._make_from_defaults(id3)
446 # check .all() returns empty list on un-cached items
447 self.assertEqual(self.checked_class.all(self.db_conn), [])
448 # check that all() shows only cached/saved items
450 item3.save(self.db_conn)
451 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
452 sorted([item1, item3]))
453 item2.save(self.db_conn)
454 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
455 sorted([item1, item2, item3]))
457 @TestCaseAugmented._run_if_checked_class
458 def test_singularity(self) -> None:
459 """Test pointers made for single object keep pointing to it."""
460 id1 = self.default_ids[0]
461 obj = self._make_from_defaults(id1)
462 obj.save(self.db_conn)
463 # change object, expect retrieved through .by_id to carry change
464 attr_name = self._change_obj(obj)
465 new_attr = getattr(obj, attr_name)
466 retrieved = self.checked_class.by_id(self.db_conn, id1)
467 self.assertEqual(new_attr, getattr(retrieved, attr_name))
469 @TestCaseAugmented._run_on_versioned_attributes
470 def test_versioned_singularity(self,
473 attr: VersionedAttribute,
475 to_set: list[str | float]
477 """Test singularity of VersionedAttributes on saving."""
478 owner.save(self.db_conn)
479 # change obj, expect retrieved through .by_id to carry change
481 retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
482 attr_retrieved = getattr(retrieved, attr_name)
483 self.assertEqual(attr.history, attr_retrieved.history)
485 @TestCaseAugmented._run_if_checked_class
486 def test_remove(self) -> None:
487 """Test .remove() effects on DB and cache."""
488 id_ = self.default_ids[0]
489 obj = self._make_from_defaults(id_)
490 # check removal only works after saving
491 with self.assertRaises(HandledException):
492 obj.remove(self.db_conn)
493 obj.save(self.db_conn)
494 obj.remove(self.db_conn)
495 # check access to obj fails after removal
496 with self.assertRaises(HandledException):
498 # check DB and cache now empty
499 self.check_identity_with_cache_and_db([])
502 class TestCaseWithServer(TestCaseWithDB):
503 """Module tests against our HTTP server/handler (and database)."""
505 def setUp(self) -> None:
507 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
508 self.server_thread = Thread(target=self.httpd.serve_forever)
509 self.server_thread.daemon = True
510 self.server_thread.start()
511 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
512 self.httpd.server_address[1])
513 self.httpd.set_json_mode()
515 def tearDown(self) -> None:
516 self.httpd.shutdown()
517 self.httpd.server_close()
518 self.server_thread.join()
522 def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
523 """Return list of only 'id' fields of items."""
526 assert isinstance(item['id'], (int, str))
527 id_list += [item['id']]
531 def as_refs(items: list[dict[str, object]]
532 ) -> dict[str, dict[str, object]]:
533 """Return dictionary of items by their 'id' fields."""
536 refs[str(item['id'])] = item
540 def cond_as_dict(id_: int = 1,
541 is_active: bool = False,
542 titles: None | list[str] = None,
543 descriptions: None | list[str] = None
544 ) -> dict[str, object]:
545 """Return JSON of Condition to expect."""
547 'is_active': is_active,
551 titles = titles if titles else []
552 descriptions = descriptions if descriptions else []
553 assert isinstance(d['_versioned'], dict)
554 for i, title in enumerate(titles):
555 d['_versioned']['title'][i] = title
556 for i, description in enumerate(descriptions):
557 d['_versioned']['description'][i] = description
561 def proc_as_dict(id_: int = 1,
563 description: str = '',
565 conditions: None | list[int] = None,
566 disables: None | list[int] = None,
567 blockers: None | list[int] = None,
568 enables: None | list[int] = None
569 ) -> dict[str, object]:
570 """Return JSON of Process to expect."""
571 # pylint: disable=too-many-arguments
573 'calendarize': False,
574 'suppressed_steps': [],
575 'explicit_steps': [],
578 'description': {0: description},
579 'effort': {0: effort}},
580 'conditions': conditions if conditions else [],
581 'disables': disables if disables else [],
582 'enables': enables if enables else [],
583 'blockers': blockers if blockers else []}
586 def check_redirect(self, target: str) -> None:
587 """Check that self.conn answers with a 302 redirect to target."""
588 response = self.conn.getresponse()
589 self.assertEqual(response.status, 302)
590 self.assertEqual(response.getheader('Location'), target)
592 def check_get(self, target: str, expected_code: int) -> None:
593 """Check that a GET to target yields expected_code."""
594 self.conn.request('GET', target)
595 self.assertEqual(self.conn.getresponse().status, expected_code)
597 def check_post(self, data: Mapping[str, object], target: str,
598 expected_code: int = 302, redir: str = '') -> None:
599 """Check that POST of data to target yields expected_code."""
600 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
601 headers = {'Content-Type': 'application/x-www-form-urlencoded',
602 'Content-Length': str(len(encoded_form_data))}
603 self.conn.request('POST', target,
604 body=encoded_form_data, headers=headers)
605 if 302 == expected_code:
606 redir = target if redir == '' else redir
607 self.check_redirect(redir)
609 self.assertEqual(self.conn.getresponse().status, expected_code)
611 def check_get_defaults(self, path: str) -> None:
612 """Some standard model paths to test."""
613 self.check_get(path, 200)
614 self.check_get(f'{path}?id=', 200)
615 self.check_get(f'{path}?id=foo', 400)
616 self.check_get(f'/{path}?id=0', 500)
617 self.check_get(f'{path}?id=1', 200)
619 def post_process(self, id_: int = 1,
620 form_data: dict[str, Any] | None = None
622 """POST basic Process."""
624 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
625 self.check_post(form_data, f'/process?id={id_}',
626 redir=f'/process?id={id_}')
629 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
630 """Compare JSON on GET path with expected.
632 To simplify comparison of VersionedAttribute histories, transforms
633 timestamp keys of VersionedAttribute history keys into integers
634 counting chronologically forward from 0.
637 def rewrite_history_keys_in(item: Any) -> Any:
638 if isinstance(item, dict):
639 if '_versioned' in item.keys():
640 for k in item['_versioned']:
641 vals = item['_versioned'][k].values()
643 for i, val in enumerate(vals):
645 item['_versioned'][k] = history
646 for k in list(item.keys()):
647 rewrite_history_keys_in(item[k])
648 elif isinstance(item, list):
649 item[:] = [rewrite_history_keys_in(i) for i in item]
652 self.conn.request('GET', path)
653 response = self.conn.getresponse()
654 self.assertEqual(response.status, 200)
655 retrieved = json_loads(response.read().decode())
656 rewrite_history_keys_in(retrieved)
657 self.assertEqual(expected, retrieved)