1 """Shared test utilities."""
2 # pylint: disable=too-many-lines
3 from __future__ import annotations
4 from unittest import TestCase
5 from typing import Mapping, Any, Callable
6 from threading import Thread
7 from http.client import HTTPConnection
8 from datetime import datetime, timedelta
10 from json import loads as json_loads, dumps as json_dumps
11 from urllib.parse import urlencode
12 from uuid import uuid4
13 from os import remove as remove_file
14 from pprint import pprint
15 from plomtask.db import DatabaseFile, DatabaseConnection
16 from plomtask.http import TaskHandler, TaskServer
17 from plomtask.processes import Process, ProcessStep
18 from plomtask.conditions import Condition
19 from plomtask.days import Day
20 from plomtask.dating import DATE_FORMAT
21 from plomtask.todos import Todo
22 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
23 from plomtask.exceptions import NotFoundException, HandledException
26 VERSIONED_VALS: dict[str,
27 list[str] | list[float]] = {'str': ['A', 'B'],
29 VALID_TRUES = {True, 'True', 'true', '1', 'on'}
32 class TestCaseAugmented(TestCase):
33 """Tester core providing helpful basic internal decorators and methods."""
35 default_init_kwargs: dict[str, Any] = {}
38 def _run_on_versioned_attributes(f: Callable[..., None]
39 ) -> Callable[..., None]:
40 def wrapper(self: TestCase) -> None:
41 assert isinstance(self, TestCaseAugmented)
42 for attr_name in self.checked_class.to_save_versioned():
43 default = self.checked_class.versioned_defaults[attr_name]
44 owner = self.checked_class(None, **self.default_init_kwargs)
45 attr = getattr(owner, attr_name)
46 to_set = VERSIONED_VALS[attr.value_type_name]
47 f(self, owner, attr_name, attr, default, to_set)
51 def _run_if_sans_db(cls, f: Callable[..., None]) -> Callable[..., None]:
52 def wrapper(self: TestCaseSansDB) -> None:
53 if issubclass(cls, TestCaseSansDB):
58 def _run_if_with_db_but_not_server(cls,
59 f: Callable[..., None]
60 ) -> Callable[..., None]:
61 def wrapper(self: TestCaseWithDB) -> None:
62 if issubclass(cls, TestCaseWithDB) and\
63 not issubclass(cls, TestCaseWithServer):
68 def _make_from_defaults(cls, id_: float | str | None) -> Any:
69 return cls.checked_class(id_, **cls.default_init_kwargs)
72 class TestCaseSansDB(TestCaseAugmented):
73 """Tests requiring no DB setup."""
74 legal_ids: list[str] | list[int] = [1, 5]
75 illegal_ids: list[str] | list[int] = [0]
77 @TestCaseAugmented._run_if_sans_db
78 def test_id_validation(self) -> None:
79 """Test .id_ validation/setting."""
80 for id_ in self.illegal_ids:
81 with self.assertRaises(HandledException):
82 self._make_from_defaults(id_)
83 for id_ in self.legal_ids:
84 obj = self._make_from_defaults(id_)
85 self.assertEqual(obj.id_, id_)
87 @TestCaseAugmented._run_if_sans_db
88 @TestCaseAugmented._run_on_versioned_attributes
89 def test_versioned_set(self,
92 attr: VersionedAttribute,
94 to_set: list[str] | list[float]
96 """Test VersionedAttribute.set() behaves as expected."""
98 self.assertEqual(list(attr.history.values()), [default])
99 # check same value does not get set twice in a row,
100 # and that not even its timestamp get updated
101 timestamp = list(attr.history.keys())[0]
103 self.assertEqual(list(attr.history.values()), [default])
104 self.assertEqual(list(attr.history.keys())[0], timestamp)
105 # check that different value _will_ be set/added
107 timesorted_vals = [attr.history[t] for
108 t in sorted(attr.history.keys())]
109 expected = [default, to_set[0]]
110 self.assertEqual(timesorted_vals, expected)
111 # check that a previously used value can be set if not most recent
113 timesorted_vals = [attr.history[t] for
114 t in sorted(attr.history.keys())]
115 expected = [default, to_set[0], default]
116 self.assertEqual(timesorted_vals, expected)
117 # again check for same value not being set twice in a row, even for
120 timesorted_vals = [attr.history[t] for
121 t in sorted(attr.history.keys())]
122 expected = [default, to_set[0], default, to_set[1]]
123 self.assertEqual(timesorted_vals, expected)
125 self.assertEqual(timesorted_vals, expected)
127 @TestCaseAugmented._run_if_sans_db
128 @TestCaseAugmented._run_on_versioned_attributes
129 def test_versioned_newest(self,
132 attr: VersionedAttribute,
133 default: str | float,
134 to_set: list[str] | list[float]
136 """Test VersionedAttribute.newest."""
137 # check .newest on empty history returns .default
138 self.assertEqual(attr.newest, default)
139 # check newest element always returned
140 for v in [to_set[0], to_set[1]]:
142 self.assertEqual(attr.newest, v)
143 # check newest element returned even if also early value
145 self.assertEqual(attr.newest, default)
147 @TestCaseAugmented._run_if_sans_db
148 @TestCaseAugmented._run_on_versioned_attributes
149 def test_versioned_at(self,
152 attr: VersionedAttribute,
153 default: str | float,
154 to_set: list[str] | list[float]
156 """Test .at() returns values nearest to queried time, or default."""
157 # check .at() return default on empty history
158 timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
159 self.assertEqual(attr.at(timestamp_a), default)
160 # check value exactly at timestamp returned
162 timestamp_b = list(attr.history.keys())[0]
163 self.assertEqual(attr.at(timestamp_b), to_set[0])
164 # check earliest value returned if exists, rather than default
165 self.assertEqual(attr.at(timestamp_a), to_set[0])
166 # check reverts to previous value for timestamps not indexed
168 timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
171 timestamp_c = sorted(attr.history.keys())[-1]
172 self.assertEqual(attr.at(timestamp_c), to_set[1])
173 self.assertEqual(attr.at(timestamp_between), to_set[0])
175 timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
176 self.assertEqual(attr.at(timestamp_after_c), to_set[1])
179 class TestCaseWithDB(TestCaseAugmented):
180 """Module tests not requiring DB setup."""
181 default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
183 def setUp(self) -> None:
184 Condition.empty_cache()
186 Process.empty_cache()
187 ProcessStep.empty_cache()
189 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
190 self.db_conn = DatabaseConnection(self.db_file)
192 def tearDown(self) -> None:
194 remove_file(self.db_file.path)
196 def _load_from_db(self, id_: int | str) -> list[object]:
197 db_found: list[object] = []
198 for row in self.db_conn.row_where(self.checked_class.table_name,
200 db_found += [self.checked_class.from_table_row(self.db_conn,
204 def _change_obj(self, obj: object) -> str:
205 attr_name: str = self.checked_class.to_save_simples[-1]
206 attr = getattr(obj, attr_name)
207 new_attr: str | int | float | bool
208 if isinstance(attr, (int, float)):
210 elif isinstance(attr, str):
211 new_attr = attr + '_'
212 elif isinstance(attr, bool):
214 setattr(obj, attr_name, new_attr)
217 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
218 """Test both cache and DB equal content."""
221 expected_cache[item.id_] = item
222 self.assertEqual(self.checked_class.get_cache(), expected_cache)
223 hashes_content = [hash(x) for x in content]
224 db_found: list[Any] = []
226 assert isinstance(item.id_, type(self.default_ids[0]))
227 db_found += self._load_from_db(item.id_)
228 hashes_db_found = [hash(x) for x in db_found]
229 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
231 def check_by_date_range_with_limits(self,
233 set_id_field: bool = True
235 """Test .by_date_range_with_limits."""
236 # pylint: disable=too-many-locals
237 f = self.checked_class.by_date_range_with_limits
238 # check illegal ranges
239 legal_range = ('yesterday', 'tomorrow')
241 for bad_date in ['foo', '2024-02-30', '2024-01-01 12:00:00']:
242 date_range = list(legal_range[:])
243 date_range[i] = bad_date
244 with self.assertRaises(HandledException):
245 f(self.db_conn, date_range, date_col)
246 # check empty, translation of 'yesterday' and 'tomorrow'
247 items, start, end = f(self.db_conn, legal_range, date_col)
248 self.assertEqual(items, [])
249 yesterday = datetime.now() + timedelta(days=-1)
250 tomorrow = datetime.now() + timedelta(days=+1)
251 self.assertEqual(start, yesterday.strftime(DATE_FORMAT))
252 self.assertEqual(end, tomorrow.strftime(DATE_FORMAT))
253 # prepare dated items for non-empty results
254 kwargs_with_date = self.default_init_kwargs.copy()
256 kwargs_with_date['id_'] = None
258 dates = ['2024-01-01', '2024-01-02', '2024-01-04']
259 for date in ['2024-01-01', '2024-01-02', '2024-01-04']:
260 kwargs_with_date['date'] = date
261 obj = self.checked_class(**kwargs_with_date)
263 # check ranges still empty before saving
264 date_range = [dates[0], dates[-1]]
265 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
266 # check all objs displayed within closed interval
268 obj.save(self.db_conn)
269 self.assertEqual(f(self.db_conn, date_range, date_col)[0], objs)
270 # check that only displayed what exists within interval
271 date_range = ['2023-12-20', '2024-01-03']
272 expected = [objs[0], objs[1]]
273 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
274 date_range = ['2024-01-03', '2024-01-30']
276 self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
277 # check that inverted interval displays nothing
278 date_range = [dates[-1], dates[0]]
279 self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
280 # check that "today" is interpreted, and single-element interval
281 today_date = datetime.now().strftime(DATE_FORMAT)
282 kwargs_with_date['date'] = today_date
283 obj_today = self.checked_class(**kwargs_with_date)
284 obj_today.save(self.db_conn)
285 date_range = ['today', 'today']
286 items, start, end = f(self.db_conn, date_range, date_col)
287 self.assertEqual(start, today_date)
288 self.assertEqual(start, end)
289 self.assertEqual(items, [obj_today])
291 @TestCaseAugmented._run_if_with_db_but_not_server
292 @TestCaseAugmented._run_on_versioned_attributes
293 def test_saving_versioned_attributes(self,
296 attr: VersionedAttribute,
298 to_set: list[str] | list[float]
300 """Test storage and initialization of versioned attributes."""
302 def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
303 attr_vals_saved: list[object] = []
304 for row in self.db_conn.row_where(attr.table_name, 'parent',
306 attr_vals_saved += [row[2]]
307 return attr_vals_saved
310 # check that without attr.save() no rows in DB
311 rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
312 self.assertEqual([], rows)
313 # fail saving attributes on non-saved owner
314 with self.assertRaises(NotFoundException):
315 attr.save(self.db_conn)
316 # check owner.save() created entries as expected in attr table
317 owner.save(self.db_conn)
318 attr_vals_saved = retrieve_attr_vals(attr)
319 self.assertEqual([to_set[0]], attr_vals_saved)
320 # check changing attr val without save affects owner in memory …
322 cmp_attr = getattr(owner, attr_name)
323 self.assertEqual(to_set, list(cmp_attr.history.values()))
324 self.assertEqual(cmp_attr.history, attr.history)
325 # … but does not yet affect DB
326 attr_vals_saved = retrieve_attr_vals(attr)
327 self.assertEqual([to_set[0]], attr_vals_saved)
328 # check individual attr.save also stores new val to DB
329 attr.save(self.db_conn)
330 attr_vals_saved = retrieve_attr_vals(attr)
331 self.assertEqual(to_set, attr_vals_saved)
333 @TestCaseAugmented._run_if_with_db_but_not_server
334 def test_saving_and_caching(self) -> None:
335 """Test effects of .cache() and .save()."""
336 id1 = self.default_ids[0]
337 # check failure to cache without ID (if None-ID input possible)
338 if isinstance(id1, int):
339 obj0 = self._make_from_defaults(None)
340 with self.assertRaises(HandledException):
342 # check mere object init itself doesn't even store in cache
343 obj1 = self._make_from_defaults(id1)
344 self.assertEqual(self.checked_class.get_cache(), {})
345 # check .cache() fills cache, but not DB
347 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
348 found_in_db = self._load_from_db(id1)
349 self.assertEqual(found_in_db, [])
350 # check .save() sets ID (for int IDs), updates cache, and fills DB
351 # (expect ID to be set to id1, despite obj1 already having that as ID:
352 # it's generated by cursor.lastrowid on the DB table, and with obj1
353 # not written there, obj2 should get it first!)
354 id_input = None if isinstance(id1, int) else id1
355 obj2 = self._make_from_defaults(id_input)
356 obj2.save(self.db_conn)
357 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
358 # NB: we'll only compare hashes because obj2 itself disappears on
359 # .from_table_row-triggered database reload
360 obj2_hash = hash(obj2)
361 found_in_db += self._load_from_db(id1)
362 self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
363 # check we cannot overwrite obj2 with obj1 despite its same ID,
364 # since it has disappeared now
365 with self.assertRaises(HandledException):
366 obj1.save(self.db_conn)
368 @TestCaseAugmented._run_if_with_db_but_not_server
369 def test_by_id(self) -> None:
371 id1, id2, _ = self.default_ids
372 # check failure if not yet saved
373 obj1 = self._make_from_defaults(id1)
374 with self.assertRaises(NotFoundException):
375 self.checked_class.by_id(self.db_conn, id1)
376 # check identity of cached and retrieved
378 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
379 # check identity of saved and retrieved
380 obj2 = self._make_from_defaults(id2)
381 obj2.save(self.db_conn)
382 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
384 @TestCaseAugmented._run_if_with_db_but_not_server
385 def test_by_id_or_create(self) -> None:
386 """Test .by_id_or_create."""
387 # check .by_id_or_create fails if wrong class
388 if not self.checked_class.can_create_by_id:
389 with self.assertRaises(HandledException):
390 self.checked_class.by_id_or_create(self.db_conn, None)
392 # check ID input of None creates, on saving, ID=1,2,… for int IDs
393 if isinstance(self.default_ids[0], int):
395 item = self.checked_class.by_id_or_create(self.db_conn, None)
396 self.assertEqual(item.id_, None)
397 item.save(self.db_conn)
398 self.assertEqual(item.id_, n+1)
399 # check .by_id_or_create acts like normal instantiation (sans saving)
400 id_ = self.default_ids[2]
401 item = self.checked_class.by_id_or_create(self.db_conn, id_)
402 self.assertEqual(item.id_, id_)
403 with self.assertRaises(NotFoundException):
404 self.checked_class.by_id(self.db_conn, item.id_)
405 self.assertEqual(self.checked_class(item.id_), item)
407 @TestCaseAugmented._run_if_with_db_but_not_server
408 def test_from_table_row(self) -> None:
409 """Test .from_table_row() properly reads in class directly from DB."""
410 id_ = self.default_ids[0]
411 obj = self._make_from_defaults(id_)
412 obj.save(self.db_conn)
413 assert isinstance(obj.id_, type(id_))
414 for row in self.db_conn.row_where(self.checked_class.table_name,
416 # check .from_table_row reproduces state saved, no matter if obj
417 # later changed (with caching even)
418 # NB: we'll only compare hashes because obj itself disappears on
419 # .from_table_row-triggered database reload
420 hash_original = hash(obj)
421 attr_name = self._change_obj(obj)
423 to_cmp = getattr(obj, attr_name)
424 retrieved = self.checked_class.from_table_row(self.db_conn, row)
425 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
426 self.assertEqual(hash_original, hash(retrieved))
427 # check cache contains what .from_table_row just produced
428 self.assertEqual({retrieved.id_: retrieved},
429 self.checked_class.get_cache())
431 @TestCaseAugmented._run_if_with_db_but_not_server
432 @TestCaseAugmented._run_on_versioned_attributes
433 def test_versioned_history_from_row(self,
436 attr: VersionedAttribute,
437 default: str | float,
438 to_set: list[str] | list[float]
440 """"Test VersionedAttribute.history_from_row() knows its DB rows."""
443 owner.save(self.db_conn)
444 # make empty VersionedAttribute, fill from rows, compare to owner's
445 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
446 loaded_attr = VersionedAttribute(owner, attr.table_name, default)
447 for row in self.db_conn.row_where(attr.table_name, 'parent',
449 loaded_attr.history_from_row(row)
450 self.assertEqual(len(attr.history.keys()),
451 len(loaded_attr.history.keys()))
452 for timestamp, value in attr.history.items():
453 self.assertEqual(value, loaded_attr.history[timestamp])
455 @TestCaseAugmented._run_if_with_db_but_not_server
456 def test_all(self) -> None:
457 """Test .all() and its relation to cache and savings."""
458 id1, id2, id3 = self.default_ids
459 item1 = self._make_from_defaults(id1)
460 item2 = self._make_from_defaults(id2)
461 item3 = self._make_from_defaults(id3)
462 # check .all() returns empty list on un-cached items
463 self.assertEqual(self.checked_class.all(self.db_conn), [])
464 # check that all() shows only cached/saved items
466 item3.save(self.db_conn)
467 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
468 sorted([item1, item3]))
469 item2.save(self.db_conn)
470 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
471 sorted([item1, item2, item3]))
473 @TestCaseAugmented._run_if_with_db_but_not_server
474 def test_singularity(self) -> None:
475 """Test pointers made for single object keep pointing to it."""
476 id1 = self.default_ids[0]
477 obj = self._make_from_defaults(id1)
478 obj.save(self.db_conn)
479 # change object, expect retrieved through .by_id to carry change
480 attr_name = self._change_obj(obj)
481 new_attr = getattr(obj, attr_name)
482 retrieved = self.checked_class.by_id(self.db_conn, id1)
483 self.assertEqual(new_attr, getattr(retrieved, attr_name))
485 @TestCaseAugmented._run_if_with_db_but_not_server
486 @TestCaseAugmented._run_on_versioned_attributes
487 def test_versioned_singularity(self,
490 attr: VersionedAttribute,
492 to_set: list[str] | list[float]
494 """Test singularity of VersionedAttributes on saving."""
495 owner.save(self.db_conn)
496 # change obj, expect retrieved through .by_id to carry change
498 retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
499 attr_retrieved = getattr(retrieved, attr_name)
500 self.assertEqual(attr.history, attr_retrieved.history)
502 @TestCaseAugmented._run_if_with_db_but_not_server
503 def test_remove(self) -> None:
504 """Test .remove() effects on DB and cache."""
505 id_ = self.default_ids[0]
506 obj = self._make_from_defaults(id_)
507 # check removal only works after saving
508 with self.assertRaises(HandledException):
509 obj.remove(self.db_conn)
510 obj.save(self.db_conn)
511 obj.remove(self.db_conn)
512 # check access to obj fails after removal
513 with self.assertRaises(HandledException):
515 # check DB and cache now empty
516 self.check_identity_with_cache_and_db([])
520 """Builder of (JSON-like) dict to compare against responses of test server.
522 Collects all items and relations we expect expressed in the server's JSON
523 responses and puts them into the proper json.dumps-friendly dict structure,
524 accessibla via .as_dict, to compare them in TestsWithServer.check_json_get.
526 On its own provides for .as_dict output only {"_library": …}, initialized
527 from .__init__ and to be directly manipulated via the .lib* methods.
528 Further structures of the expected response may be added and kept
529 up-to-date by subclassing .__init__, .recalc, and .d.
531 NB: Lots of expectations towards server behavior will be made explicit here
532 (or in the subclasses) rather than in the actual TestCase methods' code.
534 _default_dict: dict[str, Any]
535 _forced: dict[str, Any]
536 _fields: dict[str, Any]
537 _on_empty_make_temp: tuple[str, str]
540 todos: list[dict[str, Any]] | None = None,
541 procs: list[dict[str, Any]] | None = None,
542 procsteps: list[dict[str, Any]] | None = None,
543 conds: list[dict[str, Any]] | None = None,
544 days: list[dict[str, Any]] | None = None
546 # pylint: disable=too-many-arguments
547 for name in ['_default_dict', '_fields', '_forced']:
548 if not hasattr(self, name):
549 setattr(self, name, {})
551 for title, items in [('Todo', todos),
553 ('ProcessStep', procsteps),
554 ('Condition', conds),
557 self._lib[title] = self._as_refs(items)
558 for k, v in self._default_dict.items():
559 if k not in self._fields:
562 def recalc(self) -> None:
563 """Update internal dictionary by subclass-specific rules."""
564 todos = self.lib_all('Todo')
568 for child_id in todo['children']:
569 self.lib_get('Todo', child_id)['parents'] += [todo['id']]
570 todo['children'].sort()
571 procsteps = self.lib_all('ProcessStep')
572 procs = self.lib_all('Process')
574 proc['explicit_steps'] = [s['id'] for s in procsteps
575 if s['owner_id'] == proc['id']]
578 def as_dict(self) -> dict[str, Any]:
579 """Return dict to compare against test server JSON responses."""
581 if hasattr(self, '_on_empty_make_temp'):
582 category, dicter = getattr(self, '_on_empty_make_temp')
583 id_ = self._fields[category.lower()]
584 make_temp = not bool(self.lib_get(category, id_))
586 self.lib_set(category, [getattr(self, dicter)(id_)])
588 d = {'_library': self._lib}
589 for k, v in self._fields.items():
590 # we expect everything sortable to be sorted
591 if isinstance(v, list) and k not in self._forced:
592 # NB: if we don't test for v being list, sorted() on an empty
593 # dict may return an empty list
599 for k, v in self._forced.items():
603 id_ = id_ if id_ is not None else '?'
604 self.lib_del(category, id_)
608 def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
609 """From library, return item of category and id_, or empty dict."""
611 if category in self._lib and str_id in self._lib[category]:
612 return self._lib[category][str_id]
615 def lib_all(self, category: str) -> list[dict[str, Any]]:
616 """From library, return items of category, or [] if none."""
617 if category in self._lib:
618 return list(self._lib[category].values())
621 def lib_set(self, category: str, items: list[dict[str, object]]) -> None:
622 """Update library for category with items."""
623 if category not in self._lib:
624 self._lib[category] = {}
625 for k, v in self._as_refs(items).items():
626 self._lib[category][k] = v
628 def lib_del(self, category: str, id_: str | int) -> None:
629 """Remove category element of id_ from library."""
630 del self._lib[category][str(id_)]
631 if 0 == len(self._lib[category]):
632 del self._lib[category]
634 def lib_wipe(self, category: str) -> None:
635 """Remove category from library."""
636 if category in self._lib:
637 del self._lib[category]
639 def set(self, field_name: str, value: object) -> None:
640 """Set top-level .as_dict field."""
641 self._fields[field_name] = value
643 def force(self, field_name: str, value: object) -> None:
644 """Set ._forced field to ensure value in .as_dict."""
645 self._forced[field_name] = value
647 def unforce(self, field_name: str) -> None:
648 """Unset ._forced field."""
649 del self._forced[field_name]
652 def _as_refs(items: list[dict[str, object]]
653 ) -> dict[str, dict[str, object]]:
654 """Return dictionary of items by their 'id' fields."""
657 id_ = str(item['id']) if item['id'] is not None else '?'
662 def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
663 """Return list of only 'id' fields of items."""
664 return [item['id'] for item in items]
667 def day_as_dict(date: str, comment: str = '') -> dict[str, object]:
668 """Return JSON of Day to expect."""
669 return {'id': date, 'comment': comment, 'todos': []}
671 def set_day_from_post(self, date: str, d: dict[str, Any]) -> None:
672 """Set Day of date in library based on POST dict d."""
673 day = self.day_as_dict(date)
674 for k, v in d.items():
675 if 'day_comment' == k:
677 elif 'new_todo' == k:
679 for todo in self.lib_all('Todo'):
680 if next_id <= todo['id']:
681 next_id = todo['id'] + 1
682 for proc_id in sorted([id_ for id_ in v if id_]):
683 todo = self.todo_as_dict(next_id, proc_id, date)
684 self.lib_set('Todo', [todo])
688 self.lib_get('Todo', todo_id)['is_done'] = True
690 for i, todo_id in enumerate(v):
691 t = self.lib_get('Todo', todo_id)
693 t['comment'] = d['comment'][i]
695 effort = d['effort'][i] if d['effort'][i] else None
697 self.lib_set('Day', [day])
700 def cond_as_dict(id_: int = 1,
701 is_active: bool = False,
702 title: None | str = None,
703 description: None | str = None,
704 ) -> dict[str, object]:
705 """Return JSON of Condition to expect."""
706 versioned: dict[str, dict[str, object]]
707 versioned = {'title': {}, 'description': {}}
708 if title is not None:
709 versioned['title']['0'] = title
710 if description is not None:
711 versioned['description']['0'] = description
712 return {'id': id_, 'is_active': is_active, '_versioned': versioned}
714 def set_cond_from_post(self, id_: int, d: dict[str, Any]) -> None:
715 """Set Condition of id_ in library based on POST dict d."""
717 self.lib_del('Condition', id_)
719 cond = self.lib_get('Condition', id_)
721 cond['is_active'] = 'is_active' in d and\
722 d['is_active'] in VALID_TRUES
723 for category in ['title', 'description']:
724 history = cond['_versioned'][category]
726 last_i = sorted([int(k) for k in history.keys()])[-1]
727 if d[category] != history[str(last_i)]:
728 history[str(last_i + 1)] = d[category]
730 history['0'] = d[category]
732 cond = self.cond_as_dict(id_, **d)
733 self.lib_set('Condition', [cond])
736 def todo_as_dict(id_: int = 1,
738 date: str = '2024-01-01',
739 conditions: None | list[int] = None,
740 disables: None | list[int] = None,
741 blockers: None | list[int] = None,
742 enables: None | list[int] = None,
743 calendarize: bool = False,
745 is_done: bool = False,
746 effort: float | None = None,
747 children: list[int] | None = None,
748 parents: list[int] | None = None,
749 ) -> dict[str, object]:
750 """Return JSON of Todo to expect."""
751 # pylint: disable=too-many-arguments
754 'process_id': process_id,
756 'calendarize': calendarize,
758 'children': children if children else [],
759 'parents': parents if parents else [],
761 'conditions': conditions if conditions else [],
762 'disables': disables if disables else [],
763 'blockers': blockers if blockers else [],
764 'enables': enables if enables else []}
767 def set_todo_from_post(self, id_: int, d: dict[str, Any]) -> None:
768 """Set Todo of id_ in library based on POST dict d."""
769 corrected_kwargs: dict[str, Any] = {
770 'children': [], 'is_done': 0, 'calendarize': 0, 'comment': ''}
771 for k, v in d.items():
772 if k.startswith('step_filler_to_'):
775 new_children = v if isinstance(v, list) else [v]
776 corrected_kwargs['children'] += new_children
778 if k in {'is_done', 'calendarize'} and v in VALID_TRUES:
780 corrected_kwargs[k] = v
781 todo = self.lib_get('Todo', id_)
783 for k, v in corrected_kwargs.items():
786 todo = self.todo_as_dict(id_, **corrected_kwargs)
787 self.lib_set('Todo', [todo])
790 def procstep_as_dict(id_: int,
792 step_process_id: int,
793 parent_step_id: int | None = None
794 ) -> dict[str, object]:
795 """Return JSON of ProcessStep to expect."""
797 'owner_id': owner_id,
798 'step_process_id': step_process_id,
799 'parent_step_id': parent_step_id}
802 def proc_as_dict(id_: int = 1,
803 title: None | str = None,
804 description: None | str = None,
805 effort: None | float = None,
806 conditions: None | list[int] = None,
807 disables: None | list[int] = None,
808 blockers: None | list[int] = None,
809 enables: None | list[int] = None,
810 explicit_steps: None | list[int] = None
811 ) -> dict[str, object]:
812 """Return JSON of Process to expect."""
813 # pylint: disable=too-many-arguments
814 versioned: dict[str, dict[str, object]]
815 versioned = {'title': {}, 'description': {}, 'effort': {}}
816 if title is not None:
817 versioned['title']['0'] = title
818 if description is not None:
819 versioned['description']['0'] = description
820 if effort is not None:
821 versioned['effort']['0'] = effort
823 'calendarize': False,
824 'suppressed_steps': [],
825 'explicit_steps': explicit_steps if explicit_steps else [],
826 '_versioned': versioned,
827 'conditions': conditions if conditions else [],
828 'disables': disables if disables else [],
829 'enables': enables if enables else [],
830 'blockers': blockers if blockers else []}
833 def set_proc_from_post(self, id_: int, d: dict[str, Any]) -> None:
834 """Set Process of id_ in library based on POST dict d."""
835 proc = self.lib_get('Process', id_)
837 for category in ['title', 'description', 'effort']:
838 history = proc['_versioned'][category]
840 last_i = sorted([int(k) for k in history.keys()])[-1]
841 if d[category] != history[str(last_i)]:
842 history[str(last_i + 1)] = d[category]
844 history['0'] = d[category]
846 proc = self.proc_as_dict(id_,
847 d['title'], d['description'], d['effort'])
848 ignore = {'title', 'description', 'effort', 'new_top_step', 'step_of',
850 proc['calendarize'] = False
851 for k, v in d.items():
853 or k.startswith('step_') or k.startswith('new_step_to'):
855 if k in {'calendarize'} and v in VALID_TRUES:
857 elif k in {'suppressed_steps', 'explicit_steps', 'conditions',
858 'disables', 'enables', 'blockers'}:
859 if not isinstance(v, list):
862 self.lib_set('Process', [proc])
865 class TestCaseWithServer(TestCaseWithDB):
866 """Module tests against our HTTP server/handler (and database)."""
868 def setUp(self) -> None:
870 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
871 self.server_thread = Thread(target=self.httpd.serve_forever)
872 self.server_thread.daemon = True
873 self.server_thread.start()
874 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
875 self.httpd.server_address[1])
876 self.httpd.render_mode = 'json'
878 def tearDown(self) -> None:
879 self.httpd.shutdown()
880 self.httpd.server_close()
881 self.server_thread.join()
884 def post_exp_cond(self,
885 exps: list[Expected],
886 payload: dict[str, object],
888 post_to_id: bool = True,
889 redir_to_id: bool = True
891 """POST /condition(s), appropriately update Expecteds."""
892 # pylint: disable=too-many-arguments
893 target = f'/condition?id={id_}' if post_to_id else '/condition'
894 redir = f'/condition?id={id_}' if redir_to_id else '/conditions'
895 if 'title' not in payload:
896 payload['title'] = 'foo'
897 if 'description' not in payload:
898 payload['description'] = 'foo'
899 self.check_post(payload, target, redir=redir)
901 exp.set_cond_from_post(id_, payload)
903 def post_exp_day(self,
904 exps: list[Expected],
905 payload: dict[str, Any],
906 date: str = '2024-01-01'
908 """POST /day, appropriately update Expecteds."""
909 if 'make_type' not in payload:
910 payload['make_type'] = 'empty'
911 if 'day_comment' not in payload:
912 payload['day_comment'] = ''
913 target = f'/day?date={date}'
914 redir_to = f'{target}&make_type={payload["make_type"]}'
915 self.check_post(payload, target, 302, redir_to)
917 exp.set_day_from_post(date, payload)
919 def post_exp_process(self,
920 exps: list[Expected],
921 payload: dict[str, Any],
923 ) -> dict[str, object]:
924 """POST /process, appropriately update Expecteds."""
925 if 'title' not in payload:
926 payload['title'] = 'foo'
927 if 'description' not in payload:
928 payload['description'] = 'foo'
929 if 'effort' not in payload:
930 payload['effort'] = 1.1
931 self.check_post(payload, f'/process?id={id_}',
932 redir=f'/process?id={id_}')
934 exp.set_proc_from_post(id_, payload)
937 def post_exp_todo(self,
938 exps: list[Expected],
939 payload: dict[str, Any],
942 """POST /todo, appropriately updated Expecteds."""
943 self.check_post(payload, f'/todo?id={id_}')
945 exp.set_todo_from_post(id_, payload)
947 def check_filter(self, exp: Expected, category: str, key: str,
948 val: str, list_ids: list[int]) -> None:
949 """Check GET /{category}?{key}={val} sorts to list_ids."""
950 # pylint: disable=too-many-arguments
952 exp.force(category, list_ids)
953 self.check_json_get(f'/{category}?{key}={val}', exp)
955 def check_redirect(self, target: str) -> None:
956 """Check that self.conn answers with a 302 redirect to target."""
957 response = self.conn.getresponse()
958 self.assertEqual(response.status, 302)
959 self.assertEqual(response.getheader('Location'), target)
961 def check_get(self, target: str, expected_code: int) -> None:
962 """Check that a GET to target yields expected_code."""
963 self.conn.request('GET', target)
964 self.assertEqual(self.conn.getresponse().status, expected_code)
966 def check_minimal_inputs(self,
968 minimal_inputs: dict[str, Any]
970 """Check that url 400's unless all of minimal_inputs provided."""
971 for to_hide in minimal_inputs.keys():
972 to_post = {k: v for k, v in minimal_inputs.items() if k != to_hide}
973 self.check_post(to_post, url, 400)
975 def check_post(self, data: Mapping[str, object], target: str,
976 expected_code: int = 302, redir: str = '') -> None:
977 """Check that POST of data to target yields expected_code."""
978 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
979 headers = {'Content-Type': 'application/x-www-form-urlencoded',
980 'Content-Length': str(len(encoded_form_data))}
981 self.conn.request('POST', target,
982 body=encoded_form_data, headers=headers)
983 if 302 == expected_code:
984 redir = target if redir == '' else redir
985 self.check_redirect(redir)
987 self.assertEqual(self.conn.getresponse().status, expected_code)
989 def check_get_defaults(self,
991 default_id: str = '1',
994 """Some standard model paths to test."""
995 nonexist_status = 200 if self.checked_class.can_create_by_id else 404
996 self.check_get(path, nonexist_status)
997 self.check_get(f'{path}?{id_name}=', 400)
998 self.check_get(f'{path}?{id_name}=foo', 400)
999 self.check_get(f'/{path}?{id_name}=0', 400)
1000 self.check_get(f'{path}?{id_name}={default_id}', nonexist_status)
1002 def check_json_get(self, path: str, expected: Expected) -> None:
1003 """Compare JSON on GET path with expected.
1005 To simplify comparison of VersionedAttribute histories, transforms
1006 timestamp keys of VersionedAttribute history keys into (strings of)
1007 integers counting chronologically forward from 0.
1010 def rewrite_history_keys_in(item: Any) -> Any:
1011 if isinstance(item, dict):
1012 if '_versioned' in item.keys():
1013 for category in item['_versioned']:
1014 vals = item['_versioned'][category].values()
1016 for i, val in enumerate(vals):
1017 history[str(i)] = val
1018 item['_versioned'][category] = history
1019 for category in list(item.keys()):
1020 rewrite_history_keys_in(item[category])
1021 elif isinstance(item, list):
1022 item[:] = [rewrite_history_keys_in(i) for i in item]
1025 def walk_diffs(path: str, cmp1: object, cmp2: object) -> None:
1026 # pylint: disable=too-many-branches
1027 def warn(intro: str, val: object) -> None:
1028 if isinstance(val, (str, int, float)):
1034 if isinstance(cmp1, dict) and isinstance(cmp2, dict):
1035 for k, v in cmp1.items():
1037 warn(f'DIFF {path}: retrieved lacks {k}', v)
1039 walk_diffs(f'{path}:{k}', v, cmp2[k])
1040 for k in [k for k in cmp2.keys() if k not in cmp1]:
1041 warn(f'DIFF {path}: expected lacks retrieved\'s {k}',
1043 elif isinstance(cmp1, list) and isinstance(cmp2, list):
1044 for i, v1 in enumerate(cmp1):
1046 warn(f'DIFF {path}[{i}] retrieved misses:', v1)
1048 walk_diffs(f'{path}[{i}]', v1, cmp2[i])
1049 if len(cmp2) > len(cmp1):
1050 for i, v2 in enumerate(cmp2[len(cmp1):]):
1051 warn(f'DIFF {path}[{len(cmp1)+i}] misses:', v2)
1053 warn(f'DIFF {path} – for expected:', cmp1)
1054 warn('… and for retrieved:', cmp2)
1056 self.conn.request('GET', path)
1057 response = self.conn.getresponse()
1058 self.assertEqual(response.status, 200)
1059 retrieved = json_loads(response.read().decode())
1060 rewrite_history_keys_in(retrieved)
1061 cmp = expected.as_dict
1063 self.assertEqual(cmp, retrieved)
1064 except AssertionError as e:
1069 walk_diffs('', cmp, retrieved)