1 """Shared test utilities."""
2 # pylint: disable=too-many-lines
3 from __future__ import annotations
4 from unittest import TestCase
5 from typing import Mapping, Any, Callable
6 from threading import Thread
7 from http.client import HTTPConnection
8 from datetime import datetime, timedelta
10 from json import loads as json_loads, dumps as json_dumps
11 from urllib.parse import urlencode
12 from uuid import uuid4
13 from os import remove as remove_file
14 from pprint import pprint
15 from plomtask.db import DatabaseFile, DatabaseConnection
16 from plomtask.http import TaskHandler, TaskServer
17 from plomtask.processes import Process, ProcessStep
18 from plomtask.conditions import Condition
19 from plomtask.days import Day
20 from plomtask.dating import DATE_FORMAT
21 from plomtask.todos import Todo
22 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
23 from plomtask.exceptions import NotFoundException, HandledException
26 VERSIONED_VALS: dict[str,
27 list[str] | list[float]] = {'str': ['A', 'B'],
29 VALID_TRUES = {True, 'True', 'true', '1', 'on'}
32 class TestCaseAugmented(TestCase):
33 """Tester core providing helpful basic internal decorators and methods."""
35 default_init_kwargs: dict[str, Any] = {}
38 def _run_if_checked_class(f: Callable[..., None]) -> Callable[..., None]:
39 def wrapper(self: TestCase) -> None:
40 if hasattr(self, 'checked_class'):
45 def _run_on_versioned_attributes(cls,
46 f: Callable[..., None]
47 ) -> Callable[..., None]:
48 @cls._run_if_checked_class
49 def wrapper(self: TestCase) -> None:
50 assert isinstance(self, TestCaseAugmented)
51 for attr_name in self.checked_class.to_save_versioned():
52 default = self.checked_class.versioned_defaults[attr_name]
53 owner = self.checked_class(None, **self.default_init_kwargs)
54 attr = getattr(owner, attr_name)
55 to_set = VERSIONED_VALS[attr.value_type_name]
56 f(self, owner, attr_name, attr, default, to_set)
60 def _make_from_defaults(cls, id_: float | str | None) -> Any:
61 return cls.checked_class(id_, **cls.default_init_kwargs)
64 class TestCaseSansDB(TestCaseAugmented):
65 """Tests requiring no DB setup."""
66 legal_ids: list[str] | list[int] = [1, 5]
67 illegal_ids: list[str] | list[int] = [0]
69 @TestCaseAugmented._run_if_checked_class
70 def test_id_validation(self) -> None:
71 """Test .id_ validation/setting."""
72 for id_ in self.illegal_ids:
73 with self.assertRaises(HandledException):
74 self._make_from_defaults(id_)
75 for id_ in self.legal_ids:
76 obj = self._make_from_defaults(id_)
77 self.assertEqual(obj.id_, id_)
79 @TestCaseAugmented._run_on_versioned_attributes
80 def test_versioned_set(self,
83 attr: VersionedAttribute,
85 to_set: list[str] | list[float]
87 """Test VersionedAttribute.set() behaves as expected."""
89 self.assertEqual(list(attr.history.values()), [default])
90 # check same value does not get set twice in a row,
91 # and that not even its timestamp get updated
92 timestamp = list(attr.history.keys())[0]
94 self.assertEqual(list(attr.history.values()), [default])
95 self.assertEqual(list(attr.history.keys())[0], timestamp)
96 # check that different value _will_ be set/added
98 timesorted_vals = [attr.history[t] for
99 t in sorted(attr.history.keys())]
100 expected = [default, to_set[0]]
101 self.assertEqual(timesorted_vals, expected)
102 # check that a previously used value can be set if not most recent
104 timesorted_vals = [attr.history[t] for
105 t in sorted(attr.history.keys())]
106 expected = [default, to_set[0], default]
107 self.assertEqual(timesorted_vals, expected)
108 # again check for same value not being set twice in a row, even for
111 timesorted_vals = [attr.history[t] for
112 t in sorted(attr.history.keys())]
113 expected = [default, to_set[0], default, to_set[1]]
114 self.assertEqual(timesorted_vals, expected)
116 self.assertEqual(timesorted_vals, expected)
118 @TestCaseAugmented._run_on_versioned_attributes
119 def test_versioned_newest(self,
122 attr: VersionedAttribute,
123 default: str | float,
124 to_set: list[str] | list[float]
126 """Test VersionedAttribute.newest."""
127 # check .newest on empty history returns .default
128 self.assertEqual(attr.newest, default)
129 # check newest element always returned
130 for v in [to_set[0], to_set[1]]:
132 self.assertEqual(attr.newest, v)
133 # check newest element returned even if also early value
135 self.assertEqual(attr.newest, default)
137 @TestCaseAugmented._run_on_versioned_attributes
138 def test_versioned_at(self,
141 attr: VersionedAttribute,
142 default: str | float,
143 to_set: list[str] | list[float]
145 """Test .at() returns values nearest to queried time, or default."""
146 # check .at() return default on empty history
147 timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
148 self.assertEqual(attr.at(timestamp_a), default)
149 # check value exactly at timestamp returned
151 timestamp_b = list(attr.history.keys())[0]
152 self.assertEqual(attr.at(timestamp_b), to_set[0])
153 # check earliest value returned if exists, rather than default
154 self.assertEqual(attr.at(timestamp_a), to_set[0])
155 # check reverts to previous value for timestamps not indexed
157 timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
160 timestamp_c = sorted(attr.history.keys())[-1]
161 self.assertEqual(attr.at(timestamp_c), to_set[1])
162 self.assertEqual(attr.at(timestamp_between), to_set[0])
164 timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
165 self.assertEqual(attr.at(timestamp_after_c), to_set[1])
168 class TestCaseWithDB(TestCaseAugmented):
169 """Module tests not requiring DB setup."""
170 default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
172 def setUp(self) -> None:
173 Condition.empty_cache()
175 Process.empty_cache()
176 ProcessStep.empty_cache()
178 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
179 self.db_conn = DatabaseConnection(self.db_file)
181 def tearDown(self) -> None:
183 remove_file(self.db_file.path)
185 def _load_from_db(self, id_: int | str) -> list[object]:
186 db_found: list[object] = []
187 for row in self.db_conn.row_where(self.checked_class.table_name,
189 db_found += [self.checked_class.from_table_row(self.db_conn,
193 def _change_obj(self, obj: object) -> str:
194 attr_name: str = self.checked_class.to_save_simples[-1]
195 attr = getattr(obj, attr_name)
196 new_attr: str | int | float | bool
197 if isinstance(attr, (int, float)):
199 elif isinstance(attr, str):
200 new_attr = attr + '_'
201 elif isinstance(attr, bool):
203 setattr(obj, attr_name, new_attr)
206 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
207 """Test both cache and DB equal content."""
210 expected_cache[item.id_] = item
211 self.assertEqual(self.checked_class.get_cache(), expected_cache)
212 hashes_content = [hash(x) for x in content]
213 db_found: list[Any] = []
215 assert isinstance(item.id_, type(self.default_ids[0]))
216 db_found += self._load_from_db(item.id_)
217 hashes_db_found = [hash(x) for x in db_found]
218 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
220 def check_by_date_range_with_limits(self,
222 set_id_field: bool = True
224 """Test .by_date_range_with_limits."""
225 # pylint: disable=too-many-locals
226 f = self.checked_class.by_date_range_with_limits
227 # check illegal ranges
228 legal_range = ('yesterday', 'tomorrow')
230 for bad_date in ['foo', '2024-02-30', '2024-01-01 12:00:00']:
231 date_range = list(legal_range[:])
232 date_range[i] = bad_date
233 with self.assertRaises(HandledException):
234 f(self.db_conn, date_range, date_col)
235 # check empty, translation of 'yesterday' and 'tomorrow'
236 items, start, end = f(self.db_conn, legal_range, date_col)
237 self.assertEqual(items, [])
238 yesterday = datetime.now() + timedelta(days=-1)
239 tomorrow = datetime.now() + timedelta(days=+1)
240 self.assertEqual(start, yesterday.strftime(DATE_FORMAT))
241 self.assertEqual(end, tomorrow.strftime(DATE_FORMAT))
242 # prepare dated items for non-empty results
243 kwargs_with_date = self.default_init_kwargs.copy()
245 kwargs_with_date['id_'] = None
247 dates = ['2024-01-01', '2024-01-02', '2024-01-04']
248 for date in ['2024-01-01', '2024-01-02', '2024-01-04']:
249 kwargs_with_date['date'] = date
250 obj = self.checked_class(**kwargs_with_date)
252 # check ranges still empty before saving
253 date_range = [dates[0], dates[-1]]
254 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
255 # check all objs displayed within closed interval
257 obj.save(self.db_conn)
258 self.assertEqual(f(self.db_conn, date_range, date_col)[0], objs)
259 # check that only displayed what exists within interval
260 date_range = ['2023-12-20', '2024-01-03']
261 expected = [objs[0], objs[1]]
262 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
263 date_range = ['2024-01-03', '2024-01-30']
265 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
266 # check that inverted interval displays nothing
267 date_range = [dates[-1], dates[0]]
268 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
269 # check that "today" is interpreted, and single-element interval
270 today_date = datetime.now().strftime(DATE_FORMAT)
271 kwargs_with_date['date'] = today_date
272 obj_today = self.checked_class(**kwargs_with_date)
273 obj_today.save(self.db_conn)
274 date_range = ['today', 'today']
275 items, start, end = f(self.db_conn, date_range, date_col)
276 self.assertEqual(start, today_date)
277 self.assertEqual(start, end)
278 self.assertEqual(items, [obj_today])
280 @TestCaseAugmented._run_on_versioned_attributes
281 def test_saving_versioned_attributes(self,
284 attr: VersionedAttribute,
286 to_set: list[str] | list[float]
288 """Test storage and initialization of versioned attributes."""
290 def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
291 attr_vals_saved: list[object] = []
292 for row in self.db_conn.row_where(attr.table_name, 'parent',
294 attr_vals_saved += [row[2]]
295 return attr_vals_saved
298 # check that without attr.save() no rows in DB
299 rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
300 self.assertEqual([], rows)
301 # fail saving attributes on non-saved owner
302 with self.assertRaises(NotFoundException):
303 attr.save(self.db_conn)
304 # check owner.save() created entries as expected in attr table
305 owner.save(self.db_conn)
306 attr_vals_saved = retrieve_attr_vals(attr)
307 self.assertEqual([to_set[0]], attr_vals_saved)
308 # check changing attr val without save affects owner in memory …
310 cmp_attr = getattr(owner, attr_name)
311 self.assertEqual(to_set, list(cmp_attr.history.values()))
312 self.assertEqual(cmp_attr.history, attr.history)
313 # … but does not yet affect DB
314 attr_vals_saved = retrieve_attr_vals(attr)
315 self.assertEqual([to_set[0]], attr_vals_saved)
316 # check individual attr.save also stores new val to DB
317 attr.save(self.db_conn)
318 attr_vals_saved = retrieve_attr_vals(attr)
319 self.assertEqual(to_set, attr_vals_saved)
321 @TestCaseAugmented._run_if_checked_class
322 def test_saving_and_caching(self) -> None:
323 """Test effects of .cache() and .save()."""
324 id1 = self.default_ids[0]
325 # check failure to cache without ID (if None-ID input possible)
326 if isinstance(id1, int):
327 obj0 = self._make_from_defaults(None)
328 with self.assertRaises(HandledException):
330 # check mere object init itself doesn't even store in cache
331 obj1 = self._make_from_defaults(id1)
332 self.assertEqual(self.checked_class.get_cache(), {})
333 # check .cache() fills cache, but not DB
335 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
336 found_in_db = self._load_from_db(id1)
337 self.assertEqual(found_in_db, [])
338 # check .save() sets ID (for int IDs), updates cache, and fills DB
339 # (expect ID to be set to id1, despite obj1 already having that as ID:
340 # it's generated by cursor.lastrowid on the DB table, and with obj1
341 # not written there, obj2 should get it first!)
342 id_input = None if isinstance(id1, int) else id1
343 obj2 = self._make_from_defaults(id_input)
344 obj2.save(self.db_conn)
345 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
346 # NB: we'll only compare hashes because obj2 itself disappears on
347 # .from_table_row-triggered database reload
348 obj2_hash = hash(obj2)
349 found_in_db += self._load_from_db(id1)
350 self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
351 # check we cannot overwrite obj2 with obj1 despite its same ID,
352 # since it has disappeared now
353 with self.assertRaises(HandledException):
354 obj1.save(self.db_conn)
356 @TestCaseAugmented._run_if_checked_class
357 def test_by_id(self) -> None:
359 id1, id2, _ = self.default_ids
360 # check failure if not yet saved
361 obj1 = self._make_from_defaults(id1)
362 with self.assertRaises(NotFoundException):
363 self.checked_class.by_id(self.db_conn, id1)
364 # check identity of cached and retrieved
366 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
367 # check identity of saved and retrieved
368 obj2 = self._make_from_defaults(id2)
369 obj2.save(self.db_conn)
370 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
372 @TestCaseAugmented._run_if_checked_class
373 def test_by_id_or_create(self) -> None:
374 """Test .by_id_or_create."""
375 # check .by_id_or_create fails if wrong class
376 if not self.checked_class.can_create_by_id:
377 with self.assertRaises(HandledException):
378 self.checked_class.by_id_or_create(self.db_conn, None)
380 # check ID input of None creates, on saving, ID=1,2,… for int IDs
381 if isinstance(self.default_ids[0], int):
383 item = self.checked_class.by_id_or_create(self.db_conn, None)
384 self.assertEqual(item.id_, None)
385 item.save(self.db_conn)
386 self.assertEqual(item.id_, n+1)
387 # check .by_id_or_create acts like normal instantiation (sans saving)
388 id_ = self.default_ids[2]
389 item = self.checked_class.by_id_or_create(self.db_conn, id_)
390 self.assertEqual(item.id_, id_)
391 with self.assertRaises(NotFoundException):
392 self.checked_class.by_id(self.db_conn, item.id_)
393 self.assertEqual(self.checked_class(item.id_), item)
395 @TestCaseAugmented._run_if_checked_class
396 def test_from_table_row(self) -> None:
397 """Test .from_table_row() properly reads in class directly from DB."""
398 id_ = self.default_ids[0]
399 obj = self._make_from_defaults(id_)
400 obj.save(self.db_conn)
401 assert isinstance(obj.id_, type(id_))
402 for row in self.db_conn.row_where(self.checked_class.table_name,
404 # check .from_table_row reproduces state saved, no matter if obj
405 # later changed (with caching even)
406 # NB: we'll only compare hashes because obj itself disappears on
407 # .from_table_row-triggered database reload
408 hash_original = hash(obj)
409 attr_name = self._change_obj(obj)
411 to_cmp = getattr(obj, attr_name)
412 retrieved = self.checked_class.from_table_row(self.db_conn, row)
413 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
414 self.assertEqual(hash_original, hash(retrieved))
415 # check cache contains what .from_table_row just produced
416 self.assertEqual({retrieved.id_: retrieved},
417 self.checked_class.get_cache())
419 @TestCaseAugmented._run_on_versioned_attributes
420 def test_versioned_history_from_row(self,
423 attr: VersionedAttribute,
424 default: str | float,
425 to_set: list[str] | list[float]
427 """"Test VersionedAttribute.history_from_row() knows its DB rows."""
430 owner.save(self.db_conn)
431 # make empty VersionedAttribute, fill from rows, compare to owner's
432 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
433 loaded_attr = VersionedAttribute(owner, attr.table_name, default)
434 for row in self.db_conn.row_where(attr.table_name, 'parent',
436 loaded_attr.history_from_row(row)
437 self.assertEqual(len(attr.history.keys()),
438 len(loaded_attr.history.keys()))
439 for timestamp, value in attr.history.items():
440 self.assertEqual(value, loaded_attr.history[timestamp])
442 @TestCaseAugmented._run_if_checked_class
443 def test_all(self) -> None:
444 """Test .all() and its relation to cache and savings."""
445 id1, id2, id3 = self.default_ids
446 item1 = self._make_from_defaults(id1)
447 item2 = self._make_from_defaults(id2)
448 item3 = self._make_from_defaults(id3)
449 # check .all() returns empty list on un-cached items
450 self.assertEqual(self.checked_class.all(self.db_conn), [])
451 # check that all() shows only cached/saved items
453 item3.save(self.db_conn)
454 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
455 sorted([item1, item3]))
456 item2.save(self.db_conn)
457 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
458 sorted([item1, item2, item3]))
460 @TestCaseAugmented._run_if_checked_class
461 def test_singularity(self) -> None:
462 """Test pointers made for single object keep pointing to it."""
463 id1 = self.default_ids[0]
464 obj = self._make_from_defaults(id1)
465 obj.save(self.db_conn)
466 # change object, expect retrieved through .by_id to carry change
467 attr_name = self._change_obj(obj)
468 new_attr = getattr(obj, attr_name)
469 retrieved = self.checked_class.by_id(self.db_conn, id1)
470 self.assertEqual(new_attr, getattr(retrieved, attr_name))
472 @TestCaseAugmented._run_on_versioned_attributes
473 def test_versioned_singularity(self,
476 attr: VersionedAttribute,
478 to_set: list[str] | list[float]
480 """Test singularity of VersionedAttributes on saving."""
481 owner.save(self.db_conn)
482 # change obj, expect retrieved through .by_id to carry change
484 retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
485 attr_retrieved = getattr(retrieved, attr_name)
486 self.assertEqual(attr.history, attr_retrieved.history)
488 @TestCaseAugmented._run_if_checked_class
489 def test_remove(self) -> None:
490 """Test .remove() effects on DB and cache."""
491 id_ = self.default_ids[0]
492 obj = self._make_from_defaults(id_)
493 # check removal only works after saving
494 with self.assertRaises(HandledException):
495 obj.remove(self.db_conn)
496 obj.save(self.db_conn)
497 obj.remove(self.db_conn)
498 # check access to obj fails after removal
499 with self.assertRaises(HandledException):
501 # check DB and cache now empty
502 self.check_identity_with_cache_and_db([])
506 """Builder of (JSON-like) dict to compare against responses of test server.
508 Collects all items and relations we expect expressed in the server's JSON
509 responses and puts them into the proper json.dumps-friendly dict structure,
510 accessibla via .as_dict, to compare them in TestsWithServer.check_json_get.
512 On its own provides for .as_dict output only {"_library": …}, initialized
513 from .__init__ and to be directly manipulated via the .lib* methods.
514 Further structures of the expected response may be added and kept
515 up-to-date by subclassing .__init__, .recalc, and .d.
517 NB: Lots of expectations towards server behavior will be made explicit here
518 (or in the subclasses) rather than in the actual TestCase methods' code.
520 _default_dict: dict[str, Any]
521 _forced: dict[str, Any]
522 _fields: dict[str, Any]
523 _on_empty_make_temp: tuple[str, str]
526 todos: list[dict[str, Any]] | None = None,
527 procs: list[dict[str, Any]] | None = None,
528 procsteps: list[dict[str, Any]] | None = None,
529 conds: list[dict[str, Any]] | None = None,
530 days: list[dict[str, Any]] | None = None
532 # pylint: disable=too-many-arguments
533 for name in ['_default_dict', '_fields', '_forced']:
534 if not hasattr(self, name):
535 setattr(self, name, {})
537 for title, items in [('Todo', todos),
539 ('ProcessStep', procsteps),
540 ('Condition', conds),
543 self._lib[title] = self._as_refs(items)
544 for k, v in self._default_dict.items():
545 if k not in self._fields:
548 def recalc(self) -> None:
549 """Update internal dictionary by subclass-specific rules."""
550 todos = self.lib_all('Todo')
554 for child_id in todo['children']:
555 self.lib_get('Todo', child_id)['parents'] += [todo['id']]
556 todo['children'].sort()
557 procsteps = self.lib_all('ProcessStep')
558 procs = self.lib_all('Process')
560 proc['explicit_steps'] = [s['id'] for s in procsteps
561 if s['owner_id'] == proc['id']]
564 def as_dict(self) -> dict[str, Any]:
565 """Return dict to compare against test server JSON responses."""
567 if hasattr(self, '_on_empty_make_temp'):
568 category, dicter = getattr(self, '_on_empty_make_temp')
569 id_ = self._fields[category.lower()]
570 make_temp = not bool(self.lib_get(category, id_))
572 f = getattr(self, dicter)
573 self.lib_set(category, [f(id_)])
575 d = {'_library': self._lib}
576 for k, v in self._fields.items():
577 # we expect everything sortable to be sorted
578 if isinstance(v, list) and k not in self._forced:
579 # NB: if we don't test for v being list, sorted() on an empty
580 # dict may return an empty list
586 for k, v in self._forced.items():
590 self.lib_del(category, id_)
594 def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
595 """From library, return item of category and id_, or empty dict."""
597 if category in self._lib and str_id in self._lib[category]:
598 return self._lib[category][str_id]
601 def lib_all(self, category: str) -> list[dict[str, Any]]:
602 """From library, return items of category, or [] if none."""
603 if category in self._lib:
604 return list(self._lib[category].values())
607 def lib_set(self, category: str, items: list[dict[str, object]]) -> None:
608 """Update library for category with items."""
609 if category not in self._lib:
610 self._lib[category] = {}
611 for k, v in self._as_refs(items).items():
612 self._lib[category][k] = v
614 def lib_del(self, category: str, id_: str | int) -> None:
615 """Remove category element of id_ from library."""
616 del self._lib[category][str(id_)]
617 if 0 == len(self._lib[category]):
618 del self._lib[category]
620 def lib_wipe(self, category: str) -> None:
621 """Remove category from library."""
622 if category in self._lib:
623 del self._lib[category]
625 def set(self, field_name: str, value: object) -> None:
626 """Set top-level .as_dict field."""
627 self._fields[field_name] = value
629 def force(self, field_name: str, value: object) -> None:
630 """Set ._forced field to ensure value in .as_dict."""
631 self._forced[field_name] = value
633 def unforce(self, field_name: str) -> None:
634 """Unset ._forced field."""
635 del self._forced[field_name]
638 def _as_refs(items: list[dict[str, object]]
639 ) -> dict[str, dict[str, object]]:
640 """Return dictionary of items by their 'id' fields."""
643 refs[str(item['id'])] = item
647 def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
648 """Return list of only 'id' fields of items."""
649 return [item['id'] for item in items]
652 def day_as_dict(date: str, comment: str = '') -> dict[str, object]:
653 """Return JSON of Day to expect."""
654 return {'id': date, 'comment': comment, 'todos': []}
656 def set_day_from_post(self, date: str, d: dict[str, Any]) -> None:
657 """Set Day of date in library based on POST dict d."""
658 day = self.day_as_dict(date)
659 for k, v in d.items():
660 if 'day_comment' == k:
662 elif 'new_todo' == k:
664 for todo in self.lib_all('Todo'):
665 if next_id <= todo['id']:
666 next_id = todo['id'] + 1
667 for proc_id in sorted(v):
668 todo = self.todo_as_dict(next_id, proc_id, date)
669 self.lib_set('Todo', [todo])
673 self.lib_get('Todo', todo_id)['is_done'] = True
675 for i, todo_id in enumerate(v):
676 t = self.lib_get('Todo', todo_id)
678 t['comment'] = d['comment'][i]
680 effort = d['effort'][i] if d['effort'][i] else None
682 self.lib_set('Day', [day])
685 def cond_as_dict(id_: int = 1,
686 is_active: bool = False,
687 title: None | str = None,
688 description: None | str = None,
689 ) -> dict[str, object]:
690 """Return JSON of Condition to expect."""
691 versioned: dict[str, dict[str, object]]
692 versioned = {'title': {}, 'description': {}}
693 if title is not None:
694 versioned['title']['0'] = title
695 if description is not None:
696 versioned['description']['0'] = description
697 return {'id': id_, 'is_active': is_active, '_versioned': versioned}
699 def set_cond_from_post(self, id_: int, d: dict[str, Any]) -> None:
700 """Set Condition of id_ in library based on POST dict d."""
701 if d == {'delete': ''}:
702 self.lib_del('Condition', id_)
704 cond = self.lib_get('Condition', id_)
706 cond['is_active'] = d['is_active']
707 for category in ['title', 'description']:
708 history = cond['_versioned'][category]
710 last_i = sorted([int(k) for k in history.keys()])[-1]
711 if d[category] != history[str(last_i)]:
712 history[str(last_i + 1)] = d[category]
714 history['0'] = d[category]
716 cond = self.cond_as_dict(
717 id_, d['is_active'], d['title'], d['description'])
718 self.lib_set('Condition', [cond])
721 def todo_as_dict(id_: int = 1,
723 date: str = '2024-01-01',
724 conditions: None | list[int] = None,
725 disables: None | list[int] = None,
726 blockers: None | list[int] = None,
727 enables: None | list[int] = None,
728 calendarize: bool = False,
730 is_done: bool = False,
731 effort: float | None = None,
732 children: list[int] | None = None,
733 parents: list[int] | None = None,
734 ) -> dict[str, object]:
735 """Return JSON of Todo to expect."""
736 # pylint: disable=too-many-arguments
739 'process_id': process_id,
741 'calendarize': calendarize,
743 'children': children if children else [],
744 'parents': parents if parents else [],
746 'conditions': conditions if conditions else [],
747 'disables': disables if disables else [],
748 'blockers': blockers if blockers else [],
749 'enables': enables if enables else []}
752 def set_todo_from_post(self, id_: int, d: dict[str, Any]) -> None:
753 """Set Todo of id_ in library based on POST dict d."""
754 corrected_kwargs: dict[str, Any] = {'children': []}
755 for k, v in d.items():
756 if k in {'adopt', 'step_filler'}:
757 new_children = v if isinstance(v, list) else [v]
758 corrected_kwargs['children'] += new_children
760 if k in {'is_done', 'calendarize'}:
762 corrected_kwargs[k] = v
763 todo = self.lib_get('Todo', id_)
765 for k, v in corrected_kwargs.items():
768 todo = self.todo_as_dict(id_, **corrected_kwargs)
769 self.lib_set('Todo', [todo])
772 def procstep_as_dict(id_: int,
774 step_process_id: int,
775 parent_step_id: int | None = None
776 ) -> dict[str, object]:
777 """Return JSON of ProcessStep to expect."""
779 'owner_id': owner_id,
780 'step_process_id': step_process_id,
781 'parent_step_id': parent_step_id}
784 def proc_as_dict(id_: int = 1,
785 title: None | str = None,
786 description: None | str = None,
787 effort: None | float = None,
788 conditions: None | list[int] = None,
789 disables: None | list[int] = None,
790 blockers: None | list[int] = None,
791 enables: None | list[int] = None,
792 explicit_steps: None | list[int] = None
793 ) -> dict[str, object]:
794 """Return JSON of Process to expect."""
795 # pylint: disable=too-many-arguments
796 versioned: dict[str, dict[str, object]]
797 versioned = {'title': {}, 'description': {}, 'effort': {}}
798 if title is not None:
799 versioned['title']['0'] = title
800 if description is not None:
801 versioned['description']['0'] = description
802 if effort is not None:
803 versioned['effort']['0'] = effort
805 'calendarize': False,
806 'suppressed_steps': [],
807 'explicit_steps': explicit_steps if explicit_steps else [],
808 '_versioned': versioned,
809 'conditions': conditions if conditions else [],
810 'disables': disables if disables else [],
811 'enables': enables if enables else [],
812 'blockers': blockers if blockers else []}
815 def set_proc_from_post(self, id_: int, d: dict[str, Any]) -> None:
816 """Set Process of id_ in library based on POST dict d."""
817 proc = self.lib_get('Process', id_)
819 for category in ['title', 'description', 'effort']:
820 history = proc['_versioned'][category]
822 last_i = sorted([int(k) for k in history.keys()])[-1]
823 if d[category] != history[str(last_i)]:
824 history[str(last_i + 1)] = d[category]
826 history['0'] = d[category]
828 proc = self.proc_as_dict(id_,
829 d['title'], d['description'], d['effort'])
830 ignore = {'title', 'description', 'effort', 'new_top_step', 'step_of',
832 for k, v in d.items():
834 or k.startswith('step_') or k.startswith('new_step_to'):
836 if k in {'calendarize'}:
838 elif k in {'suppressed_steps', 'explicit_steps', 'conditions',
839 'disables', 'enables', 'blockers'}:
840 if not isinstance(v, list):
843 self.lib_set('Process', [proc])
846 class TestCaseWithServer(TestCaseWithDB):
847 """Module tests against our HTTP server/handler (and database)."""
849 def setUp(self) -> None:
851 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
852 self.server_thread = Thread(target=self.httpd.serve_forever)
853 self.server_thread.daemon = True
854 self.server_thread.start()
855 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
856 self.httpd.server_address[1])
857 self.httpd.render_mode = 'json'
859 def tearDown(self) -> None:
860 self.httpd.shutdown()
861 self.httpd.server_close()
862 self.server_thread.join()
865 def post_exp_cond(self,
866 exps: list[Expected],
868 payload: dict[str, object],
869 path_suffix: str = '',
870 redir_suffix: str = ''
872 """POST /condition(s), appropriately update Expecteds."""
873 # pylint: disable=too-many-arguments
874 path = f'/condition{path_suffix}'
875 redir = f'/condition{redir_suffix}'
876 self.check_post(payload, path, redir=redir)
878 exp.set_cond_from_post(id_, payload)
880 def post_exp_day(self,
881 exps: list[Expected],
882 payload: dict[str, Any],
883 date: str = '2024-01-01'
885 """POST /day, appropriately update Expecteds."""
886 if 'make_type' not in payload:
887 payload['make_type'] = 'empty'
888 if 'day_comment' not in payload:
889 payload['day_comment'] = ''
890 target = f'/day?date={date}'
891 redir_to = f'{target}&make_type={payload["make_type"]}'
892 self.check_post(payload, target, 302, redir_to)
894 exp.set_day_from_post(date, payload)
896 def post_exp_process(self,
897 exps: list[Expected],
898 payload: dict[str, Any],
900 ) -> dict[str, object]:
901 """POST /process, appropriately update Expecteds."""
902 if 'title' not in payload:
903 payload['title'] = 'foo'
904 if 'description' not in payload:
905 payload['description'] = 'foo'
906 if 'effort' not in payload:
907 payload['effort'] = 1.1
908 self.check_post(payload, f'/process?id={id_}',
909 redir=f'/process?id={id_}')
911 exp.set_proc_from_post(id_, payload)
914 def check_filter(self, exp: Expected, category: str, key: str,
915 val: str, list_ids: list[int]) -> None:
916 """Check GET /{category}?{key}={val} sorts to list_ids."""
917 # pylint: disable=too-many-arguments
919 exp.force(category, list_ids)
920 self.check_json_get(f'/{category}?{key}={val}', exp)
922 def check_redirect(self, target: str) -> None:
923 """Check that self.conn answers with a 302 redirect to target."""
924 response = self.conn.getresponse()
925 self.assertEqual(response.status, 302)
926 self.assertEqual(response.getheader('Location'), target)
928 def check_get(self, target: str, expected_code: int) -> None:
929 """Check that a GET to target yields expected_code."""
930 self.conn.request('GET', target)
931 self.assertEqual(self.conn.getresponse().status, expected_code)
933 def check_post(self, data: Mapping[str, object], target: str,
934 expected_code: int = 302, redir: str = '') -> None:
935 """Check that POST of data to target yields expected_code."""
936 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
937 headers = {'Content-Type': 'application/x-www-form-urlencoded',
938 'Content-Length': str(len(encoded_form_data))}
939 self.conn.request('POST', target,
940 body=encoded_form_data, headers=headers)
941 if 302 == expected_code:
942 redir = target if redir == '' else redir
943 self.check_redirect(redir)
945 self.assertEqual(self.conn.getresponse().status, expected_code)
947 def check_get_defaults(self, path: str) -> None:
948 """Some standard model paths to test."""
949 self.check_get(path, 200)
950 self.check_get(f'{path}?id=', 200)
951 self.check_get(f'{path}?id=foo', 400)
952 self.check_get(f'/{path}?id=0', 500)
953 self.check_get(f'{path}?id=1', 200)
955 def check_json_get(self, path: str, expected: Expected) -> None:
956 """Compare JSON on GET path with expected.
958 To simplify comparison of VersionedAttribute histories, transforms
959 timestamp keys of VersionedAttribute history keys into (strings of)
960 integers counting chronologically forward from 0.
963 def rewrite_history_keys_in(item: Any) -> Any:
964 if isinstance(item, dict):
965 if '_versioned' in item.keys():
966 for category in item['_versioned']:
967 vals = item['_versioned'][category].values()
969 for i, val in enumerate(vals):
970 history[str(i)] = val
971 item['_versioned'][category] = history
972 for category in list(item.keys()):
973 rewrite_history_keys_in(item[category])
974 elif isinstance(item, list):
975 item[:] = [rewrite_history_keys_in(i) for i in item]
978 def walk_diffs(path: str, cmp1: object, cmp2: object) -> None:
979 # pylint: disable=too-many-branches
980 def warn(intro: str, val: object) -> None:
981 if isinstance(val, (str, int, float)):
987 if isinstance(cmp1, dict) and isinstance(cmp2, dict):
988 for k, v in cmp1.items():
990 warn(f'DIFF {path}: retrieved lacks {k}', v)
992 walk_diffs(f'{path}:{k}', v, cmp2[k])
993 for k in [k for k in cmp2.keys() if k not in cmp1]:
994 warn(f'DIFF {path}: expected lacks retrieved\'s {k}',
996 elif isinstance(cmp1, list) and isinstance(cmp2, list):
997 for i, v1 in enumerate(cmp1):
999 warn(f'DIFF {path}[{i}] retrieved misses:', v1)
1001 walk_diffs(f'{path}[{i}]', v1, cmp2[i])
1002 if len(cmp2) > len(cmp1):
1003 for i, v2 in enumerate(cmp2[len(cmp1):]):
1004 warn(f'DIFF {path}[{len(cmp1)+i}] misses:', v2)
1006 warn(f'DIFF {path} – for expected:', cmp1)
1007 warn('… and for retrieved:', cmp2)
1009 self.conn.request('GET', path)
1010 response = self.conn.getresponse()
1011 self.assertEqual(response.status, 200)
1012 retrieved = json_loads(response.read().decode())
1013 rewrite_history_keys_in(retrieved)
1014 cmp = expected.as_dict
1016 self.assertEqual(cmp, retrieved)
1017 except AssertionError as e:
1022 walk_diffs('', cmp, retrieved)