1 """Shared test utilities."""
2 # pylint: disable=too-many-lines
3 from __future__ import annotations
4 from unittest import TestCase
5 from typing import Mapping, Any, Callable
6 from threading import Thread
7 from http.client import HTTPConnection
8 from datetime import datetime, timedelta
10 from json import loads as json_loads, dumps as json_dumps
11 from urllib.parse import urlencode
12 from uuid import uuid4
13 from os import remove as remove_file
14 from pprint import pprint
15 from plomtask.db import DatabaseFile, DatabaseConnection
16 from plomtask.http import TaskHandler, TaskServer
17 from plomtask.processes import Process, ProcessStep
18 from plomtask.conditions import Condition
19 from plomtask.days import Day
20 from plomtask.dating import DATE_FORMAT
21 from plomtask.todos import Todo
22 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
23 from plomtask.exceptions import NotFoundException, HandledException
26 VERSIONED_VALS: dict[str,
27 list[str] | list[float]] = {'str': ['A', 'B'],
29 VALID_TRUES = {True, 'True', 'true', '1', 'on'}
32 class TestCaseAugmented(TestCase):
33 """Tester core providing helpful basic internal decorators and methods."""
35 default_init_kwargs: dict[str, Any] = {}
38 def _run_on_versioned_attributes(f: Callable[..., None]
39 ) -> Callable[..., None]:
40 def wrapper(self: TestCase) -> None:
41 assert isinstance(self, TestCaseAugmented)
42 for attr_name in self.checked_class.to_save_versioned():
43 default = self.checked_class.versioned_defaults[attr_name]
44 owner = self.checked_class(None, **self.default_init_kwargs)
45 attr = getattr(owner, attr_name)
46 to_set = VERSIONED_VALS[attr.value_type_name]
47 f(self, owner, attr_name, attr, default, to_set)
51 def _run_if_sans_db(cls, f: Callable[..., None]) -> Callable[..., None]:
52 def wrapper(self: TestCaseSansDB) -> None:
53 if issubclass(cls, TestCaseSansDB):
58 def _run_if_with_db_but_not_server(cls,
59 f: Callable[..., None]
60 ) -> Callable[..., None]:
61 def wrapper(self: TestCaseWithDB) -> None:
62 if issubclass(cls, TestCaseWithDB) and\
63 not issubclass(cls, TestCaseWithServer):
68 def _make_from_defaults(cls, id_: float | str | None) -> Any:
69 return cls.checked_class(id_, **cls.default_init_kwargs)
72 class TestCaseSansDB(TestCaseAugmented):
73 """Tests requiring no DB setup."""
74 legal_ids: list[str] | list[int] = [1, 5]
75 illegal_ids: list[str] | list[int] = [0]
77 @TestCaseAugmented._run_if_sans_db
78 def test_id_validation(self) -> None:
79 """Test .id_ validation/setting."""
80 for id_ in self.illegal_ids:
81 with self.assertRaises(HandledException):
82 self._make_from_defaults(id_)
83 for id_ in self.legal_ids:
84 obj = self._make_from_defaults(id_)
85 self.assertEqual(obj.id_, id_)
87 @TestCaseAugmented._run_if_sans_db
88 @TestCaseAugmented._run_on_versioned_attributes
89 def test_versioned_set(self,
92 attr: VersionedAttribute,
94 to_set: list[str] | list[float]
96 """Test VersionedAttribute.set() behaves as expected."""
98 self.assertEqual(list(attr.history.values()), [default])
99 # check same value does not get set twice in a row,
100 # and that not even its timestamp get updated
101 timestamp = list(attr.history.keys())[0]
103 self.assertEqual(list(attr.history.values()), [default])
104 self.assertEqual(list(attr.history.keys())[0], timestamp)
105 # check that different value _will_ be set/added
107 timesorted_vals = [attr.history[t] for
108 t in sorted(attr.history.keys())]
109 expected = [default, to_set[0]]
110 self.assertEqual(timesorted_vals, expected)
111 # check that a previously used value can be set if not most recent
113 timesorted_vals = [attr.history[t] for
114 t in sorted(attr.history.keys())]
115 expected = [default, to_set[0], default]
116 self.assertEqual(timesorted_vals, expected)
117 # again check for same value not being set twice in a row, even for
120 timesorted_vals = [attr.history[t] for
121 t in sorted(attr.history.keys())]
122 expected = [default, to_set[0], default, to_set[1]]
123 self.assertEqual(timesorted_vals, expected)
125 self.assertEqual(timesorted_vals, expected)
127 @TestCaseAugmented._run_if_sans_db
128 @TestCaseAugmented._run_on_versioned_attributes
129 def test_versioned_newest(self,
132 attr: VersionedAttribute,
133 default: str | float,
134 to_set: list[str] | list[float]
136 """Test VersionedAttribute.newest."""
137 # check .newest on empty history returns .default
138 self.assertEqual(attr.newest, default)
139 # check newest element always returned
140 for v in [to_set[0], to_set[1]]:
142 self.assertEqual(attr.newest, v)
143 # check newest element returned even if also early value
145 self.assertEqual(attr.newest, default)
147 @TestCaseAugmented._run_if_sans_db
148 @TestCaseAugmented._run_on_versioned_attributes
149 def test_versioned_at(self,
152 attr: VersionedAttribute,
153 default: str | float,
154 to_set: list[str] | list[float]
156 """Test .at() returns values nearest to queried time, or default."""
157 # check .at() return default on empty history
158 timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
159 self.assertEqual(attr.at(timestamp_a), default)
160 # check value exactly at timestamp returned
162 timestamp_b = list(attr.history.keys())[0]
163 self.assertEqual(attr.at(timestamp_b), to_set[0])
164 # check earliest value returned if exists, rather than default
165 self.assertEqual(attr.at(timestamp_a), to_set[0])
166 # check reverts to previous value for timestamps not indexed
168 timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
171 timestamp_c = sorted(attr.history.keys())[-1]
172 self.assertEqual(attr.at(timestamp_c), to_set[1])
173 self.assertEqual(attr.at(timestamp_between), to_set[0])
175 timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
176 self.assertEqual(attr.at(timestamp_after_c), to_set[1])
179 class TestCaseWithDB(TestCaseAugmented):
180 """Module tests not requiring DB setup."""
181 default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
183 def setUp(self) -> None:
184 Condition.empty_cache()
186 Process.empty_cache()
187 ProcessStep.empty_cache()
189 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
190 self.db_conn = DatabaseConnection(self.db_file)
192 def tearDown(self) -> None:
194 remove_file(self.db_file.path)
196 def _load_from_db(self, id_: int | str) -> list[object]:
197 db_found: list[object] = []
198 for row in self.db_conn.row_where(self.checked_class.table_name,
200 db_found += [self.checked_class.from_table_row(self.db_conn,
204 def _change_obj(self, obj: object) -> str:
205 attr_name: str = self.checked_class.to_save_simples[-1]
206 attr = getattr(obj, attr_name)
207 new_attr: str | int | float | bool
208 if isinstance(attr, (int, float)):
210 elif isinstance(attr, str):
211 new_attr = attr + '_'
212 elif isinstance(attr, bool):
214 setattr(obj, attr_name, new_attr)
217 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
218 """Test both cache and DB equal content."""
221 expected_cache[item.id_] = item
222 self.assertEqual(self.checked_class.get_cache(), expected_cache)
223 hashes_content = [hash(x) for x in content]
224 db_found: list[Any] = []
226 assert isinstance(item.id_, type(self.default_ids[0]))
227 db_found += self._load_from_db(item.id_)
228 hashes_db_found = [hash(x) for x in db_found]
229 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
231 def check_by_date_range_with_limits(self,
233 set_id_field: bool = True
235 """Test .by_date_range_with_limits."""
236 # pylint: disable=too-many-locals
237 f = self.checked_class.by_date_range_with_limits
238 # check illegal ranges
239 legal_range = ('yesterday', 'tomorrow')
241 for bad_date in ['foo', '2024-02-30', '2024-01-01 12:00:00']:
242 date_range = list(legal_range[:])
243 date_range[i] = bad_date
244 with self.assertRaises(HandledException):
245 f(self.db_conn, date_range, date_col)
246 # check empty, translation of 'yesterday' and 'tomorrow'
247 items, start, end = f(self.db_conn, legal_range, date_col)
248 self.assertEqual(items, [])
249 yesterday = datetime.now() + timedelta(days=-1)
250 tomorrow = datetime.now() + timedelta(days=+1)
251 self.assertEqual(start, yesterday.strftime(DATE_FORMAT))
252 self.assertEqual(end, tomorrow.strftime(DATE_FORMAT))
253 # prepare dated items for non-empty results
254 kwargs_with_date = self.default_init_kwargs.copy()
256 kwargs_with_date['id_'] = None
258 dates = ['2024-01-01', '2024-01-02', '2024-01-04']
259 for date in ['2024-01-01', '2024-01-02', '2024-01-04']:
260 kwargs_with_date['date'] = date
261 obj = self.checked_class(**kwargs_with_date)
263 # check ranges still empty before saving
264 date_range = [dates[0], dates[-1]]
265 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
266 # check all objs displayed within closed interval
268 obj.save(self.db_conn)
269 self.assertEqual(f(self.db_conn, date_range, date_col)[0], objs)
270 # check that only displayed what exists within interval
271 date_range = ['2023-12-20', '2024-01-03']
272 expected = [objs[0], objs[1]]
273 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
274 date_range = ['2024-01-03', '2024-01-30']
276 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
277 # check that inverted interval displays nothing
278 date_range = [dates[-1], dates[0]]
279 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
280 # check that "today" is interpreted, and single-element interval
281 today_date = datetime.now().strftime(DATE_FORMAT)
282 kwargs_with_date['date'] = today_date
283 obj_today = self.checked_class(**kwargs_with_date)
284 obj_today.save(self.db_conn)
285 date_range = ['today', 'today']
286 items, start, end = f(self.db_conn, date_range, date_col)
287 self.assertEqual(start, today_date)
288 self.assertEqual(start, end)
289 self.assertEqual(items, [obj_today])
291 @TestCaseAugmented._run_if_with_db_but_not_server
292 @TestCaseAugmented._run_on_versioned_attributes
293 def test_saving_versioned_attributes(self,
296 attr: VersionedAttribute,
298 to_set: list[str] | list[float]
300 """Test storage and initialization of versioned attributes."""
302 def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
303 attr_vals_saved: list[object] = []
304 for row in self.db_conn.row_where(attr.table_name, 'parent',
306 attr_vals_saved += [row[2]]
307 return attr_vals_saved
310 # check that without attr.save() no rows in DB
311 rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
312 self.assertEqual([], rows)
313 # fail saving attributes on non-saved owner
314 with self.assertRaises(NotFoundException):
315 attr.save(self.db_conn)
316 # check owner.save() created entries as expected in attr table
317 owner.save(self.db_conn)
318 attr_vals_saved = retrieve_attr_vals(attr)
319 self.assertEqual([to_set[0]], attr_vals_saved)
320 # check changing attr val without save affects owner in memory …
322 cmp_attr = getattr(owner, attr_name)
323 self.assertEqual(to_set, list(cmp_attr.history.values()))
324 self.assertEqual(cmp_attr.history, attr.history)
325 # … but does not yet affect DB
326 attr_vals_saved = retrieve_attr_vals(attr)
327 self.assertEqual([to_set[0]], attr_vals_saved)
328 # check individual attr.save also stores new val to DB
329 attr.save(self.db_conn)
330 attr_vals_saved = retrieve_attr_vals(attr)
331 self.assertEqual(to_set, attr_vals_saved)
333 @TestCaseAugmented._run_if_with_db_but_not_server
334 def test_saving_and_caching(self) -> None:
335 """Test effects of .cache() and .save()."""
336 id1 = self.default_ids[0]
337 # check failure to cache without ID (if None-ID input possible)
338 if isinstance(id1, int):
339 obj0 = self._make_from_defaults(None)
340 with self.assertRaises(HandledException):
342 # check mere object init itself doesn't even store in cache
343 obj1 = self._make_from_defaults(id1)
344 self.assertEqual(self.checked_class.get_cache(), {})
345 # check .cache() fills cache, but not DB
347 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
348 found_in_db = self._load_from_db(id1)
349 self.assertEqual(found_in_db, [])
350 # check .save() sets ID (for int IDs), updates cache, and fills DB
351 # (expect ID to be set to id1, despite obj1 already having that as ID:
352 # it's generated by cursor.lastrowid on the DB table, and with obj1
353 # not written there, obj2 should get it first!)
354 id_input = None if isinstance(id1, int) else id1
355 obj2 = self._make_from_defaults(id_input)
356 obj2.save(self.db_conn)
357 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
358 # NB: we'll only compare hashes because obj2 itself disappears on
359 # .from_table_row-triggered database reload
360 obj2_hash = hash(obj2)
361 found_in_db += self._load_from_db(id1)
362 self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
363 # check we cannot overwrite obj2 with obj1 despite its same ID,
364 # since it has disappeared now
365 with self.assertRaises(HandledException):
366 obj1.save(self.db_conn)
368 @TestCaseAugmented._run_if_with_db_but_not_server
369 def test_by_id(self) -> None:
371 id1, id2, _ = self.default_ids
372 # check failure if not yet saved
373 obj1 = self._make_from_defaults(id1)
374 with self.assertRaises(NotFoundException):
375 self.checked_class.by_id(self.db_conn, id1)
376 # check identity of cached and retrieved
378 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
379 # check identity of saved and retrieved
380 obj2 = self._make_from_defaults(id2)
381 obj2.save(self.db_conn)
382 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
384 @TestCaseAugmented._run_if_with_db_but_not_server
385 def test_by_id_or_create(self) -> None:
386 """Test .by_id_or_create."""
387 # check .by_id_or_create fails if wrong class
388 if not self.checked_class.can_create_by_id:
389 with self.assertRaises(HandledException):
390 self.checked_class.by_id_or_create(self.db_conn, None)
392 # check ID input of None creates, on saving, ID=1,2,… for int IDs
393 if isinstance(self.default_ids[0], int):
395 item = self.checked_class.by_id_or_create(self.db_conn, None)
396 self.assertEqual(item.id_, None)
397 item.save(self.db_conn)
398 self.assertEqual(item.id_, n+1)
399 # check .by_id_or_create acts like normal instantiation (sans saving)
400 id_ = self.default_ids[2]
401 item = self.checked_class.by_id_or_create(self.db_conn, id_)
402 self.assertEqual(item.id_, id_)
403 with self.assertRaises(NotFoundException):
404 self.checked_class.by_id(self.db_conn, item.id_)
405 self.assertEqual(self.checked_class(item.id_), item)
407 @TestCaseAugmented._run_if_with_db_but_not_server
408 def test_from_table_row(self) -> None:
409 """Test .from_table_row() properly reads in class directly from DB."""
410 id_ = self.default_ids[0]
411 obj = self._make_from_defaults(id_)
412 obj.save(self.db_conn)
413 assert isinstance(obj.id_, type(id_))
414 for row in self.db_conn.row_where(self.checked_class.table_name,
416 # check .from_table_row reproduces state saved, no matter if obj
417 # later changed (with caching even)
418 # NB: we'll only compare hashes because obj itself disappears on
419 # .from_table_row-triggered database reload
420 hash_original = hash(obj)
421 attr_name = self._change_obj(obj)
423 to_cmp = getattr(obj, attr_name)
424 retrieved = self.checked_class.from_table_row(self.db_conn, row)
425 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
426 self.assertEqual(hash_original, hash(retrieved))
427 # check cache contains what .from_table_row just produced
428 self.assertEqual({retrieved.id_: retrieved},
429 self.checked_class.get_cache())
431 @TestCaseAugmented._run_if_with_db_but_not_server
432 @TestCaseAugmented._run_on_versioned_attributes
433 def test_versioned_history_from_row(self,
436 attr: VersionedAttribute,
437 default: str | float,
438 to_set: list[str] | list[float]
440 """"Test VersionedAttribute.history_from_row() knows its DB rows."""
443 owner.save(self.db_conn)
444 # make empty VersionedAttribute, fill from rows, compare to owner's
445 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
446 loaded_attr = VersionedAttribute(owner, attr.table_name, default)
447 for row in self.db_conn.row_where(attr.table_name, 'parent',
449 loaded_attr.history_from_row(row)
450 self.assertEqual(len(attr.history.keys()),
451 len(loaded_attr.history.keys()))
452 for timestamp, value in attr.history.items():
453 self.assertEqual(value, loaded_attr.history[timestamp])
455 @TestCaseAugmented._run_if_with_db_but_not_server
456 def test_all(self) -> None:
457 """Test .all() and its relation to cache and savings."""
458 id1, id2, id3 = self.default_ids
459 item1 = self._make_from_defaults(id1)
460 item2 = self._make_from_defaults(id2)
461 item3 = self._make_from_defaults(id3)
462 # check .all() returns empty list on un-cached items
463 self.assertEqual(self.checked_class.all(self.db_conn), [])
464 # check that all() shows only cached/saved items
466 item3.save(self.db_conn)
467 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
468 sorted([item1, item3]))
469 item2.save(self.db_conn)
470 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
471 sorted([item1, item2, item3]))
473 @TestCaseAugmented._run_if_with_db_but_not_server
474 def test_singularity(self) -> None:
475 """Test pointers made for single object keep pointing to it."""
476 id1 = self.default_ids[0]
477 obj = self._make_from_defaults(id1)
478 obj.save(self.db_conn)
479 # change object, expect retrieved through .by_id to carry change
480 attr_name = self._change_obj(obj)
481 new_attr = getattr(obj, attr_name)
482 retrieved = self.checked_class.by_id(self.db_conn, id1)
483 self.assertEqual(new_attr, getattr(retrieved, attr_name))
485 @TestCaseAugmented._run_if_with_db_but_not_server
486 @TestCaseAugmented._run_on_versioned_attributes
487 def test_versioned_singularity(self,
490 attr: VersionedAttribute,
492 to_set: list[str] | list[float]
494 """Test singularity of VersionedAttributes on saving."""
495 owner.save(self.db_conn)
496 # change obj, expect retrieved through .by_id to carry change
498 retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
499 attr_retrieved = getattr(retrieved, attr_name)
500 self.assertEqual(attr.history, attr_retrieved.history)
502 @TestCaseAugmented._run_if_with_db_but_not_server
503 def test_remove(self) -> None:
504 """Test .remove() effects on DB and cache."""
505 id_ = self.default_ids[0]
506 obj = self._make_from_defaults(id_)
507 # check removal only works after saving
508 with self.assertRaises(HandledException):
509 obj.remove(self.db_conn)
510 obj.save(self.db_conn)
511 obj.remove(self.db_conn)
512 # check access to obj fails after removal
513 with self.assertRaises(HandledException):
515 # check DB and cache now empty
516 self.check_identity_with_cache_and_db([])
520 """Builder of (JSON-like) dict to compare against responses of test server.
522 Collects all items and relations we expect expressed in the server's JSON
523 responses and puts them into the proper json.dumps-friendly dict structure,
524 accessibla via .as_dict, to compare them in TestsWithServer.check_json_get.
526 On its own provides for .as_dict output only {"_library": …}, initialized
527 from .__init__ and to be directly manipulated via the .lib* methods.
528 Further structures of the expected response may be added and kept
529 up-to-date by subclassing .__init__, .recalc, and .d.
531 NB: Lots of expectations towards server behavior will be made explicit here
532 (or in the subclasses) rather than in the actual TestCase methods' code.
534 _default_dict: dict[str, Any]
535 _forced: dict[str, Any]
536 _fields: dict[str, Any]
537 _on_empty_make_temp: tuple[str, str]
540 todos: list[dict[str, Any]] | None = None,
541 procs: list[dict[str, Any]] | None = None,
542 procsteps: list[dict[str, Any]] | None = None,
543 conds: list[dict[str, Any]] | None = None,
544 days: list[dict[str, Any]] | None = None
546 # pylint: disable=too-many-arguments
547 for name in ['_default_dict', '_fields', '_forced']:
548 if not hasattr(self, name):
549 setattr(self, name, {})
551 for title, items in [('Todo', todos),
553 ('ProcessStep', procsteps),
554 ('Condition', conds),
557 self._lib[title] = self._as_refs(items)
558 for k, v in self._default_dict.items():
559 if k not in self._fields:
562 def recalc(self) -> None:
563 """Update internal dictionary by subclass-specific rules."""
564 todos = self.lib_all('Todo')
568 for child_id in todo['children']:
569 self.lib_get('Todo', child_id)['parents'] += [todo['id']]
570 todo['children'].sort()
571 procsteps = self.lib_all('ProcessStep')
572 procs = self.lib_all('Process')
574 proc['explicit_steps'] = [s['id'] for s in procsteps
575 if s['owner_id'] == proc['id']]
578 def as_dict(self) -> dict[str, Any]:
579 """Return dict to compare against test server JSON responses."""
581 if hasattr(self, '_on_empty_make_temp'):
582 category, dicter = getattr(self, '_on_empty_make_temp')
583 id_ = self._fields[category.lower()]
584 make_temp = not bool(self.lib_get(category, id_))
586 self.lib_set(category, [getattr(self, dicter)(id_)])
588 d = {'_library': self._lib}
589 for k, v in self._fields.items():
590 # we expect everything sortable to be sorted
591 if isinstance(v, list) and k not in self._forced:
592 # NB: if we don't test for v being list, sorted() on an empty
593 # dict may return an empty list
599 for k, v in self._forced.items():
603 id_ = id_ if id_ is not None else '?'
604 self.lib_del(category, id_)
608 def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
609 """From library, return item of category and id_, or empty dict."""
611 if category in self._lib and str_id in self._lib[category]:
612 return self._lib[category][str_id]
615 def lib_all(self, category: str) -> list[dict[str, Any]]:
616 """From library, return items of category, or [] if none."""
617 if category in self._lib:
618 return list(self._lib[category].values())
621 def lib_set(self, category: str, items: list[dict[str, object]]) -> None:
622 """Update library for category with items."""
623 if category not in self._lib:
624 self._lib[category] = {}
625 for k, v in self._as_refs(items).items():
626 self._lib[category][k] = v
628 def lib_del(self, category: str, id_: str | int) -> None:
629 """Remove category element of id_ from library."""
630 del self._lib[category][str(id_)]
631 if 0 == len(self._lib[category]):
632 del self._lib[category]
634 def lib_wipe(self, category: str) -> None:
635 """Remove category from library."""
636 if category in self._lib:
637 del self._lib[category]
639 def set(self, field_name: str, value: object) -> None:
640 """Set top-level .as_dict field."""
641 self._fields[field_name] = value
643 def force(self, field_name: str, value: object) -> None:
644 """Set ._forced field to ensure value in .as_dict."""
645 self._forced[field_name] = value
647 def unforce(self, field_name: str) -> None:
648 """Unset ._forced field."""
649 del self._forced[field_name]
652 def _as_refs(items: list[dict[str, object]]
653 ) -> dict[str, dict[str, object]]:
654 """Return dictionary of items by their 'id' fields."""
657 id_ = str(item['id']) if item['id'] is not None else '?'
662 def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
663 """Return list of only 'id' fields of items."""
664 return [item['id'] for item in items]
667 def day_as_dict(date: str, comment: str = '') -> dict[str, object]:
668 """Return JSON of Day to expect."""
669 return {'id': date, 'comment': comment, 'todos': []}
671 def set_day_from_post(self, date: str, d: dict[str, Any]) -> None:
672 """Set Day of date in library based on POST dict d."""
673 day = self.day_as_dict(date)
674 for k, v in d.items():
675 if 'day_comment' == k:
677 elif 'new_todo' == k:
679 for todo in self.lib_all('Todo'):
680 if next_id <= todo['id']:
681 next_id = todo['id'] + 1
682 for proc_id in sorted([id_ for id_ in v if id_]):
683 todo = self.todo_as_dict(next_id, proc_id, date)
684 self.lib_set('Todo', [todo])
688 self.lib_get('Todo', todo_id)['is_done'] = True
690 for i, todo_id in enumerate(v):
691 t = self.lib_get('Todo', todo_id)
693 t['comment'] = d['comment'][i]
695 effort = d['effort'][i] if d['effort'][i] else None
697 self.lib_set('Day', [day])
700 def cond_as_dict(id_: int = 1,
701 is_active: bool = False,
702 title: None | str = None,
703 description: None | str = None,
704 ) -> dict[str, object]:
705 """Return JSON of Condition to expect."""
706 versioned: dict[str, dict[str, object]]
707 versioned = {'title': {}, 'description': {}}
708 if title is not None:
709 versioned['title']['0'] = title
710 if description is not None:
711 versioned['description']['0'] = description
712 return {'id': id_, 'is_active': is_active, '_versioned': versioned}
714 def set_cond_from_post(self, id_: int, d: dict[str, Any]) -> None:
715 """Set Condition of id_ in library based on POST dict d."""
716 if d == {'delete': ''}:
717 self.lib_del('Condition', id_)
719 cond = self.lib_get('Condition', id_)
722 cond['is_active'] = d['is_active']
723 for category in ['title', 'description']:
724 history = cond['_versioned'][category]
726 last_i = sorted([int(k) for k in history.keys()])[-1]
727 if d[category] != history[str(last_i)]:
728 history[str(last_i + 1)] = d[category]
730 history['0'] = d[category]
732 cond = self.cond_as_dict(id_, **d)
733 self.lib_set('Condition', [cond])
736 def todo_as_dict(id_: int = 1,
738 date: str = '2024-01-01',
739 conditions: None | list[int] = None,
740 disables: None | list[int] = None,
741 blockers: None | list[int] = None,
742 enables: None | list[int] = None,
743 calendarize: bool = False,
745 is_done: bool = False,
746 effort: float | None = None,
747 children: list[int] | None = None,
748 parents: list[int] | None = None,
749 ) -> dict[str, object]:
750 """Return JSON of Todo to expect."""
751 # pylint: disable=too-many-arguments
754 'process_id': process_id,
756 'calendarize': calendarize,
758 'children': children if children else [],
759 'parents': parents if parents else [],
761 'conditions': conditions if conditions else [],
762 'disables': disables if disables else [],
763 'blockers': blockers if blockers else [],
764 'enables': enables if enables else []}
767 def set_todo_from_post(self, id_: int, d: dict[str, Any]) -> None:
768 """Set Todo of id_ in library based on POST dict d."""
769 corrected_kwargs: dict[str, Any] = {'children': []}
770 for k, v in d.items():
771 if k.startswith('step_filler_to_'):
774 new_children = v if isinstance(v, list) else [v]
775 corrected_kwargs['children'] += new_children
777 if k in {'is_done', 'calendarize'}:
779 corrected_kwargs[k] = v
780 todo = self.lib_get('Todo', id_)
782 for k, v in corrected_kwargs.items():
785 todo = self.todo_as_dict(id_, **corrected_kwargs)
786 self.lib_set('Todo', [todo])
789 def procstep_as_dict(id_: int,
791 step_process_id: int,
792 parent_step_id: int | None = None
793 ) -> dict[str, object]:
794 """Return JSON of ProcessStep to expect."""
796 'owner_id': owner_id,
797 'step_process_id': step_process_id,
798 'parent_step_id': parent_step_id}
801 def proc_as_dict(id_: int = 1,
802 title: None | str = None,
803 description: None | str = None,
804 effort: None | float = None,
805 conditions: None | list[int] = None,
806 disables: None | list[int] = None,
807 blockers: None | list[int] = None,
808 enables: None | list[int] = None,
809 explicit_steps: None | list[int] = None
810 ) -> dict[str, object]:
811 """Return JSON of Process to expect."""
812 # pylint: disable=too-many-arguments
813 versioned: dict[str, dict[str, object]]
814 versioned = {'title': {}, 'description': {}, 'effort': {}}
815 if title is not None:
816 versioned['title']['0'] = title
817 if description is not None:
818 versioned['description']['0'] = description
819 if effort is not None:
820 versioned['effort']['0'] = effort
822 'calendarize': False,
823 'suppressed_steps': [],
824 'explicit_steps': explicit_steps if explicit_steps else [],
825 '_versioned': versioned,
826 'conditions': conditions if conditions else [],
827 'disables': disables if disables else [],
828 'enables': enables if enables else [],
829 'blockers': blockers if blockers else []}
832 def set_proc_from_post(self, id_: int, d: dict[str, Any]) -> None:
833 """Set Process of id_ in library based on POST dict d."""
834 proc = self.lib_get('Process', id_)
836 for category in ['title', 'description', 'effort']:
837 history = proc['_versioned'][category]
839 last_i = sorted([int(k) for k in history.keys()])[-1]
840 if d[category] != history[str(last_i)]:
841 history[str(last_i + 1)] = d[category]
843 history['0'] = d[category]
845 proc = self.proc_as_dict(id_,
846 d['title'], d['description'], d['effort'])
847 ignore = {'title', 'description', 'effort', 'new_top_step', 'step_of',
849 for k, v in d.items():
851 or k.startswith('step_') or k.startswith('new_step_to'):
853 if k in {'calendarize'}:
855 elif k in {'suppressed_steps', 'explicit_steps', 'conditions',
856 'disables', 'enables', 'blockers'}:
857 if not isinstance(v, list):
860 self.lib_set('Process', [proc])
863 class TestCaseWithServer(TestCaseWithDB):
864 """Module tests against our HTTP server/handler (and database)."""
866 def setUp(self) -> None:
868 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
869 self.server_thread = Thread(target=self.httpd.serve_forever)
870 self.server_thread.daemon = True
871 self.server_thread.start()
872 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
873 self.httpd.server_address[1])
874 self.httpd.render_mode = 'json'
876 def tearDown(self) -> None:
877 self.httpd.shutdown()
878 self.httpd.server_close()
879 self.server_thread.join()
882 def post_exp_cond(self,
883 exps: list[Expected],
884 payload: dict[str, object],
886 post_to_id: bool = True,
887 redir_to_id: bool = True
889 """POST /condition(s), appropriately update Expecteds."""
890 # pylint: disable=too-many-arguments
891 target = f'/condition?id={id_}' if post_to_id else '/condition'
892 redir = f'/condition?id={id_}' if redir_to_id else '/conditions'
893 self.check_post(payload, target, redir=redir)
895 exp.set_cond_from_post(id_, payload)
897 def post_exp_day(self,
898 exps: list[Expected],
899 payload: dict[str, Any],
900 date: str = '2024-01-01'
902 """POST /day, appropriately update Expecteds."""
903 if 'make_type' not in payload:
904 payload['make_type'] = 'empty'
905 if 'day_comment' not in payload:
906 payload['day_comment'] = ''
907 target = f'/day?date={date}'
908 redir_to = f'{target}&make_type={payload["make_type"]}'
909 self.check_post(payload, target, 302, redir_to)
911 exp.set_day_from_post(date, payload)
913 def post_exp_process(self,
914 exps: list[Expected],
915 payload: dict[str, Any],
917 ) -> dict[str, object]:
918 """POST /process, appropriately update Expecteds."""
919 if 'title' not in payload:
920 payload['title'] = 'foo'
921 if 'description' not in payload:
922 payload['description'] = 'foo'
923 if 'effort' not in payload:
924 payload['effort'] = 1.1
925 self.check_post(payload, f'/process?id={id_}',
926 redir=f'/process?id={id_}')
928 exp.set_proc_from_post(id_, payload)
931 def check_filter(self, exp: Expected, category: str, key: str,
932 val: str, list_ids: list[int]) -> None:
933 """Check GET /{category}?{key}={val} sorts to list_ids."""
934 # pylint: disable=too-many-arguments
936 exp.force(category, list_ids)
937 self.check_json_get(f'/{category}?{key}={val}', exp)
939 def check_redirect(self, target: str) -> None:
940 """Check that self.conn answers with a 302 redirect to target."""
941 response = self.conn.getresponse()
942 self.assertEqual(response.status, 302)
943 self.assertEqual(response.getheader('Location'), target)
945 def check_get(self, target: str, expected_code: int) -> None:
946 """Check that a GET to target yields expected_code."""
947 self.conn.request('GET', target)
948 self.assertEqual(self.conn.getresponse().status, expected_code)
950 def check_minimal_inputs(self,
952 minimal_inputs: dict[str, Any]
954 """Check that url 400's unless all of minimal_inputs provided."""
955 for to_hide in minimal_inputs.keys():
956 to_post = {k: v for k, v in minimal_inputs.items() if k != to_hide}
957 self.check_post(to_post, url, 400)
959 def check_post(self, data: Mapping[str, object], target: str,
960 expected_code: int = 302, redir: str = '') -> None:
961 """Check that POST of data to target yields expected_code."""
962 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
963 headers = {'Content-Type': 'application/x-www-form-urlencoded',
964 'Content-Length': str(len(encoded_form_data))}
965 self.conn.request('POST', target,
966 body=encoded_form_data, headers=headers)
967 if 302 == expected_code:
968 redir = target if redir == '' else redir
969 self.check_redirect(redir)
971 self.assertEqual(self.conn.getresponse().status, expected_code)
973 def check_get_defaults(self,
975 default_id: str = '1',
978 """Some standard model paths to test."""
979 nonexist_status = 200 if self.checked_class.can_create_by_id else 404
980 self.check_get(path, nonexist_status)
981 self.check_get(f'{path}?{id_name}=', 400)
982 self.check_get(f'{path}?{id_name}=foo', 400)
983 self.check_get(f'/{path}?{id_name}=0', 400)
984 self.check_get(f'{path}?{id_name}={default_id}', nonexist_status)
986 def check_json_get(self, path: str, expected: Expected) -> None:
987 """Compare JSON on GET path with expected.
989 To simplify comparison of VersionedAttribute histories, transforms
990 timestamp keys of VersionedAttribute history keys into (strings of)
991 integers counting chronologically forward from 0.
994 def rewrite_history_keys_in(item: Any) -> Any:
995 if isinstance(item, dict):
996 if '_versioned' in item.keys():
997 for category in item['_versioned']:
998 vals = item['_versioned'][category].values()
1000 for i, val in enumerate(vals):
1001 history[str(i)] = val
1002 item['_versioned'][category] = history
1003 for category in list(item.keys()):
1004 rewrite_history_keys_in(item[category])
1005 elif isinstance(item, list):
1006 item[:] = [rewrite_history_keys_in(i) for i in item]
1009 def walk_diffs(path: str, cmp1: object, cmp2: object) -> None:
1010 # pylint: disable=too-many-branches
1011 def warn(intro: str, val: object) -> None:
1012 if isinstance(val, (str, int, float)):
1018 if isinstance(cmp1, dict) and isinstance(cmp2, dict):
1019 for k, v in cmp1.items():
1021 warn(f'DIFF {path}: retrieved lacks {k}', v)
1023 walk_diffs(f'{path}:{k}', v, cmp2[k])
1024 for k in [k for k in cmp2.keys() if k not in cmp1]:
1025 warn(f'DIFF {path}: expected lacks retrieved\'s {k}',
1027 elif isinstance(cmp1, list) and isinstance(cmp2, list):
1028 for i, v1 in enumerate(cmp1):
1030 warn(f'DIFF {path}[{i}] retrieved misses:', v1)
1032 walk_diffs(f'{path}[{i}]', v1, cmp2[i])
1033 if len(cmp2) > len(cmp1):
1034 for i, v2 in enumerate(cmp2[len(cmp1):]):
1035 warn(f'DIFF {path}[{len(cmp1)+i}] misses:', v2)
1037 warn(f'DIFF {path} – for expected:', cmp1)
1038 warn('… and for retrieved:', cmp2)
1040 self.conn.request('GET', path)
1041 response = self.conn.getresponse()
1042 self.assertEqual(response.status, 200)
1043 retrieved = json_loads(response.read().decode())
1044 rewrite_history_keys_in(retrieved)
1045 cmp = expected.as_dict
1047 self.assertEqual(cmp, retrieved)
1048 except AssertionError as e:
1053 walk_diffs('', cmp, retrieved)