home · contact · privacy
Minor tests refactoring.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 # pylint: disable=too-many-lines
3 from __future__ import annotations
4 from unittest import TestCase
5 from typing import Mapping, Any, Callable
6 from threading import Thread
7 from http.client import HTTPConnection
8 from datetime import datetime, timedelta
9 from time import sleep
10 from json import loads as json_loads, dumps as json_dumps
11 from urllib.parse import urlencode
12 from uuid import uuid4
13 from os import remove as remove_file
14 from pprint import pprint
15 from plomtask.db import DatabaseFile, DatabaseConnection
16 from plomtask.http import TaskHandler, TaskServer
17 from plomtask.processes import Process, ProcessStep
18 from plomtask.conditions import Condition
19 from plomtask.days import Day
20 from plomtask.dating import DATE_FORMAT
21 from plomtask.todos import Todo
22 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
23 from plomtask.exceptions import NotFoundException, HandledException
24
25
26 VERSIONED_VALS: dict[str,
27                      list[str] | list[float]] = {'str': ['A', 'B'],
28                                                  'float': [0.3, 1.1]}
29 VALID_TRUES = {True, 'True', 'true', '1', 'on'}
30
31
32 class TestCaseAugmented(TestCase):
33     """Tester core providing helpful basic internal decorators and methods."""
34     checked_class: Any
35     default_init_kwargs: dict[str, Any] = {}
36
37     @staticmethod
38     def _run_if_checked_class(f: Callable[..., None]) -> Callable[..., None]:
39         def wrapper(self: TestCase) -> None:
40             if hasattr(self, 'checked_class'):
41                 f(self)
42         return wrapper
43
44     @classmethod
45     def _run_on_versioned_attributes(cls,
46                                      f: Callable[..., None]
47                                      ) -> Callable[..., None]:
48         @cls._run_if_checked_class
49         def wrapper(self: TestCase) -> None:
50             assert isinstance(self, TestCaseAugmented)
51             for attr_name in self.checked_class.to_save_versioned():
52                 default = self.checked_class.versioned_defaults[attr_name]
53                 owner = self.checked_class(None, **self.default_init_kwargs)
54                 attr = getattr(owner, attr_name)
55                 to_set = VERSIONED_VALS[attr.value_type_name]
56                 f(self, owner, attr_name, attr, default, to_set)
57         return wrapper
58
59     @classmethod
60     def _make_from_defaults(cls, id_: float | str | None) -> Any:
61         return cls.checked_class(id_, **cls.default_init_kwargs)
62
63
64 class TestCaseSansDB(TestCaseAugmented):
65     """Tests requiring no DB setup."""
66     legal_ids: list[str] | list[int] = [1, 5]
67     illegal_ids: list[str] | list[int] = [0]
68
69     @TestCaseAugmented._run_if_checked_class
70     def test_id_validation(self) -> None:
71         """Test .id_ validation/setting."""
72         for id_ in self.illegal_ids:
73             with self.assertRaises(HandledException):
74                 self._make_from_defaults(id_)
75         for id_ in self.legal_ids:
76             obj = self._make_from_defaults(id_)
77             self.assertEqual(obj.id_, id_)
78
79     @TestCaseAugmented._run_on_versioned_attributes
80     def test_versioned_set(self,
81                            _: Any,
82                            __: str,
83                            attr: VersionedAttribute,
84                            default: str | float,
85                            to_set: list[str] | list[float]
86                            ) -> None:
87         """Test VersionedAttribute.set() behaves as expected."""
88         attr.set(default)
89         self.assertEqual(list(attr.history.values()), [default])
90         # check same value does not get set twice in a row,
91         # and that not even its timestamp get updated
92         timestamp = list(attr.history.keys())[0]
93         attr.set(default)
94         self.assertEqual(list(attr.history.values()), [default])
95         self.assertEqual(list(attr.history.keys())[0], timestamp)
96         # check that different value _will_ be set/added
97         attr.set(to_set[0])
98         timesorted_vals = [attr.history[t] for
99                            t in sorted(attr.history.keys())]
100         expected = [default, to_set[0]]
101         self.assertEqual(timesorted_vals, expected)
102         # check that a previously used value can be set if not most recent
103         attr.set(default)
104         timesorted_vals = [attr.history[t] for
105                            t in sorted(attr.history.keys())]
106         expected = [default, to_set[0], default]
107         self.assertEqual(timesorted_vals, expected)
108         # again check for same value not being set twice in a row, even for
109         # later items
110         attr.set(to_set[1])
111         timesorted_vals = [attr.history[t] for
112                            t in sorted(attr.history.keys())]
113         expected = [default, to_set[0], default, to_set[1]]
114         self.assertEqual(timesorted_vals, expected)
115         attr.set(to_set[1])
116         self.assertEqual(timesorted_vals, expected)
117
118     @TestCaseAugmented._run_on_versioned_attributes
119     def test_versioned_newest(self,
120                               _: Any,
121                               __: str,
122                               attr: VersionedAttribute,
123                               default: str | float,
124                               to_set: list[str] | list[float]
125                               ) -> None:
126         """Test VersionedAttribute.newest."""
127         # check .newest on empty history returns .default
128         self.assertEqual(attr.newest, default)
129         # check newest element always returned
130         for v in [to_set[0], to_set[1]]:
131             attr.set(v)
132             self.assertEqual(attr.newest, v)
133         # check newest element returned even if also early value
134         attr.set(default)
135         self.assertEqual(attr.newest, default)
136
137     @TestCaseAugmented._run_on_versioned_attributes
138     def test_versioned_at(self,
139                           _: Any,
140                           __: str,
141                           attr: VersionedAttribute,
142                           default: str | float,
143                           to_set: list[str] | list[float]
144                           ) -> None:
145         """Test .at() returns values nearest to queried time, or default."""
146         # check .at() return default on empty history
147         timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
148         self.assertEqual(attr.at(timestamp_a), default)
149         # check value exactly at timestamp returned
150         attr.set(to_set[0])
151         timestamp_b = list(attr.history.keys())[0]
152         self.assertEqual(attr.at(timestamp_b), to_set[0])
153         # check earliest value returned if exists, rather than default
154         self.assertEqual(attr.at(timestamp_a), to_set[0])
155         # check reverts to previous value for timestamps not indexed
156         sleep(0.00001)
157         timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
158         sleep(0.00001)
159         attr.set(to_set[1])
160         timestamp_c = sorted(attr.history.keys())[-1]
161         self.assertEqual(attr.at(timestamp_c), to_set[1])
162         self.assertEqual(attr.at(timestamp_between), to_set[0])
163         sleep(0.00001)
164         timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
165         self.assertEqual(attr.at(timestamp_after_c), to_set[1])
166
167
168 class TestCaseWithDB(TestCaseAugmented):
169     """Module tests not requiring DB setup."""
170     default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
171
172     def setUp(self) -> None:
173         Condition.empty_cache()
174         Day.empty_cache()
175         Process.empty_cache()
176         ProcessStep.empty_cache()
177         Todo.empty_cache()
178         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
179         self.db_conn = DatabaseConnection(self.db_file)
180
181     def tearDown(self) -> None:
182         self.db_conn.close()
183         remove_file(self.db_file.path)
184
185     def _load_from_db(self, id_: int | str) -> list[object]:
186         db_found: list[object] = []
187         for row in self.db_conn.row_where(self.checked_class.table_name,
188                                           'id', id_):
189             db_found += [self.checked_class.from_table_row(self.db_conn,
190                                                            row)]
191         return db_found
192
193     def _change_obj(self, obj: object) -> str:
194         attr_name: str = self.checked_class.to_save_simples[-1]
195         attr = getattr(obj, attr_name)
196         new_attr: str | int | float | bool
197         if isinstance(attr, (int, float)):
198             new_attr = attr + 1
199         elif isinstance(attr, str):
200             new_attr = attr + '_'
201         elif isinstance(attr, bool):
202             new_attr = not attr
203         setattr(obj, attr_name, new_attr)
204         return attr_name
205
206     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
207         """Test both cache and DB equal content."""
208         expected_cache = {}
209         for item in content:
210             expected_cache[item.id_] = item
211         self.assertEqual(self.checked_class.get_cache(), expected_cache)
212         hashes_content = [hash(x) for x in content]
213         db_found: list[Any] = []
214         for item in content:
215             assert isinstance(item.id_, type(self.default_ids[0]))
216             db_found += self._load_from_db(item.id_)
217         hashes_db_found = [hash(x) for x in db_found]
218         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
219
220     def check_by_date_range_with_limits(self,
221                                         date_col: str,
222                                         set_id_field: bool = True
223                                         ) -> None:
224         """Test .by_date_range_with_limits."""
225         # pylint: disable=too-many-locals
226         f = self.checked_class.by_date_range_with_limits
227         # check illegal ranges
228         legal_range = ('yesterday', 'tomorrow')
229         for i in [0, 1]:
230             for bad_date in ['foo', '2024-02-30', '2024-01-01 12:00:00']:
231                 date_range = list(legal_range[:])
232                 date_range[i] = bad_date
233                 with self.assertRaises(HandledException):
234                     f(self.db_conn, date_range, date_col)
235         # check empty, translation of 'yesterday' and 'tomorrow'
236         items, start, end = f(self.db_conn, legal_range, date_col)
237         self.assertEqual(items, [])
238         yesterday = datetime.now() + timedelta(days=-1)
239         tomorrow = datetime.now() + timedelta(days=+1)
240         self.assertEqual(start, yesterday.strftime(DATE_FORMAT))
241         self.assertEqual(end, tomorrow.strftime(DATE_FORMAT))
242         # prepare dated items for non-empty results
243         kwargs_with_date = self.default_init_kwargs.copy()
244         if set_id_field:
245             kwargs_with_date['id_'] = None
246         objs = []
247         dates = ['2024-01-01', '2024-01-02', '2024-01-04']
248         for date in ['2024-01-01', '2024-01-02', '2024-01-04']:
249             kwargs_with_date['date'] = date
250             obj = self.checked_class(**kwargs_with_date)
251             objs += [obj]
252         # check ranges still empty before saving
253         date_range = [dates[0], dates[-1]]
254         self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
255         # check all objs displayed within closed interval
256         for obj in objs:
257             obj.save(self.db_conn)
258         self.assertEqual(f(self.db_conn, date_range, date_col)[0], objs)
259         # check that only displayed what exists within interval
260         date_range = ['2023-12-20', '2024-01-03']
261         expected = [objs[0], objs[1]]
262         self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
263         date_range = ['2024-01-03', '2024-01-30']
264         expected = [objs[2]]
265         self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
266         # check that inverted interval displays nothing
267         date_range = [dates[-1], dates[0]]
268         self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
269         # check that "today" is interpreted, and single-element interval
270         today_date = datetime.now().strftime(DATE_FORMAT)
271         kwargs_with_date['date'] = today_date
272         obj_today = self.checked_class(**kwargs_with_date)
273         obj_today.save(self.db_conn)
274         date_range = ['today', 'today']
275         items, start, end = f(self.db_conn, date_range, date_col)
276         self.assertEqual(start, today_date)
277         self.assertEqual(start, end)
278         self.assertEqual(items, [obj_today])
279
280     @TestCaseAugmented._run_on_versioned_attributes
281     def test_saving_versioned_attributes(self,
282                                          owner: Any,
283                                          attr_name: str,
284                                          attr: VersionedAttribute,
285                                          _: str | float,
286                                          to_set: list[str] | list[float]
287                                          ) -> None:
288         """Test storage and initialization of versioned attributes."""
289
290         def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
291             attr_vals_saved: list[object] = []
292             for row in self.db_conn.row_where(attr.table_name, 'parent',
293                                               owner.id_):
294                 attr_vals_saved += [row[2]]
295             return attr_vals_saved
296
297         attr.set(to_set[0])
298         # check that without attr.save() no rows in DB
299         rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
300         self.assertEqual([], rows)
301         # fail saving attributes on non-saved owner
302         with self.assertRaises(NotFoundException):
303             attr.save(self.db_conn)
304         # check owner.save() created entries as expected in attr table
305         owner.save(self.db_conn)
306         attr_vals_saved = retrieve_attr_vals(attr)
307         self.assertEqual([to_set[0]], attr_vals_saved)
308         # check changing attr val without save affects owner in memory …
309         attr.set(to_set[1])
310         cmp_attr = getattr(owner, attr_name)
311         self.assertEqual(to_set, list(cmp_attr.history.values()))
312         self.assertEqual(cmp_attr.history, attr.history)
313         # … but does not yet affect DB
314         attr_vals_saved = retrieve_attr_vals(attr)
315         self.assertEqual([to_set[0]], attr_vals_saved)
316         # check individual attr.save also stores new val to DB
317         attr.save(self.db_conn)
318         attr_vals_saved = retrieve_attr_vals(attr)
319         self.assertEqual(to_set, attr_vals_saved)
320
321     @TestCaseAugmented._run_if_checked_class
322     def test_saving_and_caching(self) -> None:
323         """Test effects of .cache() and .save()."""
324         id1 = self.default_ids[0]
325         # check failure to cache without ID (if None-ID input possible)
326         if isinstance(id1, int):
327             obj0 = self._make_from_defaults(None)
328             with self.assertRaises(HandledException):
329                 obj0.cache()
330         # check mere object init itself doesn't even store in cache
331         obj1 = self._make_from_defaults(id1)
332         self.assertEqual(self.checked_class.get_cache(), {})
333         # check .cache() fills cache, but not DB
334         obj1.cache()
335         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
336         found_in_db = self._load_from_db(id1)
337         self.assertEqual(found_in_db, [])
338         # check .save() sets ID (for int IDs), updates cache, and fills DB
339         # (expect ID to be set to id1, despite obj1 already having that as ID:
340         # it's generated by cursor.lastrowid on the DB table, and with obj1
341         # not written there, obj2 should get it first!)
342         id_input = None if isinstance(id1, int) else id1
343         obj2 = self._make_from_defaults(id_input)
344         obj2.save(self.db_conn)
345         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
346         # NB: we'll only compare hashes because obj2 itself disappears on
347         # .from_table_row-triggered database reload
348         obj2_hash = hash(obj2)
349         found_in_db += self._load_from_db(id1)
350         self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
351         # check we cannot overwrite obj2 with obj1 despite its same ID,
352         # since it has disappeared now
353         with self.assertRaises(HandledException):
354             obj1.save(self.db_conn)
355
356     @TestCaseAugmented._run_if_checked_class
357     def test_by_id(self) -> None:
358         """Test .by_id()."""
359         id1, id2, _ = self.default_ids
360         # check failure if not yet saved
361         obj1 = self._make_from_defaults(id1)
362         with self.assertRaises(NotFoundException):
363             self.checked_class.by_id(self.db_conn, id1)
364         # check identity of cached and retrieved
365         obj1.cache()
366         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
367         # check identity of saved and retrieved
368         obj2 = self._make_from_defaults(id2)
369         obj2.save(self.db_conn)
370         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
371
372     @TestCaseAugmented._run_if_checked_class
373     def test_by_id_or_create(self) -> None:
374         """Test .by_id_or_create."""
375         # check .by_id_or_create fails if wrong class
376         if not self.checked_class.can_create_by_id:
377             with self.assertRaises(HandledException):
378                 self.checked_class.by_id_or_create(self.db_conn, None)
379             return
380         # check ID input of None creates, on saving, ID=1,2,… for int IDs
381         if isinstance(self.default_ids[0], int):
382             for n in range(2):
383                 item = self.checked_class.by_id_or_create(self.db_conn, None)
384                 self.assertEqual(item.id_, None)
385                 item.save(self.db_conn)
386                 self.assertEqual(item.id_, n+1)
387         # check .by_id_or_create acts like normal instantiation (sans saving)
388         id_ = self.default_ids[2]
389         item = self.checked_class.by_id_or_create(self.db_conn, id_)
390         self.assertEqual(item.id_, id_)
391         with self.assertRaises(NotFoundException):
392             self.checked_class.by_id(self.db_conn, item.id_)
393         self.assertEqual(self.checked_class(item.id_), item)
394
395     @TestCaseAugmented._run_if_checked_class
396     def test_from_table_row(self) -> None:
397         """Test .from_table_row() properly reads in class directly from DB."""
398         id_ = self.default_ids[0]
399         obj = self._make_from_defaults(id_)
400         obj.save(self.db_conn)
401         assert isinstance(obj.id_, type(id_))
402         for row in self.db_conn.row_where(self.checked_class.table_name,
403                                           'id', obj.id_):
404             # check .from_table_row reproduces state saved, no matter if obj
405             # later changed (with caching even)
406             # NB: we'll only compare hashes because obj itself disappears on
407             # .from_table_row-triggered database reload
408             hash_original = hash(obj)
409             attr_name = self._change_obj(obj)
410             obj.cache()
411             to_cmp = getattr(obj, attr_name)
412             retrieved = self.checked_class.from_table_row(self.db_conn, row)
413             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
414             self.assertEqual(hash_original, hash(retrieved))
415             # check cache contains what .from_table_row just produced
416             self.assertEqual({retrieved.id_: retrieved},
417                              self.checked_class.get_cache())
418
419     @TestCaseAugmented._run_on_versioned_attributes
420     def test_versioned_history_from_row(self,
421                                         owner: Any,
422                                         _: str,
423                                         attr: VersionedAttribute,
424                                         default: str | float,
425                                         to_set: list[str] | list[float]
426                                         ) -> None:
427         """"Test VersionedAttribute.history_from_row() knows its DB rows."""
428         attr.set(to_set[0])
429         attr.set(to_set[1])
430         owner.save(self.db_conn)
431         # make empty VersionedAttribute, fill from rows, compare to owner's
432         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
433             loaded_attr = VersionedAttribute(owner, attr.table_name, default)
434             for row in self.db_conn.row_where(attr.table_name, 'parent',
435                                               owner.id_):
436                 loaded_attr.history_from_row(row)
437             self.assertEqual(len(attr.history.keys()),
438                              len(loaded_attr.history.keys()))
439             for timestamp, value in attr.history.items():
440                 self.assertEqual(value, loaded_attr.history[timestamp])
441
442     @TestCaseAugmented._run_if_checked_class
443     def test_all(self) -> None:
444         """Test .all() and its relation to cache and savings."""
445         id1, id2, id3 = self.default_ids
446         item1 = self._make_from_defaults(id1)
447         item2 = self._make_from_defaults(id2)
448         item3 = self._make_from_defaults(id3)
449         # check .all() returns empty list on un-cached items
450         self.assertEqual(self.checked_class.all(self.db_conn), [])
451         # check that all() shows only cached/saved items
452         item1.cache()
453         item3.save(self.db_conn)
454         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
455                          sorted([item1, item3]))
456         item2.save(self.db_conn)
457         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
458                          sorted([item1, item2, item3]))
459
460     @TestCaseAugmented._run_if_checked_class
461     def test_singularity(self) -> None:
462         """Test pointers made for single object keep pointing to it."""
463         id1 = self.default_ids[0]
464         obj = self._make_from_defaults(id1)
465         obj.save(self.db_conn)
466         # change object, expect retrieved through .by_id to carry change
467         attr_name = self._change_obj(obj)
468         new_attr = getattr(obj, attr_name)
469         retrieved = self.checked_class.by_id(self.db_conn, id1)
470         self.assertEqual(new_attr, getattr(retrieved, attr_name))
471
472     @TestCaseAugmented._run_on_versioned_attributes
473     def test_versioned_singularity(self,
474                                    owner: Any,
475                                    attr_name: str,
476                                    attr: VersionedAttribute,
477                                    _: str | float,
478                                    to_set: list[str] | list[float]
479                                    ) -> None:
480         """Test singularity of VersionedAttributes on saving."""
481         owner.save(self.db_conn)
482         # change obj, expect retrieved through .by_id to carry change
483         attr.set(to_set[0])
484         retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
485         attr_retrieved = getattr(retrieved, attr_name)
486         self.assertEqual(attr.history, attr_retrieved.history)
487
488     @TestCaseAugmented._run_if_checked_class
489     def test_remove(self) -> None:
490         """Test .remove() effects on DB and cache."""
491         id_ = self.default_ids[0]
492         obj = self._make_from_defaults(id_)
493         # check removal only works after saving
494         with self.assertRaises(HandledException):
495             obj.remove(self.db_conn)
496         obj.save(self.db_conn)
497         obj.remove(self.db_conn)
498         # check access to obj fails after removal
499         with self.assertRaises(HandledException):
500             print(obj.id_)
501         # check DB and cache now empty
502         self.check_identity_with_cache_and_db([])
503
504
505 class Expected:
506     """Builder of (JSON-like) dict to compare against responses of test server.
507
508     Collects all items and relations we expect expressed in the server's JSON
509     responses and puts them into the proper json.dumps-friendly dict structure,
510     accessibla via .as_dict, to compare them in TestsWithServer.check_json_get.
511
512     On its own provides for .as_dict output only {"_library": …}, initialized
513     from .__init__ and to be directly manipulated via the .lib* methods.
514     Further structures of the expected response may be added and kept
515     up-to-date by subclassing .__init__, .recalc, and .d.
516
517     NB: Lots of expectations towards server behavior will be made explicit here
518     (or in the subclasses) rather than in the actual TestCase methods' code.
519     """
520     _default_dict: dict[str, Any]
521     _forced: dict[str, Any]
522     _fields: dict[str, Any]
523     _on_empty_make_temp: tuple[str, str]
524
525     def __init__(self,
526                  todos: list[dict[str, Any]] | None = None,
527                  procs: list[dict[str, Any]] | None = None,
528                  procsteps: list[dict[str, Any]] | None = None,
529                  conds: list[dict[str, Any]] | None = None,
530                  days: list[dict[str, Any]] | None = None
531                  ) -> None:
532         # pylint: disable=too-many-arguments
533         for name in ['_default_dict', '_fields', '_forced']:
534             if not hasattr(self, name):
535                 setattr(self, name, {})
536         self._lib = {}
537         for title, items in [('Todo', todos),
538                              ('Process', procs),
539                              ('ProcessStep', procsteps),
540                              ('Condition', conds),
541                              ('Day', days)]:
542             if items:
543                 self._lib[title] = self._as_refs(items)
544         for k, v in self._default_dict.items():
545             if k not in self._fields:
546                 self._fields[k] = v
547
548     def recalc(self) -> None:
549         """Update internal dictionary by subclass-specific rules."""
550         todos = self.lib_all('Todo')
551         for todo in todos:
552             todo['parents'] = []
553         for todo in todos:
554             for child_id in todo['children']:
555                 self.lib_get('Todo', child_id)['parents'] += [todo['id']]
556             todo['children'].sort()
557         procsteps = self.lib_all('ProcessStep')
558         procs = self.lib_all('Process')
559         for proc in procs:
560             proc['explicit_steps'] = [s['id'] for s in procsteps
561                                       if s['owner_id'] == proc['id']]
562
563     @property
564     def as_dict(self) -> dict[str, Any]:
565         """Return dict to compare against test server JSON responses."""
566         make_temp = False
567         if hasattr(self, '_on_empty_make_temp'):
568             category, dicter = getattr(self, '_on_empty_make_temp')
569             id_ = self._fields[category.lower()]
570             make_temp = not bool(self.lib_get(category, id_))
571             if make_temp:
572                 f = getattr(self, dicter)
573                 self.lib_set(category, [f(id_)])
574         self.recalc()
575         d = {'_library': self._lib}
576         for k, v in self._fields.items():
577             # we expect everything sortable to be sorted
578             if isinstance(v, list) and k not in self._forced:
579                 # NB: if we don't test for v being list, sorted() on an empty
580                 # dict may return an empty list
581                 try:
582                     v = sorted(v)
583                 except TypeError:
584                     pass
585             d[k] = v
586         for k, v in self._forced.items():
587             d[k] = v
588         if make_temp:
589             json = json_dumps(d)
590             self.lib_del(category, id_)
591             d = json_loads(json)
592         return d
593
594     def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
595         """From library, return item of category and id_, or empty dict."""
596         str_id = str(id_)
597         if category in self._lib and str_id in self._lib[category]:
598             return self._lib[category][str_id]
599         return {}
600
601     def lib_all(self, category: str) -> list[dict[str, Any]]:
602         """From library, return items of category, or [] if none."""
603         if category in self._lib:
604             return list(self._lib[category].values())
605         return []
606
607     def lib_set(self, category: str, items: list[dict[str, object]]) -> None:
608         """Update library for category with items."""
609         if category not in self._lib:
610             self._lib[category] = {}
611         for k, v in self._as_refs(items).items():
612             self._lib[category][k] = v
613
614     def lib_del(self, category: str, id_: str | int) -> None:
615         """Remove category element of id_ from library."""
616         del self._lib[category][str(id_)]
617         if 0 == len(self._lib[category]):
618             del self._lib[category]
619
620     def lib_wipe(self, category: str) -> None:
621         """Remove category from library."""
622         if category in self._lib:
623             del self._lib[category]
624
625     def set(self, field_name: str, value: object) -> None:
626         """Set top-level .as_dict field."""
627         self._fields[field_name] = value
628
629     def force(self, field_name: str, value: object) -> None:
630         """Set ._forced field to ensure value in .as_dict."""
631         self._forced[field_name] = value
632
633     def unforce(self, field_name: str) -> None:
634         """Unset ._forced field."""
635         del self._forced[field_name]
636
637     @staticmethod
638     def _as_refs(items: list[dict[str, object]]
639                  ) -> dict[str, dict[str, object]]:
640         """Return dictionary of items by their 'id' fields."""
641         refs = {}
642         for item in items:
643             refs[str(item['id'])] = item
644         return refs
645
646     @staticmethod
647     def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
648         """Return list of only 'id' fields of items."""
649         return [item['id'] for item in items]
650
651     @staticmethod
652     def day_as_dict(date: str, comment: str = '') -> dict[str, object]:
653         """Return JSON of Day to expect."""
654         return {'id': date, 'comment': comment, 'todos': []}
655
656     def set_day_from_post(self, date: str, d: dict[str, Any]) -> None:
657         """Set Day of date in library based on POST dict d."""
658         day = self.day_as_dict(date)
659         for k, v in d.items():
660             if 'day_comment' == k:
661                 day['comment'] = v
662             elif 'new_todo' == k:
663                 next_id = 1
664                 for todo in self.lib_all('Todo'):
665                     if next_id <= todo['id']:
666                         next_id = todo['id'] + 1
667                 for proc_id in sorted(v):
668                     todo = self.todo_as_dict(next_id, proc_id, date)
669                     self.lib_set('Todo', [todo])
670                     next_id += 1
671             elif 'done' == k:
672                 for todo_id in v:
673                     self.lib_get('Todo', todo_id)['is_done'] = True
674             elif 'todo_id' == k:
675                 for i, todo_id in enumerate(v):
676                     t = self.lib_get('Todo', todo_id)
677                     if 'comment' in d:
678                         t['comment'] = d['comment'][i]
679                     if 'effort' in d:
680                         effort = d['effort'][i] if d['effort'][i] else None
681                         t['effort'] = effort
682         self.lib_set('Day', [day])
683
684     @staticmethod
685     def cond_as_dict(id_: int = 1,
686                      is_active: bool = False,
687                      title: None | str = None,
688                      description: None | str = None,
689                      ) -> dict[str, object]:
690         """Return JSON of Condition to expect."""
691         versioned: dict[str, dict[str, object]]
692         versioned = {'title': {}, 'description': {}}
693         if title is not None:
694             versioned['title']['0'] = title
695         if description is not None:
696             versioned['description']['0'] = description
697         return {'id': id_, 'is_active': is_active, '_versioned': versioned}
698
699     def set_cond_from_post(self, id_: int, d: dict[str, Any]) -> None:
700         """Set Condition of id_ in library based on POST dict d."""
701         if d == {'delete': ''}:
702             self.lib_del('Condition', id_)
703             return
704         cond = self.lib_get('Condition', id_)
705         if cond:
706             cond['is_active'] = d['is_active']
707             for category in ['title', 'description']:
708                 history = cond['_versioned'][category]
709                 if len(history) > 0:
710                     last_i = sorted([int(k) for k in history.keys()])[-1]
711                     if d[category] != history[str(last_i)]:
712                         history[str(last_i + 1)] = d[category]
713                 else:
714                     history['0'] = d[category]
715         else:
716             cond = self.cond_as_dict(
717                     id_, d['is_active'], d['title'], d['description'])
718         self.lib_set('Condition', [cond])
719
720     @staticmethod
721     def todo_as_dict(id_: int = 1,
722                      process_id: int = 1,
723                      date: str = '2024-01-01',
724                      conditions: None | list[int] = None,
725                      disables: None | list[int] = None,
726                      blockers: None | list[int] = None,
727                      enables: None | list[int] = None,
728                      calendarize: bool = False,
729                      comment: str = '',
730                      is_done: bool = False,
731                      effort: float | None = None,
732                      children: list[int] | None = None,
733                      parents: list[int] | None = None,
734                      ) -> dict[str, object]:
735         """Return JSON of Todo to expect."""
736         # pylint: disable=too-many-arguments
737         d = {'id': id_,
738              'date': date,
739              'process_id': process_id,
740              'is_done': is_done,
741              'calendarize': calendarize,
742              'comment': comment,
743              'children': children if children else [],
744              'parents': parents if parents else [],
745              'effort': effort,
746              'conditions': conditions if conditions else [],
747              'disables': disables if disables else [],
748              'blockers': blockers if blockers else [],
749              'enables': enables if enables else []}
750         return d
751
752     def set_todo_from_post(self, id_: int, d: dict[str, Any]) -> None:
753         """Set Todo of id_ in library based on POST dict d."""
754         corrected_kwargs: dict[str, Any] = {'children': []}
755         for k, v in d.items():
756             if k in {'adopt', 'step_filler'}:
757                 new_children = v if isinstance(v, list) else [v]
758                 corrected_kwargs['children'] += new_children
759                 continue
760             if k in {'is_done', 'calendarize'}:
761                 v = v in VALID_TRUES
762             corrected_kwargs[k] = v
763         todo = self.lib_get('Todo', id_)
764         if todo:
765             for k, v in corrected_kwargs.items():
766                 todo[k] = v
767         else:
768             todo = self.todo_as_dict(id_, **corrected_kwargs)
769         self.lib_set('Todo', [todo])
770
771     @staticmethod
772     def procstep_as_dict(id_: int,
773                          owner_id: int,
774                          step_process_id: int,
775                          parent_step_id: int | None = None
776                          ) -> dict[str, object]:
777         """Return JSON of ProcessStep to expect."""
778         return {'id': id_,
779                 'owner_id': owner_id,
780                 'step_process_id': step_process_id,
781                 'parent_step_id': parent_step_id}
782
783     @staticmethod
784     def proc_as_dict(id_: int = 1,
785                      title: None | str = None,
786                      description: None | str = None,
787                      effort: None | float = None,
788                      conditions: None | list[int] = None,
789                      disables: None | list[int] = None,
790                      blockers: None | list[int] = None,
791                      enables: None | list[int] = None,
792                      explicit_steps: None | list[int] = None
793                      ) -> dict[str, object]:
794         """Return JSON of Process to expect."""
795         # pylint: disable=too-many-arguments
796         versioned: dict[str, dict[str, object]]
797         versioned = {'title': {}, 'description': {}, 'effort': {}}
798         if title is not None:
799             versioned['title']['0'] = title
800         if description is not None:
801             versioned['description']['0'] = description
802         if effort is not None:
803             versioned['effort']['0'] = effort
804         d = {'id': id_,
805              'calendarize': False,
806              'suppressed_steps': [],
807              'explicit_steps': explicit_steps if explicit_steps else [],
808              '_versioned': versioned,
809              'conditions': conditions if conditions else [],
810              'disables': disables if disables else [],
811              'enables': enables if enables else [],
812              'blockers': blockers if blockers else []}
813         return d
814
815     def set_proc_from_post(self, id_: int, d: dict[str, Any]) -> None:
816         """Set Process of id_ in library based on POST dict d."""
817         proc = self.lib_get('Process', id_)
818         if proc:
819             for category in ['title', 'description', 'effort']:
820                 history = proc['_versioned'][category]
821                 if len(history) > 0:
822                     last_i = sorted([int(k) for k in history.keys()])[-1]
823                     if d[category] != history[str(last_i)]:
824                         history[str(last_i + 1)] = d[category]
825                 else:
826                     history['0'] = d[category]
827         else:
828             proc = self.proc_as_dict(id_,
829                                      d['title'], d['description'], d['effort'])
830         ignore = {'title', 'description', 'effort', 'new_top_step', 'step_of',
831                   'kept_steps'}
832         for k, v in d.items():
833             if k in ignore\
834                     or k.startswith('step_') or k.startswith('new_step_to'):
835                 continue
836             if k in {'calendarize'}:
837                 v = v in VALID_TRUES
838             elif k in {'suppressed_steps', 'explicit_steps', 'conditions',
839                        'disables', 'enables', 'blockers'}:
840                 if not isinstance(v, list):
841                     v = [v]
842             proc[k] = v
843         self.lib_set('Process', [proc])
844
845
846 class TestCaseWithServer(TestCaseWithDB):
847     """Module tests against our HTTP server/handler (and database)."""
848
849     def setUp(self) -> None:
850         super().setUp()
851         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
852         self.server_thread = Thread(target=self.httpd.serve_forever)
853         self.server_thread.daemon = True
854         self.server_thread.start()
855         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
856                                    self.httpd.server_address[1])
857         self.httpd.render_mode = 'json'
858
859     def tearDown(self) -> None:
860         self.httpd.shutdown()
861         self.httpd.server_close()
862         self.server_thread.join()
863         super().tearDown()
864
865     def post_exp_cond(self,
866                       exps: list[Expected],
867                       id_: int,
868                       payload: dict[str, object],
869                       path_suffix: str = '',
870                       redir_suffix: str = ''
871                       ) -> None:
872         """POST /condition(s), appropriately update Expecteds."""
873         # pylint: disable=too-many-arguments
874         path = f'/condition{path_suffix}'
875         redir = f'/condition{redir_suffix}'
876         self.check_post(payload, path, redir=redir)
877         for exp in exps:
878             exp.set_cond_from_post(id_, payload)
879
880     def post_exp_day(self,
881                      exps: list[Expected],
882                      payload: dict[str, Any],
883                      date: str = '2024-01-01'
884                      ) -> None:
885         """POST /day, appropriately update Expecteds."""
886         if 'make_type' not in payload:
887             payload['make_type'] = 'empty'
888         if 'day_comment' not in payload:
889             payload['day_comment'] = ''
890         target = f'/day?date={date}'
891         redir_to = f'{target}&make_type={payload["make_type"]}'
892         self.check_post(payload, target, 302, redir_to)
893         for exp in exps:
894             exp.set_day_from_post(date, payload)
895
896     def post_exp_process(self,
897                          exps: list[Expected],
898                          payload: dict[str, Any],
899                          id_: int,
900                          ) -> dict[str, object]:
901         """POST /process, appropriately update Expecteds."""
902         if 'title' not in payload:
903             payload['title'] = 'foo'
904         if 'description' not in payload:
905             payload['description'] = 'foo'
906         if 'effort' not in payload:
907             payload['effort'] = 1.1
908         self.check_post(payload, f'/process?id={id_}',
909                         redir=f'/process?id={id_}')
910         for exp in exps:
911             exp.set_proc_from_post(id_, payload)
912         return payload
913
914     def check_filter(self, exp: Expected, category: str, key: str,
915                      val: str, list_ids: list[int]) -> None:
916         """Check GET /{category}?{key}={val} sorts to list_ids."""
917         # pylint: disable=too-many-arguments
918         exp.set(key, val)
919         exp.force(category, list_ids)
920         self.check_json_get(f'/{category}?{key}={val}', exp)
921
922     def check_redirect(self, target: str) -> None:
923         """Check that self.conn answers with a 302 redirect to target."""
924         response = self.conn.getresponse()
925         self.assertEqual(response.status, 302)
926         self.assertEqual(response.getheader('Location'), target)
927
928     def check_get(self, target: str, expected_code: int) -> None:
929         """Check that a GET to target yields expected_code."""
930         self.conn.request('GET', target)
931         self.assertEqual(self.conn.getresponse().status, expected_code)
932
933     def check_post(self, data: Mapping[str, object], target: str,
934                    expected_code: int = 302, redir: str = '') -> None:
935         """Check that POST of data to target yields expected_code."""
936         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
937         headers = {'Content-Type': 'application/x-www-form-urlencoded',
938                    'Content-Length': str(len(encoded_form_data))}
939         self.conn.request('POST', target,
940                           body=encoded_form_data, headers=headers)
941         if 302 == expected_code:
942             redir = target if redir == '' else redir
943             self.check_redirect(redir)
944         else:
945             self.assertEqual(self.conn.getresponse().status, expected_code)
946
947     def check_get_defaults(self, path: str) -> None:
948         """Some standard model paths to test."""
949         self.check_get(path, 200)
950         self.check_get(f'{path}?id=', 200)
951         self.check_get(f'{path}?id=foo', 400)
952         self.check_get(f'/{path}?id=0', 500)
953         self.check_get(f'{path}?id=1', 200)
954
955     def check_json_get(self, path: str, expected: Expected) -> None:
956         """Compare JSON on GET path with expected.
957
958         To simplify comparison of VersionedAttribute histories, transforms
959         timestamp keys of VersionedAttribute history keys into (strings of)
960         integers counting chronologically forward from 0.
961         """
962
963         def rewrite_history_keys_in(item: Any) -> Any:
964             if isinstance(item, dict):
965                 if '_versioned' in item.keys():
966                     for category in item['_versioned']:
967                         vals = item['_versioned'][category].values()
968                         history = {}
969                         for i, val in enumerate(vals):
970                             history[str(i)] = val
971                         item['_versioned'][category] = history
972                 for category in list(item.keys()):
973                     rewrite_history_keys_in(item[category])
974             elif isinstance(item, list):
975                 item[:] = [rewrite_history_keys_in(i) for i in item]
976             return item
977
978         def walk_diffs(path: str, cmp1: object, cmp2: object) -> None:
979             # pylint: disable=too-many-branches
980             def warn(intro: str, val: object) -> None:
981                 if isinstance(val, (str, int, float)):
982                     print(intro, val)
983                 else:
984                     print(intro)
985                     pprint(val)
986             if cmp1 != cmp2:
987                 if isinstance(cmp1, dict) and isinstance(cmp2, dict):
988                     for k, v in cmp1.items():
989                         if k not in cmp2:
990                             warn(f'DIFF {path}: retrieved lacks {k}', v)
991                         elif v != cmp2[k]:
992                             walk_diffs(f'{path}:{k}', v, cmp2[k])
993                     for k in [k for k in cmp2.keys() if k not in cmp1]:
994                         warn(f'DIFF {path}: expected lacks retrieved\'s {k}',
995                              cmp2[k])
996                 elif isinstance(cmp1, list) and isinstance(cmp2, list):
997                     for i, v1 in enumerate(cmp1):
998                         if i >= len(cmp2):
999                             warn(f'DIFF {path}[{i}] retrieved misses:', v1)
1000                         elif v1 != cmp2[i]:
1001                             walk_diffs(f'{path}[{i}]', v1, cmp2[i])
1002                     if len(cmp2) > len(cmp1):
1003                         for i, v2 in enumerate(cmp2[len(cmp1):]):
1004                             warn(f'DIFF {path}[{len(cmp1)+i}] misses:', v2)
1005                 else:
1006                     warn(f'DIFF {path} – for expected:', cmp1)
1007                     warn('… and for retrieved:', cmp2)
1008
1009         self.conn.request('GET', path)
1010         response = self.conn.getresponse()
1011         self.assertEqual(response.status, 200)
1012         retrieved = json_loads(response.read().decode())
1013         rewrite_history_keys_in(retrieved)
1014         cmp = expected.as_dict
1015         try:
1016             self.assertEqual(cmp, retrieved)
1017         except AssertionError as e:
1018             print('EXPECTED:')
1019             pprint(cmp)
1020             print('RETRIEVED:')
1021             pprint(retrieved)
1022             walk_diffs('', cmp, retrieved)
1023             raise e