home · contact · privacy
Add TaskHandler code to actually make previous commit work.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 # pylint: disable=too-many-lines
3 from __future__ import annotations
4 from unittest import TestCase
5 from typing import Mapping, Any, Callable
6 from threading import Thread
7 from http.client import HTTPConnection
8 from datetime import datetime, timedelta
9 from time import sleep
10 from json import loads as json_loads, dumps as json_dumps
11 from urllib.parse import urlencode
12 from uuid import uuid4
13 from os import remove as remove_file
14 from pprint import pprint
15 from plomtask.db import DatabaseFile, DatabaseConnection
16 from plomtask.http import TaskHandler, TaskServer
17 from plomtask.processes import Process, ProcessStep
18 from plomtask.conditions import Condition
19 from plomtask.days import Day
20 from plomtask.dating import DATE_FORMAT
21 from plomtask.todos import Todo
22 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
23 from plomtask.exceptions import NotFoundException, HandledException
24
25
26 VERSIONED_VALS: dict[str,
27                      list[str] | list[float]] = {'str': ['A', 'B'],
28                                                  'float': [0.3, 1.1]}
29 VALID_TRUES = {True, 'True', 'true', '1', 'on'}
30
31
32 class TestCaseAugmented(TestCase):
33     """Tester core providing helpful basic internal decorators and methods."""
34     checked_class: Any
35     default_init_kwargs: dict[str, Any] = {}
36
37     @staticmethod
38     def _run_if_checked_class(f: Callable[..., None]) -> Callable[..., None]:
39         def wrapper(self: TestCase) -> None:
40             if hasattr(self, 'checked_class'):
41                 f(self)
42         return wrapper
43
44     @classmethod
45     def _run_on_versioned_attributes(cls,
46                                      f: Callable[..., None]
47                                      ) -> Callable[..., None]:
48         @cls._run_if_checked_class
49         def wrapper(self: TestCase) -> None:
50             assert isinstance(self, TestCaseAugmented)
51             for attr_name in self.checked_class.to_save_versioned():
52                 default = self.checked_class.versioned_defaults[attr_name]
53                 owner = self.checked_class(None, **self.default_init_kwargs)
54                 attr = getattr(owner, attr_name)
55                 to_set = VERSIONED_VALS[attr.value_type_name]
56                 f(self, owner, attr_name, attr, default, to_set)
57         return wrapper
58
59     @classmethod
60     def _make_from_defaults(cls, id_: float | str | None) -> Any:
61         return cls.checked_class(id_, **cls.default_init_kwargs)
62
63
64 class TestCaseSansDB(TestCaseAugmented):
65     """Tests requiring no DB setup."""
66     legal_ids: list[str] | list[int] = [1, 5]
67     illegal_ids: list[str] | list[int] = [0]
68
69     @TestCaseAugmented._run_if_checked_class
70     def test_id_validation(self) -> None:
71         """Test .id_ validation/setting."""
72         for id_ in self.illegal_ids:
73             with self.assertRaises(HandledException):
74                 self._make_from_defaults(id_)
75         for id_ in self.legal_ids:
76             obj = self._make_from_defaults(id_)
77             self.assertEqual(obj.id_, id_)
78
79     @TestCaseAugmented._run_on_versioned_attributes
80     def test_versioned_set(self,
81                            _: Any,
82                            __: str,
83                            attr: VersionedAttribute,
84                            default: str | float,
85                            to_set: list[str] | list[float]
86                            ) -> None:
87         """Test VersionedAttribute.set() behaves as expected."""
88         attr.set(default)
89         self.assertEqual(list(attr.history.values()), [default])
90         # check same value does not get set twice in a row,
91         # and that not even its timestamp get updated
92         timestamp = list(attr.history.keys())[0]
93         attr.set(default)
94         self.assertEqual(list(attr.history.values()), [default])
95         self.assertEqual(list(attr.history.keys())[0], timestamp)
96         # check that different value _will_ be set/added
97         attr.set(to_set[0])
98         timesorted_vals = [attr.history[t] for
99                            t in sorted(attr.history.keys())]
100         expected = [default, to_set[0]]
101         self.assertEqual(timesorted_vals, expected)
102         # check that a previously used value can be set if not most recent
103         attr.set(default)
104         timesorted_vals = [attr.history[t] for
105                            t in sorted(attr.history.keys())]
106         expected = [default, to_set[0], default]
107         self.assertEqual(timesorted_vals, expected)
108         # again check for same value not being set twice in a row, even for
109         # later items
110         attr.set(to_set[1])
111         timesorted_vals = [attr.history[t] for
112                            t in sorted(attr.history.keys())]
113         expected = [default, to_set[0], default, to_set[1]]
114         self.assertEqual(timesorted_vals, expected)
115         attr.set(to_set[1])
116         self.assertEqual(timesorted_vals, expected)
117
118     @TestCaseAugmented._run_on_versioned_attributes
119     def test_versioned_newest(self,
120                               _: Any,
121                               __: str,
122                               attr: VersionedAttribute,
123                               default: str | float,
124                               to_set: list[str] | list[float]
125                               ) -> None:
126         """Test VersionedAttribute.newest."""
127         # check .newest on empty history returns .default
128         self.assertEqual(attr.newest, default)
129         # check newest element always returned
130         for v in [to_set[0], to_set[1]]:
131             attr.set(v)
132             self.assertEqual(attr.newest, v)
133         # check newest element returned even if also early value
134         attr.set(default)
135         self.assertEqual(attr.newest, default)
136
137     @TestCaseAugmented._run_on_versioned_attributes
138     def test_versioned_at(self,
139                           _: Any,
140                           __: str,
141                           attr: VersionedAttribute,
142                           default: str | float,
143                           to_set: list[str] | list[float]
144                           ) -> None:
145         """Test .at() returns values nearest to queried time, or default."""
146         # check .at() return default on empty history
147         timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
148         self.assertEqual(attr.at(timestamp_a), default)
149         # check value exactly at timestamp returned
150         attr.set(to_set[0])
151         timestamp_b = list(attr.history.keys())[0]
152         self.assertEqual(attr.at(timestamp_b), to_set[0])
153         # check earliest value returned if exists, rather than default
154         self.assertEqual(attr.at(timestamp_a), to_set[0])
155         # check reverts to previous value for timestamps not indexed
156         sleep(0.00001)
157         timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
158         sleep(0.00001)
159         attr.set(to_set[1])
160         timestamp_c = sorted(attr.history.keys())[-1]
161         self.assertEqual(attr.at(timestamp_c), to_set[1])
162         self.assertEqual(attr.at(timestamp_between), to_set[0])
163         sleep(0.00001)
164         timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
165         self.assertEqual(attr.at(timestamp_after_c), to_set[1])
166
167
168 class TestCaseWithDB(TestCaseAugmented):
169     """Module tests not requiring DB setup."""
170     default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
171
172     def setUp(self) -> None:
173         Condition.empty_cache()
174         Day.empty_cache()
175         Process.empty_cache()
176         ProcessStep.empty_cache()
177         Todo.empty_cache()
178         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
179         self.db_conn = DatabaseConnection(self.db_file)
180
181     def tearDown(self) -> None:
182         self.db_conn.close()
183         remove_file(self.db_file.path)
184
185     def _load_from_db(self, id_: int | str) -> list[object]:
186         db_found: list[object] = []
187         for row in self.db_conn.row_where(self.checked_class.table_name,
188                                           'id', id_):
189             db_found += [self.checked_class.from_table_row(self.db_conn,
190                                                            row)]
191         return db_found
192
193     def _change_obj(self, obj: object) -> str:
194         attr_name: str = self.checked_class.to_save_simples[-1]
195         attr = getattr(obj, attr_name)
196         new_attr: str | int | float | bool
197         if isinstance(attr, (int, float)):
198             new_attr = attr + 1
199         elif isinstance(attr, str):
200             new_attr = attr + '_'
201         elif isinstance(attr, bool):
202             new_attr = not attr
203         setattr(obj, attr_name, new_attr)
204         return attr_name
205
206     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
207         """Test both cache and DB equal content."""
208         expected_cache = {}
209         for item in content:
210             expected_cache[item.id_] = item
211         self.assertEqual(self.checked_class.get_cache(), expected_cache)
212         hashes_content = [hash(x) for x in content]
213         db_found: list[Any] = []
214         for item in content:
215             assert isinstance(item.id_, type(self.default_ids[0]))
216             db_found += self._load_from_db(item.id_)
217         hashes_db_found = [hash(x) for x in db_found]
218         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
219
220     def check_by_date_range_with_limits(self,
221                                         date_col: str,
222                                         set_id_field: bool = True
223                                         ) -> None:
224         """Test .by_date_range_with_limits."""
225         # pylint: disable=too-many-locals
226         f = self.checked_class.by_date_range_with_limits
227         # check illegal ranges
228         legal_range = ('yesterday', 'tomorrow')
229         for i in [0, 1]:
230             for bad_date in ['foo', '2024-02-30', '2024-01-01 12:00:00']:
231                 date_range = list(legal_range[:])
232                 date_range[i] = bad_date
233                 with self.assertRaises(HandledException):
234                     f(self.db_conn, date_range, date_col)
235         # check empty, translation of 'yesterday' and 'tomorrow'
236         items, start, end = f(self.db_conn, legal_range, date_col)
237         self.assertEqual(items, [])
238         yesterday = datetime.now() + timedelta(days=-1)
239         tomorrow = datetime.now() + timedelta(days=+1)
240         self.assertEqual(start, yesterday.strftime(DATE_FORMAT))
241         self.assertEqual(end, tomorrow.strftime(DATE_FORMAT))
242         # prepare dated items for non-empty results
243         kwargs_with_date = self.default_init_kwargs.copy()
244         if set_id_field:
245             kwargs_with_date['id_'] = None
246         objs = []
247         dates = ['2024-01-01', '2024-01-02', '2024-01-04']
248         for date in ['2024-01-01', '2024-01-02', '2024-01-04']:
249             kwargs_with_date['date'] = date
250             obj = self.checked_class(**kwargs_with_date)
251             objs += [obj]
252         # check ranges still empty before saving
253         date_range = [dates[0], dates[-1]]
254         self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
255         # check all objs displayed within closed interval
256         for obj in objs:
257             obj.save(self.db_conn)
258         self.assertEqual(f(self.db_conn, date_range, date_col)[0], objs)
259         # check that only displayed what exists within interval
260         date_range = ['2023-12-20', '2024-01-03']
261         expected = [objs[0], objs[1]]
262         self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
263         date_range = ['2024-01-03', '2024-01-30']
264         expected = [objs[2]]
265         self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
266         # check that inverted interval displays nothing
267         date_range = [dates[-1], dates[0]]
268         self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
269         # check that "today" is interpreted, and single-element interval
270         today_date = datetime.now().strftime(DATE_FORMAT)
271         kwargs_with_date['date'] = today_date
272         obj_today = self.checked_class(**kwargs_with_date)
273         obj_today.save(self.db_conn)
274         date_range = ['today', 'today']
275         items, start, end = f(self.db_conn, date_range, date_col)
276         self.assertEqual(start, today_date)
277         self.assertEqual(start, end)
278         self.assertEqual(items, [obj_today])
279
280     @TestCaseAugmented._run_on_versioned_attributes
281     def test_saving_versioned_attributes(self,
282                                          owner: Any,
283                                          attr_name: str,
284                                          attr: VersionedAttribute,
285                                          _: str | float,
286                                          to_set: list[str] | list[float]
287                                          ) -> None:
288         """Test storage and initialization of versioned attributes."""
289
290         def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
291             attr_vals_saved: list[object] = []
292             for row in self.db_conn.row_where(attr.table_name, 'parent',
293                                               owner.id_):
294                 attr_vals_saved += [row[2]]
295             return attr_vals_saved
296
297         attr.set(to_set[0])
298         # check that without attr.save() no rows in DB
299         rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
300         self.assertEqual([], rows)
301         # fail saving attributes on non-saved owner
302         with self.assertRaises(NotFoundException):
303             attr.save(self.db_conn)
304         # check owner.save() created entries as expected in attr table
305         owner.save(self.db_conn)
306         attr_vals_saved = retrieve_attr_vals(attr)
307         self.assertEqual([to_set[0]], attr_vals_saved)
308         # check changing attr val without save affects owner in memory …
309         attr.set(to_set[1])
310         cmp_attr = getattr(owner, attr_name)
311         self.assertEqual(to_set, list(cmp_attr.history.values()))
312         self.assertEqual(cmp_attr.history, attr.history)
313         # … but does not yet affect DB
314         attr_vals_saved = retrieve_attr_vals(attr)
315         self.assertEqual([to_set[0]], attr_vals_saved)
316         # check individual attr.save also stores new val to DB
317         attr.save(self.db_conn)
318         attr_vals_saved = retrieve_attr_vals(attr)
319         self.assertEqual(to_set, attr_vals_saved)
320
321     @TestCaseAugmented._run_if_checked_class
322     def test_saving_and_caching(self) -> None:
323         """Test effects of .cache() and .save()."""
324         id1 = self.default_ids[0]
325         # check failure to cache without ID (if None-ID input possible)
326         if isinstance(id1, int):
327             obj0 = self._make_from_defaults(None)
328             with self.assertRaises(HandledException):
329                 obj0.cache()
330         # check mere object init itself doesn't even store in cache
331         obj1 = self._make_from_defaults(id1)
332         self.assertEqual(self.checked_class.get_cache(), {})
333         # check .cache() fills cache, but not DB
334         obj1.cache()
335         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
336         found_in_db = self._load_from_db(id1)
337         self.assertEqual(found_in_db, [])
338         # check .save() sets ID (for int IDs), updates cache, and fills DB
339         # (expect ID to be set to id1, despite obj1 already having that as ID:
340         # it's generated by cursor.lastrowid on the DB table, and with obj1
341         # not written there, obj2 should get it first!)
342         id_input = None if isinstance(id1, int) else id1
343         obj2 = self._make_from_defaults(id_input)
344         obj2.save(self.db_conn)
345         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
346         # NB: we'll only compare hashes because obj2 itself disappears on
347         # .from_table_row-triggered database reload
348         obj2_hash = hash(obj2)
349         found_in_db += self._load_from_db(id1)
350         self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
351         # check we cannot overwrite obj2 with obj1 despite its same ID,
352         # since it has disappeared now
353         with self.assertRaises(HandledException):
354             obj1.save(self.db_conn)
355
356     @TestCaseAugmented._run_if_checked_class
357     def test_by_id(self) -> None:
358         """Test .by_id()."""
359         id1, id2, _ = self.default_ids
360         # check failure if not yet saved
361         obj1 = self._make_from_defaults(id1)
362         with self.assertRaises(NotFoundException):
363             self.checked_class.by_id(self.db_conn, id1)
364         # check identity of cached and retrieved
365         obj1.cache()
366         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
367         # check identity of saved and retrieved
368         obj2 = self._make_from_defaults(id2)
369         obj2.save(self.db_conn)
370         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
371
372     @TestCaseAugmented._run_if_checked_class
373     def test_by_id_or_create(self) -> None:
374         """Test .by_id_or_create."""
375         # check .by_id_or_create fails if wrong class
376         if not self.checked_class.can_create_by_id:
377             with self.assertRaises(HandledException):
378                 self.checked_class.by_id_or_create(self.db_conn, None)
379             return
380         # check ID input of None creates, on saving, ID=1,2,… for int IDs
381         if isinstance(self.default_ids[0], int):
382             for n in range(2):
383                 item = self.checked_class.by_id_or_create(self.db_conn, None)
384                 self.assertEqual(item.id_, None)
385                 item.save(self.db_conn)
386                 self.assertEqual(item.id_, n+1)
387         # check .by_id_or_create acts like normal instantiation (sans saving)
388         id_ = self.default_ids[2]
389         item = self.checked_class.by_id_or_create(self.db_conn, id_)
390         self.assertEqual(item.id_, id_)
391         with self.assertRaises(NotFoundException):
392             self.checked_class.by_id(self.db_conn, item.id_)
393         self.assertEqual(self.checked_class(item.id_), item)
394
395     @TestCaseAugmented._run_if_checked_class
396     def test_from_table_row(self) -> None:
397         """Test .from_table_row() properly reads in class directly from DB."""
398         id_ = self.default_ids[0]
399         obj = self._make_from_defaults(id_)
400         obj.save(self.db_conn)
401         assert isinstance(obj.id_, type(id_))
402         for row in self.db_conn.row_where(self.checked_class.table_name,
403                                           'id', obj.id_):
404             # check .from_table_row reproduces state saved, no matter if obj
405             # later changed (with caching even)
406             # NB: we'll only compare hashes because obj itself disappears on
407             # .from_table_row-triggered database reload
408             hash_original = hash(obj)
409             attr_name = self._change_obj(obj)
410             obj.cache()
411             to_cmp = getattr(obj, attr_name)
412             retrieved = self.checked_class.from_table_row(self.db_conn, row)
413             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
414             self.assertEqual(hash_original, hash(retrieved))
415             # check cache contains what .from_table_row just produced
416             self.assertEqual({retrieved.id_: retrieved},
417                              self.checked_class.get_cache())
418
419     @TestCaseAugmented._run_on_versioned_attributes
420     def test_versioned_history_from_row(self,
421                                         owner: Any,
422                                         _: str,
423                                         attr: VersionedAttribute,
424                                         default: str | float,
425                                         to_set: list[str] | list[float]
426                                         ) -> None:
427         """"Test VersionedAttribute.history_from_row() knows its DB rows."""
428         attr.set(to_set[0])
429         attr.set(to_set[1])
430         owner.save(self.db_conn)
431         # make empty VersionedAttribute, fill from rows, compare to owner's
432         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
433             loaded_attr = VersionedAttribute(owner, attr.table_name, default)
434             for row in self.db_conn.row_where(attr.table_name, 'parent',
435                                               owner.id_):
436                 loaded_attr.history_from_row(row)
437             self.assertEqual(len(attr.history.keys()),
438                              len(loaded_attr.history.keys()))
439             for timestamp, value in attr.history.items():
440                 self.assertEqual(value, loaded_attr.history[timestamp])
441
442     @TestCaseAugmented._run_if_checked_class
443     def test_all(self) -> None:
444         """Test .all() and its relation to cache and savings."""
445         id1, id2, id3 = self.default_ids
446         item1 = self._make_from_defaults(id1)
447         item2 = self._make_from_defaults(id2)
448         item3 = self._make_from_defaults(id3)
449         # check .all() returns empty list on un-cached items
450         self.assertEqual(self.checked_class.all(self.db_conn), [])
451         # check that all() shows only cached/saved items
452         item1.cache()
453         item3.save(self.db_conn)
454         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
455                          sorted([item1, item3]))
456         item2.save(self.db_conn)
457         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
458                          sorted([item1, item2, item3]))
459
460     @TestCaseAugmented._run_if_checked_class
461     def test_singularity(self) -> None:
462         """Test pointers made for single object keep pointing to it."""
463         id1 = self.default_ids[0]
464         obj = self._make_from_defaults(id1)
465         obj.save(self.db_conn)
466         # change object, expect retrieved through .by_id to carry change
467         attr_name = self._change_obj(obj)
468         new_attr = getattr(obj, attr_name)
469         retrieved = self.checked_class.by_id(self.db_conn, id1)
470         self.assertEqual(new_attr, getattr(retrieved, attr_name))
471
472     @TestCaseAugmented._run_on_versioned_attributes
473     def test_versioned_singularity(self,
474                                    owner: Any,
475                                    attr_name: str,
476                                    attr: VersionedAttribute,
477                                    _: str | float,
478                                    to_set: list[str] | list[float]
479                                    ) -> None:
480         """Test singularity of VersionedAttributes on saving."""
481         owner.save(self.db_conn)
482         # change obj, expect retrieved through .by_id to carry change
483         attr.set(to_set[0])
484         retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
485         attr_retrieved = getattr(retrieved, attr_name)
486         self.assertEqual(attr.history, attr_retrieved.history)
487
488     @TestCaseAugmented._run_if_checked_class
489     def test_remove(self) -> None:
490         """Test .remove() effects on DB and cache."""
491         id_ = self.default_ids[0]
492         obj = self._make_from_defaults(id_)
493         # check removal only works after saving
494         with self.assertRaises(HandledException):
495             obj.remove(self.db_conn)
496         obj.save(self.db_conn)
497         obj.remove(self.db_conn)
498         # check access to obj fails after removal
499         with self.assertRaises(HandledException):
500             print(obj.id_)
501         # check DB and cache now empty
502         self.check_identity_with_cache_and_db([])
503
504
505 class Expected:
506     """Builder of (JSON-like) dict to compare against responses of test server.
507
508     Collects all items and relations we expect expressed in the server's JSON
509     responses and puts them into the proper json.dumps-friendly dict structure,
510     accessibla via .as_dict, to compare them in TestsWithServer.check_json_get.
511
512     On its own provides for .as_dict output only {"_library": …}, initialized
513     from .__init__ and to be directly manipulated via the .lib* methods.
514     Further structures of the expected response may be added and kept
515     up-to-date by subclassing .__init__, .recalc, and .d.
516
517     NB: Lots of expectations towards server behavior will be made explicit here
518     (or in the subclasses) rather than in the actual TestCase methods' code.
519     """
520     _default_dict: dict[str, Any]
521     _forced: dict[str, Any]
522     _fields: dict[str, Any]
523     _on_empty_make_temp: tuple[str, str]
524
525     def __init__(self,
526                  todos: list[dict[str, Any]] | None = None,
527                  procs: list[dict[str, Any]] | None = None,
528                  procsteps: list[dict[str, Any]] | None = None,
529                  conds: list[dict[str, Any]] | None = None,
530                  days: list[dict[str, Any]] | None = None
531                  ) -> None:
532         # pylint: disable=too-many-arguments
533         for name in ['_default_dict', '_fields', '_forced']:
534             if not hasattr(self, name):
535                 setattr(self, name, {})
536         self._lib = {}
537         for title, items in [('Todo', todos),
538                              ('Process', procs),
539                              ('ProcessStep', procsteps),
540                              ('Condition', conds),
541                              ('Day', days)]:
542             if items:
543                 self._lib[title] = self._as_refs(items)
544         for k, v in self._default_dict.items():
545             if k not in self._fields:
546                 self._fields[k] = v
547
548     def recalc(self) -> None:
549         """Update internal dictionary by subclass-specific rules."""
550         todos = self.lib_all('Todo')
551         for todo in todos:
552             todo['parents'] = []
553         for todo in todos:
554             for child_id in todo['children']:
555                 self.lib_get('Todo', child_id)['parents'] += [todo['id']]
556             todo['children'].sort()
557         procsteps = self.lib_all('ProcessStep')
558         procs = self.lib_all('Process')
559         for proc in procs:
560             proc['explicit_steps'] = [s['id'] for s in procsteps
561                                       if s['owner_id'] == proc['id']]
562
563     @property
564     def as_dict(self) -> dict[str, Any]:
565         """Return dict to compare against test server JSON responses."""
566         make_temp = False
567         if hasattr(self, '_on_empty_make_temp'):
568             category, dicter = getattr(self, '_on_empty_make_temp')
569             id_ = self._fields[category.lower()]
570             make_temp = not bool(self.lib_get(category, id_))
571             if make_temp:
572                 f = getattr(self, dicter)
573                 self.lib_set(category, [f(id_)])
574         self.recalc()
575         d = {'_library': self._lib}
576         for k, v in self._fields.items():
577             # we expect everything sortable to be sorted
578             if isinstance(v, list) and k not in self._forced:
579                 # NB: if we don't test for v being list, sorted() on an empty
580                 # dict may return an empty list
581                 try:
582                     v = sorted(v)
583                 except TypeError:
584                     pass
585             d[k] = v
586         for k, v in self._forced.items():
587             d[k] = v
588         if make_temp:
589             json = json_dumps(d)
590             self.lib_del(category, id_)
591             d = json_loads(json)
592         return d
593
594     def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
595         """From library, return item of category and id_, or empty dict."""
596         str_id = str(id_)
597         if category in self._lib and str_id in self._lib[category]:
598             return self._lib[category][str_id]
599         return {}
600
601     def lib_all(self, category: str) -> list[dict[str, Any]]:
602         """From library, return items of category, or [] if none."""
603         if category in self._lib:
604             return list(self._lib[category].values())
605         return []
606
607     def lib_set(self, category: str, items: list[dict[str, object]]) -> None:
608         """Update library for category with items."""
609         if category not in self._lib:
610             self._lib[category] = {}
611         for k, v in self._as_refs(items).items():
612             self._lib[category][k] = v
613
614     def lib_del(self, category: str, id_: str | int) -> None:
615         """Remove category element of id_ from library."""
616         del self._lib[category][str(id_)]
617         if 0 == len(self._lib[category]):
618             del self._lib[category]
619
620     def lib_wipe(self, category: str) -> None:
621         """Remove category from library."""
622         if category in self._lib:
623             del self._lib[category]
624
625     def set(self, field_name: str, value: object) -> None:
626         """Set top-level .as_dict field."""
627         self._fields[field_name] = value
628
629     def force(self, field_name: str, value: object) -> None:
630         """Set ._forced field to ensure value in .as_dict."""
631         self._forced[field_name] = value
632
633     def unforce(self, field_name: str) -> None:
634         """Unset ._forced field."""
635         del self._forced[field_name]
636
637     @staticmethod
638     def _as_refs(items: list[dict[str, object]]
639                  ) -> dict[str, dict[str, object]]:
640         """Return dictionary of items by their 'id' fields."""
641         refs = {}
642         for item in items:
643             refs[str(item['id'])] = item
644         return refs
645
646     @staticmethod
647     def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
648         """Return list of only 'id' fields of items."""
649         return [item['id'] for item in items]
650
651     @staticmethod
652     def day_as_dict(date: str, comment: str = '') -> dict[str, object]:
653         """Return JSON of Day to expect."""
654         return {'id': date, 'comment': comment, 'todos': []}
655
656     def set_day_from_post(self, date: str, d: dict[str, Any]) -> None:
657         """Set Day of date in library based on POST dict d."""
658         day = self.day_as_dict(date)
659         for k, v in d.items():
660             if 'day_comment' == k:
661                 day['comment'] = v
662             elif 'new_todo' == k:
663                 next_id = 1
664                 for todo in self.lib_all('Todo'):
665                     if next_id <= todo['id']:
666                         next_id = todo['id'] + 1
667                 for proc_id in sorted(v):
668                     todo = self.todo_as_dict(next_id, proc_id, date)
669                     self.lib_set('Todo', [todo])
670                     next_id += 1
671             elif 'done' == k:
672                 for todo_id in v:
673                     self.lib_get('Todo', todo_id)['is_done'] = True
674             elif 'todo_id' == k:
675                 for i, todo_id in enumerate(v):
676                     t = self.lib_get('Todo', todo_id)
677                     if 'comment' in d:
678                         t['comment'] = d['comment'][i]
679                     if 'effort' in d:
680                         effort = d['effort'][i] if d['effort'][i] else None
681                         t['effort'] = effort
682         self.lib_set('Day', [day])
683
684     @staticmethod
685     def cond_as_dict(id_: int = 1,
686                      is_active: bool = False,
687                      title: None | str = None,
688                      description: None | str = None,
689                      ) -> dict[str, object]:
690         """Return JSON of Condition to expect."""
691         versioned: dict[str, dict[str, object]]
692         versioned = {'title': {}, 'description': {}}
693         if title is not None:
694             versioned['title']['0'] = title
695         if description is not None:
696             versioned['description']['0'] = description
697         return {'id': id_, 'is_active': is_active, '_versioned': versioned}
698
699     def set_cond_from_post(self, id_: int, d: dict[str, Any]) -> None:
700         """Set Condition of id_ in library based on POST dict d."""
701         if d == {'delete': ''}:
702             self.lib_del('Condition', id_)
703             return
704         cond = self.lib_get('Condition', id_)
705         if cond:
706             cond['is_active'] = d['is_active']
707             for category in ['title', 'description']:
708                 history = cond['_versioned'][category]
709                 if len(history) > 0:
710                     last_i = sorted([int(k) for k in history.keys()])[-1]
711                     if d[category] != history[str(last_i)]:
712                         history[str(last_i + 1)] = d[category]
713                 else:
714                     history['0'] = d[category]
715         else:
716             cond = self.cond_as_dict(
717                     id_, d['is_active'], d['title'], d['description'])
718         self.lib_set('Condition', [cond])
719
720     @staticmethod
721     def todo_as_dict(id_: int = 1,
722                      process_id: int = 1,
723                      date: str = '2024-01-01',
724                      conditions: None | list[int] = None,
725                      disables: None | list[int] = None,
726                      blockers: None | list[int] = None,
727                      enables: None | list[int] = None,
728                      calendarize: bool = False,
729                      comment: str = '',
730                      is_done: bool = False,
731                      effort: float | None = None,
732                      children: list[int] | None = None,
733                      parents: list[int] | None = None,
734                      ) -> dict[str, object]:
735         """Return JSON of Todo to expect."""
736         # pylint: disable=too-many-arguments
737         d = {'id': id_,
738              'date': date,
739              'process_id': process_id,
740              'is_done': is_done,
741              'calendarize': calendarize,
742              'comment': comment,
743              'children': children if children else [],
744              'parents': parents if parents else [],
745              'effort': effort,
746              'conditions': conditions if conditions else [],
747              'disables': disables if disables else [],
748              'blockers': blockers if blockers else [],
749              'enables': enables if enables else []}
750         return d
751
752     def set_todo_from_post(self, id_: int, d: dict[str, Any]) -> None:
753         """Set Todo of id_ in library based on POST dict d."""
754         corrected_kwargs: dict[str, Any] = {'children': []}
755         for k, v in d.items():
756             if k.startswith('step_filler_to_'):
757                 continue
758             elif 'adopt' == k:
759                 new_children = v if isinstance(v, list) else [v]
760                 corrected_kwargs['children'] += new_children
761                 continue
762             elif k in {'is_done', 'calendarize'}:
763                 v = v in VALID_TRUES
764             corrected_kwargs[k] = v
765         todo = self.lib_get('Todo', id_)
766         if todo:
767             for k, v in corrected_kwargs.items():
768                 todo[k] = v
769         else:
770             todo = self.todo_as_dict(id_, **corrected_kwargs)
771         self.lib_set('Todo', [todo])
772
773     @staticmethod
774     def procstep_as_dict(id_: int,
775                          owner_id: int,
776                          step_process_id: int,
777                          parent_step_id: int | None = None
778                          ) -> dict[str, object]:
779         """Return JSON of ProcessStep to expect."""
780         return {'id': id_,
781                 'owner_id': owner_id,
782                 'step_process_id': step_process_id,
783                 'parent_step_id': parent_step_id}
784
785     @staticmethod
786     def proc_as_dict(id_: int = 1,
787                      title: None | str = None,
788                      description: None | str = None,
789                      effort: None | float = None,
790                      conditions: None | list[int] = None,
791                      disables: None | list[int] = None,
792                      blockers: None | list[int] = None,
793                      enables: None | list[int] = None,
794                      explicit_steps: None | list[int] = None
795                      ) -> dict[str, object]:
796         """Return JSON of Process to expect."""
797         # pylint: disable=too-many-arguments
798         versioned: dict[str, dict[str, object]]
799         versioned = {'title': {}, 'description': {}, 'effort': {}}
800         if title is not None:
801             versioned['title']['0'] = title
802         if description is not None:
803             versioned['description']['0'] = description
804         if effort is not None:
805             versioned['effort']['0'] = effort
806         d = {'id': id_,
807              'calendarize': False,
808              'suppressed_steps': [],
809              'explicit_steps': explicit_steps if explicit_steps else [],
810              '_versioned': versioned,
811              'conditions': conditions if conditions else [],
812              'disables': disables if disables else [],
813              'enables': enables if enables else [],
814              'blockers': blockers if blockers else []}
815         return d
816
817     def set_proc_from_post(self, id_: int, d: dict[str, Any]) -> None:
818         """Set Process of id_ in library based on POST dict d."""
819         proc = self.lib_get('Process', id_)
820         if proc:
821             for category in ['title', 'description', 'effort']:
822                 history = proc['_versioned'][category]
823                 if len(history) > 0:
824                     last_i = sorted([int(k) for k in history.keys()])[-1]
825                     if d[category] != history[str(last_i)]:
826                         history[str(last_i + 1)] = d[category]
827                 else:
828                     history['0'] = d[category]
829         else:
830             proc = self.proc_as_dict(id_,
831                                      d['title'], d['description'], d['effort'])
832         ignore = {'title', 'description', 'effort', 'new_top_step', 'step_of',
833                   'kept_steps'}
834         for k, v in d.items():
835             if k in ignore\
836                     or k.startswith('step_') or k.startswith('new_step_to'):
837                 continue
838             if k in {'calendarize'}:
839                 v = v in VALID_TRUES
840             elif k in {'suppressed_steps', 'explicit_steps', 'conditions',
841                        'disables', 'enables', 'blockers'}:
842                 if not isinstance(v, list):
843                     v = [v]
844             proc[k] = v
845         self.lib_set('Process', [proc])
846
847
848 class TestCaseWithServer(TestCaseWithDB):
849     """Module tests against our HTTP server/handler (and database)."""
850
851     def setUp(self) -> None:
852         super().setUp()
853         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
854         self.server_thread = Thread(target=self.httpd.serve_forever)
855         self.server_thread.daemon = True
856         self.server_thread.start()
857         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
858                                    self.httpd.server_address[1])
859         self.httpd.render_mode = 'json'
860
861     def tearDown(self) -> None:
862         self.httpd.shutdown()
863         self.httpd.server_close()
864         self.server_thread.join()
865         super().tearDown()
866
867     def post_exp_cond(self,
868                       exps: list[Expected],
869                       id_: int,
870                       payload: dict[str, object],
871                       path_suffix: str = '',
872                       redir_suffix: str = ''
873                       ) -> None:
874         """POST /condition(s), appropriately update Expecteds."""
875         # pylint: disable=too-many-arguments
876         path = f'/condition{path_suffix}'
877         redir = f'/condition{redir_suffix}'
878         self.check_post(payload, path, redir=redir)
879         for exp in exps:
880             exp.set_cond_from_post(id_, payload)
881
882     def post_exp_day(self,
883                      exps: list[Expected],
884                      payload: dict[str, Any],
885                      date: str = '2024-01-01'
886                      ) -> None:
887         """POST /day, appropriately update Expecteds."""
888         if 'make_type' not in payload:
889             payload['make_type'] = 'empty'
890         if 'day_comment' not in payload:
891             payload['day_comment'] = ''
892         target = f'/day?date={date}'
893         redir_to = f'{target}&make_type={payload["make_type"]}'
894         self.check_post(payload, target, 302, redir_to)
895         for exp in exps:
896             exp.set_day_from_post(date, payload)
897
898     def post_exp_process(self,
899                          exps: list[Expected],
900                          payload: dict[str, Any],
901                          id_: int,
902                          ) -> dict[str, object]:
903         """POST /process, appropriately update Expecteds."""
904         if 'title' not in payload:
905             payload['title'] = 'foo'
906         if 'description' not in payload:
907             payload['description'] = 'foo'
908         if 'effort' not in payload:
909             payload['effort'] = 1.1
910         self.check_post(payload, f'/process?id={id_}',
911                         redir=f'/process?id={id_}')
912         for exp in exps:
913             exp.set_proc_from_post(id_, payload)
914         return payload
915
916     def check_filter(self, exp: Expected, category: str, key: str,
917                      val: str, list_ids: list[int]) -> None:
918         """Check GET /{category}?{key}={val} sorts to list_ids."""
919         # pylint: disable=too-many-arguments
920         exp.set(key, val)
921         exp.force(category, list_ids)
922         self.check_json_get(f'/{category}?{key}={val}', exp)
923
924     def check_redirect(self, target: str) -> None:
925         """Check that self.conn answers with a 302 redirect to target."""
926         response = self.conn.getresponse()
927         self.assertEqual(response.status, 302)
928         self.assertEqual(response.getheader('Location'), target)
929
930     def check_get(self, target: str, expected_code: int) -> None:
931         """Check that a GET to target yields expected_code."""
932         self.conn.request('GET', target)
933         self.assertEqual(self.conn.getresponse().status, expected_code)
934
935     def check_post(self, data: Mapping[str, object], target: str,
936                    expected_code: int = 302, redir: str = '') -> None:
937         """Check that POST of data to target yields expected_code."""
938         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
939         headers = {'Content-Type': 'application/x-www-form-urlencoded',
940                    'Content-Length': str(len(encoded_form_data))}
941         self.conn.request('POST', target,
942                           body=encoded_form_data, headers=headers)
943         if 302 == expected_code:
944             redir = target if redir == '' else redir
945             self.check_redirect(redir)
946         else:
947             self.assertEqual(self.conn.getresponse().status, expected_code)
948
949     def check_get_defaults(self, path: str) -> None:
950         """Some standard model paths to test."""
951         self.check_get(path, 200)
952         self.check_get(f'{path}?id=', 200)
953         self.check_get(f'{path}?id=foo', 400)
954         self.check_get(f'/{path}?id=0', 500)
955         self.check_get(f'{path}?id=1', 200)
956
957     def check_json_get(self, path: str, expected: Expected) -> None:
958         """Compare JSON on GET path with expected.
959
960         To simplify comparison of VersionedAttribute histories, transforms
961         timestamp keys of VersionedAttribute history keys into (strings of)
962         integers counting chronologically forward from 0.
963         """
964
965         def rewrite_history_keys_in(item: Any) -> Any:
966             if isinstance(item, dict):
967                 if '_versioned' in item.keys():
968                     for category in item['_versioned']:
969                         vals = item['_versioned'][category].values()
970                         history = {}
971                         for i, val in enumerate(vals):
972                             history[str(i)] = val
973                         item['_versioned'][category] = history
974                 for category in list(item.keys()):
975                     rewrite_history_keys_in(item[category])
976             elif isinstance(item, list):
977                 item[:] = [rewrite_history_keys_in(i) for i in item]
978             return item
979
980         def walk_diffs(path: str, cmp1: object, cmp2: object) -> None:
981             # pylint: disable=too-many-branches
982             def warn(intro: str, val: object) -> None:
983                 if isinstance(val, (str, int, float)):
984                     print(intro, val)
985                 else:
986                     print(intro)
987                     pprint(val)
988             if cmp1 != cmp2:
989                 if isinstance(cmp1, dict) and isinstance(cmp2, dict):
990                     for k, v in cmp1.items():
991                         if k not in cmp2:
992                             warn(f'DIFF {path}: retrieved lacks {k}', v)
993                         elif v != cmp2[k]:
994                             walk_diffs(f'{path}:{k}', v, cmp2[k])
995                     for k in [k for k in cmp2.keys() if k not in cmp1]:
996                         warn(f'DIFF {path}: expected lacks retrieved\'s {k}',
997                              cmp2[k])
998                 elif isinstance(cmp1, list) and isinstance(cmp2, list):
999                     for i, v1 in enumerate(cmp1):
1000                         if i >= len(cmp2):
1001                             warn(f'DIFF {path}[{i}] retrieved misses:', v1)
1002                         elif v1 != cmp2[i]:
1003                             walk_diffs(f'{path}[{i}]', v1, cmp2[i])
1004                     if len(cmp2) > len(cmp1):
1005                         for i, v2 in enumerate(cmp2[len(cmp1):]):
1006                             warn(f'DIFF {path}[{len(cmp1)+i}] misses:', v2)
1007                 else:
1008                     warn(f'DIFF {path} – for expected:', cmp1)
1009                     warn('… and for retrieved:', cmp2)
1010
1011         self.conn.request('GET', path)
1012         response = self.conn.getresponse()
1013         self.assertEqual(response.status, 200)
1014         retrieved = json_loads(response.read().decode())
1015         rewrite_history_keys_in(retrieved)
1016         cmp = expected.as_dict
1017         try:
1018             self.assertEqual(cmp, retrieved)
1019         except AssertionError as e:
1020             print('EXPECTED:')
1021             pprint(cmp)
1022             print('RETRIEVED:')
1023             pprint(retrieved)
1024             walk_diffs('', cmp, retrieved)
1025             raise e