home · contact · privacy
Extend and refactor tests.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 # pylint: disable=too-many-lines
3 from __future__ import annotations
4 from unittest import TestCase
5 from typing import Mapping, Any, Callable
6 from threading import Thread
7 from http.client import HTTPConnection
8 from datetime import datetime, timedelta
9 from time import sleep
10 from json import loads as json_loads, dumps as json_dumps
11 from urllib.parse import urlencode
12 from uuid import uuid4
13 from os import remove as remove_file
14 from pprint import pprint
15 from plomtask.db import DatabaseFile, DatabaseConnection
16 from plomtask.http import TaskHandler, TaskServer
17 from plomtask.processes import Process, ProcessStep
18 from plomtask.conditions import Condition
19 from plomtask.days import Day
20 from plomtask.dating import DATE_FORMAT
21 from plomtask.todos import Todo
22 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
23 from plomtask.exceptions import NotFoundException, HandledException
24
25
26 VERSIONED_VALS: dict[str,
27                      list[str] | list[float]] = {'str': ['A', 'B'],
28                                                  'float': [0.3, 1.1]}
29 VALID_TRUES = {True, 'True', 'true', '1', 'on'}
30
31
32 class TestCaseAugmented(TestCase):
33     """Tester core providing helpful basic internal decorators and methods."""
34     checked_class: Any
35     default_init_kwargs: dict[str, Any] = {}
36
37     @staticmethod
38     def _run_on_versioned_attributes(f: Callable[..., None]
39                                      ) -> Callable[..., None]:
40         def wrapper(self: TestCase) -> None:
41             assert isinstance(self, TestCaseAugmented)
42             for attr_name in self.checked_class.to_save_versioned():
43                 default = self.checked_class.versioned_defaults[attr_name]
44                 owner = self.checked_class(None, **self.default_init_kwargs)
45                 attr = getattr(owner, attr_name)
46                 to_set = VERSIONED_VALS[attr.value_type_name]
47                 f(self, owner, attr_name, attr, default, to_set)
48         return wrapper
49
50     @classmethod
51     def _run_if_sans_db(cls, f: Callable[..., None]) -> Callable[..., None]:
52         def wrapper(self: TestCaseSansDB) -> None:
53             if issubclass(cls, TestCaseSansDB):
54                 f(self)
55         return wrapper
56
57     @classmethod
58     def _run_if_with_db_but_not_server(cls,
59                                        f: Callable[..., None]
60                                        ) -> Callable[..., None]:
61         def wrapper(self: TestCaseWithDB) -> None:
62             if issubclass(cls, TestCaseWithDB) and\
63                     not issubclass(cls, TestCaseWithServer):
64                 f(self)
65         return wrapper
66
67     @classmethod
68     def _make_from_defaults(cls, id_: float | str | None) -> Any:
69         return cls.checked_class(id_, **cls.default_init_kwargs)
70
71
72 class TestCaseSansDB(TestCaseAugmented):
73     """Tests requiring no DB setup."""
74     legal_ids: list[str] | list[int] = [1, 5]
75     illegal_ids: list[str] | list[int] = [0]
76
77     @TestCaseAugmented._run_if_sans_db
78     def test_id_validation(self) -> None:
79         """Test .id_ validation/setting."""
80         for id_ in self.illegal_ids:
81             with self.assertRaises(HandledException):
82                 self._make_from_defaults(id_)
83         for id_ in self.legal_ids:
84             obj = self._make_from_defaults(id_)
85             self.assertEqual(obj.id_, id_)
86
87     @TestCaseAugmented._run_if_sans_db
88     @TestCaseAugmented._run_on_versioned_attributes
89     def test_versioned_set(self,
90                            _: Any,
91                            __: str,
92                            attr: VersionedAttribute,
93                            default: str | float,
94                            to_set: list[str] | list[float]
95                            ) -> None:
96         """Test VersionedAttribute.set() behaves as expected."""
97         attr.set(default)
98         self.assertEqual(list(attr.history.values()), [default])
99         # check same value does not get set twice in a row,
100         # and that not even its timestamp get updated
101         timestamp = list(attr.history.keys())[0]
102         attr.set(default)
103         self.assertEqual(list(attr.history.values()), [default])
104         self.assertEqual(list(attr.history.keys())[0], timestamp)
105         # check that different value _will_ be set/added
106         attr.set(to_set[0])
107         timesorted_vals = [attr.history[t] for
108                            t in sorted(attr.history.keys())]
109         expected = [default, to_set[0]]
110         self.assertEqual(timesorted_vals, expected)
111         # check that a previously used value can be set if not most recent
112         attr.set(default)
113         timesorted_vals = [attr.history[t] for
114                            t in sorted(attr.history.keys())]
115         expected = [default, to_set[0], default]
116         self.assertEqual(timesorted_vals, expected)
117         # again check for same value not being set twice in a row, even for
118         # later items
119         attr.set(to_set[1])
120         timesorted_vals = [attr.history[t] for
121                            t in sorted(attr.history.keys())]
122         expected = [default, to_set[0], default, to_set[1]]
123         self.assertEqual(timesorted_vals, expected)
124         attr.set(to_set[1])
125         self.assertEqual(timesorted_vals, expected)
126
127     @TestCaseAugmented._run_if_sans_db
128     @TestCaseAugmented._run_on_versioned_attributes
129     def test_versioned_newest(self,
130                               _: Any,
131                               __: str,
132                               attr: VersionedAttribute,
133                               default: str | float,
134                               to_set: list[str] | list[float]
135                               ) -> None:
136         """Test VersionedAttribute.newest."""
137         # check .newest on empty history returns .default
138         self.assertEqual(attr.newest, default)
139         # check newest element always returned
140         for v in [to_set[0], to_set[1]]:
141             attr.set(v)
142             self.assertEqual(attr.newest, v)
143         # check newest element returned even if also early value
144         attr.set(default)
145         self.assertEqual(attr.newest, default)
146
147     @TestCaseAugmented._run_if_sans_db
148     @TestCaseAugmented._run_on_versioned_attributes
149     def test_versioned_at(self,
150                           _: Any,
151                           __: str,
152                           attr: VersionedAttribute,
153                           default: str | float,
154                           to_set: list[str] | list[float]
155                           ) -> None:
156         """Test .at() returns values nearest to queried time, or default."""
157         # check .at() return default on empty history
158         timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
159         self.assertEqual(attr.at(timestamp_a), default)
160         # check value exactly at timestamp returned
161         attr.set(to_set[0])
162         timestamp_b = list(attr.history.keys())[0]
163         self.assertEqual(attr.at(timestamp_b), to_set[0])
164         # check earliest value returned if exists, rather than default
165         self.assertEqual(attr.at(timestamp_a), to_set[0])
166         # check reverts to previous value for timestamps not indexed
167         sleep(0.00001)
168         timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
169         sleep(0.00001)
170         attr.set(to_set[1])
171         timestamp_c = sorted(attr.history.keys())[-1]
172         self.assertEqual(attr.at(timestamp_c), to_set[1])
173         self.assertEqual(attr.at(timestamp_between), to_set[0])
174         sleep(0.00001)
175         timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
176         self.assertEqual(attr.at(timestamp_after_c), to_set[1])
177
178
179 class TestCaseWithDB(TestCaseAugmented):
180     """Module tests not requiring DB setup."""
181     default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
182
183     def setUp(self) -> None:
184         Condition.empty_cache()
185         Day.empty_cache()
186         Process.empty_cache()
187         ProcessStep.empty_cache()
188         Todo.empty_cache()
189         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
190         self.db_conn = DatabaseConnection(self.db_file)
191
192     def tearDown(self) -> None:
193         self.db_conn.close()
194         remove_file(self.db_file.path)
195
196     def _load_from_db(self, id_: int | str) -> list[object]:
197         db_found: list[object] = []
198         for row in self.db_conn.row_where(self.checked_class.table_name,
199                                           'id', id_):
200             db_found += [self.checked_class.from_table_row(self.db_conn,
201                                                            row)]
202         return db_found
203
204     def _change_obj(self, obj: object) -> str:
205         attr_name: str = self.checked_class.to_save_simples[-1]
206         attr = getattr(obj, attr_name)
207         new_attr: str | int | float | bool
208         if isinstance(attr, (int, float)):
209             new_attr = attr + 1
210         elif isinstance(attr, str):
211             new_attr = attr + '_'
212         elif isinstance(attr, bool):
213             new_attr = not attr
214         setattr(obj, attr_name, new_attr)
215         return attr_name
216
217     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
218         """Test both cache and DB equal content."""
219         expected_cache = {}
220         for item in content:
221             expected_cache[item.id_] = item
222         self.assertEqual(self.checked_class.get_cache(), expected_cache)
223         hashes_content = [hash(x) for x in content]
224         db_found: list[Any] = []
225         for item in content:
226             assert isinstance(item.id_, type(self.default_ids[0]))
227             db_found += self._load_from_db(item.id_)
228         hashes_db_found = [hash(x) for x in db_found]
229         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
230
231     def check_by_date_range_with_limits(self,
232                                         date_col: str,
233                                         set_id_field: bool = True
234                                         ) -> None:
235         """Test .by_date_range_with_limits."""
236         # pylint: disable=too-many-locals
237         f = self.checked_class.by_date_range_with_limits
238         # check illegal ranges
239         legal_range = ('yesterday', 'tomorrow')
240         for i in [0, 1]:
241             for bad_date in ['foo', '2024-02-30', '2024-01-01 12:00:00']:
242                 date_range = list(legal_range[:])
243                 date_range[i] = bad_date
244                 with self.assertRaises(HandledException):
245                     f(self.db_conn, date_range, date_col)
246         # check empty, translation of 'yesterday' and 'tomorrow'
247         items, start, end = f(self.db_conn, legal_range, date_col)
248         self.assertEqual(items, [])
249         yesterday = datetime.now() + timedelta(days=-1)
250         tomorrow = datetime.now() + timedelta(days=+1)
251         self.assertEqual(start, yesterday.strftime(DATE_FORMAT))
252         self.assertEqual(end, tomorrow.strftime(DATE_FORMAT))
253         # prepare dated items for non-empty results
254         kwargs_with_date = self.default_init_kwargs.copy()
255         if set_id_field:
256             kwargs_with_date['id_'] = None
257         objs = []
258         dates = ['2024-01-01', '2024-01-02', '2024-01-04']
259         for date in ['2024-01-01', '2024-01-02', '2024-01-04']:
260             kwargs_with_date['date'] = date
261             obj = self.checked_class(**kwargs_with_date)
262             objs += [obj]
263         # check ranges still empty before saving
264         date_range = [dates[0], dates[-1]]
265         self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
266         # check all objs displayed within closed interval
267         for obj in objs:
268             obj.save(self.db_conn)
269         self.assertEqual(f(self.db_conn, date_range, date_col)[0], objs)
270         # check that only displayed what exists within interval
271         date_range = ['2023-12-20', '2024-01-03']
272         expected = [objs[0], objs[1]]
273         self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
274         date_range = ['2024-01-03', '2024-01-30']
275         expected = [objs[2]]
276         self.assertEqual(f(self.db_conn, date_range, date_col)[0], expected)
277         # check that inverted interval displays nothing
278         date_range = [dates[-1], dates[0]]
279         self.assertEqual(f(self.db_conn, date_range, date_col)[0], [])
280         # check that "today" is interpreted, and single-element interval
281         today_date = datetime.now().strftime(DATE_FORMAT)
282         kwargs_with_date['date'] = today_date
283         obj_today = self.checked_class(**kwargs_with_date)
284         obj_today.save(self.db_conn)
285         date_range = ['today', 'today']
286         items, start, end = f(self.db_conn, date_range, date_col)
287         self.assertEqual(start, today_date)
288         self.assertEqual(start, end)
289         self.assertEqual(items, [obj_today])
290
291     @TestCaseAugmented._run_if_with_db_but_not_server
292     @TestCaseAugmented._run_on_versioned_attributes
293     def test_saving_versioned_attributes(self,
294                                          owner: Any,
295                                          attr_name: str,
296                                          attr: VersionedAttribute,
297                                          _: str | float,
298                                          to_set: list[str] | list[float]
299                                          ) -> None:
300         """Test storage and initialization of versioned attributes."""
301
302         def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
303             attr_vals_saved: list[object] = []
304             for row in self.db_conn.row_where(attr.table_name, 'parent',
305                                               owner.id_):
306                 attr_vals_saved += [row[2]]
307             return attr_vals_saved
308
309         attr.set(to_set[0])
310         # check that without attr.save() no rows in DB
311         rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
312         self.assertEqual([], rows)
313         # fail saving attributes on non-saved owner
314         with self.assertRaises(NotFoundException):
315             attr.save(self.db_conn)
316         # check owner.save() created entries as expected in attr table
317         owner.save(self.db_conn)
318         attr_vals_saved = retrieve_attr_vals(attr)
319         self.assertEqual([to_set[0]], attr_vals_saved)
320         # check changing attr val without save affects owner in memory …
321         attr.set(to_set[1])
322         cmp_attr = getattr(owner, attr_name)
323         self.assertEqual(to_set, list(cmp_attr.history.values()))
324         self.assertEqual(cmp_attr.history, attr.history)
325         # … but does not yet affect DB
326         attr_vals_saved = retrieve_attr_vals(attr)
327         self.assertEqual([to_set[0]], attr_vals_saved)
328         # check individual attr.save also stores new val to DB
329         attr.save(self.db_conn)
330         attr_vals_saved = retrieve_attr_vals(attr)
331         self.assertEqual(to_set, attr_vals_saved)
332
333     @TestCaseAugmented._run_if_with_db_but_not_server
334     def test_saving_and_caching(self) -> None:
335         """Test effects of .cache() and .save()."""
336         id1 = self.default_ids[0]
337         # check failure to cache without ID (if None-ID input possible)
338         if isinstance(id1, int):
339             obj0 = self._make_from_defaults(None)
340             with self.assertRaises(HandledException):
341                 obj0.cache()
342         # check mere object init itself doesn't even store in cache
343         obj1 = self._make_from_defaults(id1)
344         self.assertEqual(self.checked_class.get_cache(), {})
345         # check .cache() fills cache, but not DB
346         obj1.cache()
347         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
348         found_in_db = self._load_from_db(id1)
349         self.assertEqual(found_in_db, [])
350         # check .save() sets ID (for int IDs), updates cache, and fills DB
351         # (expect ID to be set to id1, despite obj1 already having that as ID:
352         # it's generated by cursor.lastrowid on the DB table, and with obj1
353         # not written there, obj2 should get it first!)
354         id_input = None if isinstance(id1, int) else id1
355         obj2 = self._make_from_defaults(id_input)
356         obj2.save(self.db_conn)
357         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
358         # NB: we'll only compare hashes because obj2 itself disappears on
359         # .from_table_row-triggered database reload
360         obj2_hash = hash(obj2)
361         found_in_db += self._load_from_db(id1)
362         self.assertEqual([hash(o) for o in found_in_db], [obj2_hash])
363         # check we cannot overwrite obj2 with obj1 despite its same ID,
364         # since it has disappeared now
365         with self.assertRaises(HandledException):
366             obj1.save(self.db_conn)
367
368     @TestCaseAugmented._run_if_with_db_but_not_server
369     def test_by_id(self) -> None:
370         """Test .by_id()."""
371         id1, id2, _ = self.default_ids
372         # check failure if not yet saved
373         obj1 = self._make_from_defaults(id1)
374         with self.assertRaises(NotFoundException):
375             self.checked_class.by_id(self.db_conn, id1)
376         # check identity of cached and retrieved
377         obj1.cache()
378         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
379         # check identity of saved and retrieved
380         obj2 = self._make_from_defaults(id2)
381         obj2.save(self.db_conn)
382         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
383
384     @TestCaseAugmented._run_if_with_db_but_not_server
385     def test_by_id_or_create(self) -> None:
386         """Test .by_id_or_create."""
387         # check .by_id_or_create fails if wrong class
388         if not self.checked_class.can_create_by_id:
389             with self.assertRaises(HandledException):
390                 self.checked_class.by_id_or_create(self.db_conn, None)
391             return
392         # check ID input of None creates, on saving, ID=1,2,… for int IDs
393         if isinstance(self.default_ids[0], int):
394             for n in range(2):
395                 item = self.checked_class.by_id_or_create(self.db_conn, None)
396                 self.assertEqual(item.id_, None)
397                 item.save(self.db_conn)
398                 self.assertEqual(item.id_, n+1)
399         # check .by_id_or_create acts like normal instantiation (sans saving)
400         id_ = self.default_ids[2]
401         item = self.checked_class.by_id_or_create(self.db_conn, id_)
402         self.assertEqual(item.id_, id_)
403         with self.assertRaises(NotFoundException):
404             self.checked_class.by_id(self.db_conn, item.id_)
405         self.assertEqual(self.checked_class(item.id_), item)
406
407     @TestCaseAugmented._run_if_with_db_but_not_server
408     def test_from_table_row(self) -> None:
409         """Test .from_table_row() properly reads in class directly from DB."""
410         id_ = self.default_ids[0]
411         obj = self._make_from_defaults(id_)
412         obj.save(self.db_conn)
413         assert isinstance(obj.id_, type(id_))
414         for row in self.db_conn.row_where(self.checked_class.table_name,
415                                           'id', obj.id_):
416             # check .from_table_row reproduces state saved, no matter if obj
417             # later changed (with caching even)
418             # NB: we'll only compare hashes because obj itself disappears on
419             # .from_table_row-triggered database reload
420             hash_original = hash(obj)
421             attr_name = self._change_obj(obj)
422             obj.cache()
423             to_cmp = getattr(obj, attr_name)
424             retrieved = self.checked_class.from_table_row(self.db_conn, row)
425             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
426             self.assertEqual(hash_original, hash(retrieved))
427             # check cache contains what .from_table_row just produced
428             self.assertEqual({retrieved.id_: retrieved},
429                              self.checked_class.get_cache())
430
431     @TestCaseAugmented._run_if_with_db_but_not_server
432     @TestCaseAugmented._run_on_versioned_attributes
433     def test_versioned_history_from_row(self,
434                                         owner: Any,
435                                         _: str,
436                                         attr: VersionedAttribute,
437                                         default: str | float,
438                                         to_set: list[str] | list[float]
439                                         ) -> None:
440         """"Test VersionedAttribute.history_from_row() knows its DB rows."""
441         attr.set(to_set[0])
442         attr.set(to_set[1])
443         owner.save(self.db_conn)
444         # make empty VersionedAttribute, fill from rows, compare to owner's
445         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
446             loaded_attr = VersionedAttribute(owner, attr.table_name, default)
447             for row in self.db_conn.row_where(attr.table_name, 'parent',
448                                               owner.id_):
449                 loaded_attr.history_from_row(row)
450             self.assertEqual(len(attr.history.keys()),
451                              len(loaded_attr.history.keys()))
452             for timestamp, value in attr.history.items():
453                 self.assertEqual(value, loaded_attr.history[timestamp])
454
455     @TestCaseAugmented._run_if_with_db_but_not_server
456     def test_all(self) -> None:
457         """Test .all() and its relation to cache and savings."""
458         id1, id2, id3 = self.default_ids
459         item1 = self._make_from_defaults(id1)
460         item2 = self._make_from_defaults(id2)
461         item3 = self._make_from_defaults(id3)
462         # check .all() returns empty list on un-cached items
463         self.assertEqual(self.checked_class.all(self.db_conn), [])
464         # check that all() shows only cached/saved items
465         item1.cache()
466         item3.save(self.db_conn)
467         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
468                          sorted([item1, item3]))
469         item2.save(self.db_conn)
470         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
471                          sorted([item1, item2, item3]))
472
473     @TestCaseAugmented._run_if_with_db_but_not_server
474     def test_singularity(self) -> None:
475         """Test pointers made for single object keep pointing to it."""
476         id1 = self.default_ids[0]
477         obj = self._make_from_defaults(id1)
478         obj.save(self.db_conn)
479         # change object, expect retrieved through .by_id to carry change
480         attr_name = self._change_obj(obj)
481         new_attr = getattr(obj, attr_name)
482         retrieved = self.checked_class.by_id(self.db_conn, id1)
483         self.assertEqual(new_attr, getattr(retrieved, attr_name))
484
485     @TestCaseAugmented._run_if_with_db_but_not_server
486     @TestCaseAugmented._run_on_versioned_attributes
487     def test_versioned_singularity(self,
488                                    owner: Any,
489                                    attr_name: str,
490                                    attr: VersionedAttribute,
491                                    _: str | float,
492                                    to_set: list[str] | list[float]
493                                    ) -> None:
494         """Test singularity of VersionedAttributes on saving."""
495         owner.save(self.db_conn)
496         # change obj, expect retrieved through .by_id to carry change
497         attr.set(to_set[0])
498         retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
499         attr_retrieved = getattr(retrieved, attr_name)
500         self.assertEqual(attr.history, attr_retrieved.history)
501
502     @TestCaseAugmented._run_if_with_db_but_not_server
503     def test_remove(self) -> None:
504         """Test .remove() effects on DB and cache."""
505         id_ = self.default_ids[0]
506         obj = self._make_from_defaults(id_)
507         # check removal only works after saving
508         with self.assertRaises(HandledException):
509             obj.remove(self.db_conn)
510         obj.save(self.db_conn)
511         obj.remove(self.db_conn)
512         # check access to obj fails after removal
513         with self.assertRaises(HandledException):
514             print(obj.id_)
515         # check DB and cache now empty
516         self.check_identity_with_cache_and_db([])
517
518
519 class Expected:
520     """Builder of (JSON-like) dict to compare against responses of test server.
521
522     Collects all items and relations we expect expressed in the server's JSON
523     responses and puts them into the proper json.dumps-friendly dict structure,
524     accessibla via .as_dict, to compare them in TestsWithServer.check_json_get.
525
526     On its own provides for .as_dict output only {"_library": …}, initialized
527     from .__init__ and to be directly manipulated via the .lib* methods.
528     Further structures of the expected response may be added and kept
529     up-to-date by subclassing .__init__, .recalc, and .d.
530
531     NB: Lots of expectations towards server behavior will be made explicit here
532     (or in the subclasses) rather than in the actual TestCase methods' code.
533     """
534     _default_dict: dict[str, Any]
535     _forced: dict[str, Any]
536     _fields: dict[str, Any]
537     _on_empty_make_temp: tuple[str, str]
538
539     def __init__(self,
540                  todos: list[dict[str, Any]] | None = None,
541                  procs: list[dict[str, Any]] | None = None,
542                  procsteps: list[dict[str, Any]] | None = None,
543                  conds: list[dict[str, Any]] | None = None,
544                  days: list[dict[str, Any]] | None = None
545                  ) -> None:
546         # pylint: disable=too-many-arguments
547         for name in ['_default_dict', '_fields', '_forced']:
548             if not hasattr(self, name):
549                 setattr(self, name, {})
550         self._lib = {}
551         for title, items in [('Todo', todos),
552                              ('Process', procs),
553                              ('ProcessStep', procsteps),
554                              ('Condition', conds),
555                              ('Day', days)]:
556             if items:
557                 self._lib[title] = self._as_refs(items)
558         for k, v in self._default_dict.items():
559             if k not in self._fields:
560                 self._fields[k] = v
561
562     def recalc(self) -> None:
563         """Update internal dictionary by subclass-specific rules."""
564         todos = self.lib_all('Todo')
565         for todo in todos:
566             todo['parents'] = []
567         for todo in todos:
568             for child_id in todo['children']:
569                 self.lib_get('Todo', child_id)['parents'] += [todo['id']]
570             todo['children'].sort()
571         procsteps = self.lib_all('ProcessStep')
572         procs = self.lib_all('Process')
573         for proc in procs:
574             proc['explicit_steps'] = [s['id'] for s in procsteps
575                                       if s['owner_id'] == proc['id']]
576
577     @property
578     def as_dict(self) -> dict[str, Any]:
579         """Return dict to compare against test server JSON responses."""
580         make_temp = False
581         if hasattr(self, '_on_empty_make_temp'):
582             category, dicter = getattr(self, '_on_empty_make_temp')
583             id_ = self._fields[category.lower()]
584             make_temp = not bool(self.lib_get(category, id_))
585             if make_temp:
586                 self.lib_set(category, [getattr(self, dicter)(id_)])
587         self.recalc()
588         d = {'_library': self._lib}
589         for k, v in self._fields.items():
590             # we expect everything sortable to be sorted
591             if isinstance(v, list) and k not in self._forced:
592                 # NB: if we don't test for v being list, sorted() on an empty
593                 # dict may return an empty list
594                 try:
595                     v = sorted(v)
596                 except TypeError:
597                     pass
598             d[k] = v
599         for k, v in self._forced.items():
600             d[k] = v
601         if make_temp:
602             json = json_dumps(d)
603             id_ = id_ if id_ is not None else '?'
604             self.lib_del(category, id_)
605             d = json_loads(json)
606         return d
607
608     def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
609         """From library, return item of category and id_, or empty dict."""
610         str_id = str(id_)
611         if category in self._lib and str_id in self._lib[category]:
612             return self._lib[category][str_id]
613         return {}
614
615     def lib_all(self, category: str) -> list[dict[str, Any]]:
616         """From library, return items of category, or [] if none."""
617         if category in self._lib:
618             return list(self._lib[category].values())
619         return []
620
621     def lib_set(self, category: str, items: list[dict[str, object]]) -> None:
622         """Update library for category with items."""
623         if category not in self._lib:
624             self._lib[category] = {}
625         for k, v in self._as_refs(items).items():
626             self._lib[category][k] = v
627
628     def lib_del(self, category: str, id_: str | int) -> None:
629         """Remove category element of id_ from library."""
630         del self._lib[category][str(id_)]
631         if 0 == len(self._lib[category]):
632             del self._lib[category]
633
634     def lib_wipe(self, category: str) -> None:
635         """Remove category from library."""
636         if category in self._lib:
637             del self._lib[category]
638
639     def set(self, field_name: str, value: object) -> None:
640         """Set top-level .as_dict field."""
641         self._fields[field_name] = value
642
643     def force(self, field_name: str, value: object) -> None:
644         """Set ._forced field to ensure value in .as_dict."""
645         self._forced[field_name] = value
646
647     def unforce(self, field_name: str) -> None:
648         """Unset ._forced field."""
649         del self._forced[field_name]
650
651     @staticmethod
652     def _as_refs(items: list[dict[str, object]]
653                  ) -> dict[str, dict[str, object]]:
654         """Return dictionary of items by their 'id' fields."""
655         refs = {}
656         for item in items:
657             id_ = str(item['id']) if item['id'] is not None else '?'
658             refs[id_] = item
659         return refs
660
661     @staticmethod
662     def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
663         """Return list of only 'id' fields of items."""
664         return [item['id'] for item in items]
665
666     @staticmethod
667     def day_as_dict(date: str, comment: str = '') -> dict[str, object]:
668         """Return JSON of Day to expect."""
669         return {'id': date, 'comment': comment, 'todos': []}
670
671     def set_day_from_post(self, date: str, d: dict[str, Any]) -> None:
672         """Set Day of date in library based on POST dict d."""
673         day = self.day_as_dict(date)
674         for k, v in d.items():
675             if 'day_comment' == k:
676                 day['comment'] = v
677             elif 'new_todo' == k:
678                 next_id = 1
679                 for todo in self.lib_all('Todo'):
680                     if next_id <= todo['id']:
681                         next_id = todo['id'] + 1
682                 for proc_id in sorted([id_ for id_ in v if id_]):
683                     todo = self.todo_as_dict(next_id, proc_id, date)
684                     self.lib_set('Todo', [todo])
685                     next_id += 1
686             elif 'done' == k:
687                 for todo_id in v:
688                     self.lib_get('Todo', todo_id)['is_done'] = True
689             elif 'todo_id' == k:
690                 for i, todo_id in enumerate(v):
691                     t = self.lib_get('Todo', todo_id)
692                     if 'comment' in d:
693                         t['comment'] = d['comment'][i]
694                     if 'effort' in d:
695                         effort = d['effort'][i] if d['effort'][i] else None
696                         t['effort'] = effort
697         self.lib_set('Day', [day])
698
699     @staticmethod
700     def cond_as_dict(id_: int = 1,
701                      is_active: bool = False,
702                      title: None | str = None,
703                      description: None | str = None,
704                      ) -> dict[str, object]:
705         """Return JSON of Condition to expect."""
706         versioned: dict[str, dict[str, object]]
707         versioned = {'title': {}, 'description': {}}
708         if title is not None:
709             versioned['title']['0'] = title
710         if description is not None:
711             versioned['description']['0'] = description
712         return {'id': id_, 'is_active': is_active, '_versioned': versioned}
713
714     def set_cond_from_post(self, id_: int, d: dict[str, Any]) -> None:
715         """Set Condition of id_ in library based on POST dict d."""
716         if d == {'delete': ''}:
717             self.lib_del('Condition', id_)
718             return
719         cond = self.lib_get('Condition', id_)
720         if cond:
721             if 'is_active' in d:
722                 cond['is_active'] = d['is_active']
723             for category in ['title', 'description']:
724                 history = cond['_versioned'][category]
725                 if len(history) > 0:
726                     last_i = sorted([int(k) for k in history.keys()])[-1]
727                     if d[category] != history[str(last_i)]:
728                         history[str(last_i + 1)] = d[category]
729                 else:
730                     history['0'] = d[category]
731         else:
732             cond = self.cond_as_dict(id_, **d)
733         self.lib_set('Condition', [cond])
734
735     @staticmethod
736     def todo_as_dict(id_: int = 1,
737                      process_id: int = 1,
738                      date: str = '2024-01-01',
739                      conditions: None | list[int] = None,
740                      disables: None | list[int] = None,
741                      blockers: None | list[int] = None,
742                      enables: None | list[int] = None,
743                      calendarize: bool = False,
744                      comment: str = '',
745                      is_done: bool = False,
746                      effort: float | None = None,
747                      children: list[int] | None = None,
748                      parents: list[int] | None = None,
749                      ) -> dict[str, object]:
750         """Return JSON of Todo to expect."""
751         # pylint: disable=too-many-arguments
752         d = {'id': id_,
753              'date': date,
754              'process_id': process_id,
755              'is_done': is_done,
756              'calendarize': calendarize,
757              'comment': comment,
758              'children': children if children else [],
759              'parents': parents if parents else [],
760              'effort': effort,
761              'conditions': conditions if conditions else [],
762              'disables': disables if disables else [],
763              'blockers': blockers if blockers else [],
764              'enables': enables if enables else []}
765         return d
766
767     def set_todo_from_post(self, id_: int, d: dict[str, Any]) -> None:
768         """Set Todo of id_ in library based on POST dict d."""
769         corrected_kwargs: dict[str, Any] = {'children': []}
770         for k, v in d.items():
771             if k.startswith('step_filler_to_'):
772                 continue
773             if 'adopt' == k:
774                 new_children = v if isinstance(v, list) else [v]
775                 corrected_kwargs['children'] += new_children
776                 continue
777             if k in {'is_done', 'calendarize'}:
778                 v = v in VALID_TRUES
779             corrected_kwargs[k] = v
780         todo = self.lib_get('Todo', id_)
781         if todo:
782             for k, v in corrected_kwargs.items():
783                 todo[k] = v
784         else:
785             todo = self.todo_as_dict(id_, **corrected_kwargs)
786         self.lib_set('Todo', [todo])
787
788     @staticmethod
789     def procstep_as_dict(id_: int,
790                          owner_id: int,
791                          step_process_id: int,
792                          parent_step_id: int | None = None
793                          ) -> dict[str, object]:
794         """Return JSON of ProcessStep to expect."""
795         return {'id': id_,
796                 'owner_id': owner_id,
797                 'step_process_id': step_process_id,
798                 'parent_step_id': parent_step_id}
799
800     @staticmethod
801     def proc_as_dict(id_: int = 1,
802                      title: None | str = None,
803                      description: None | str = None,
804                      effort: None | float = None,
805                      conditions: None | list[int] = None,
806                      disables: None | list[int] = None,
807                      blockers: None | list[int] = None,
808                      enables: None | list[int] = None,
809                      explicit_steps: None | list[int] = None
810                      ) -> dict[str, object]:
811         """Return JSON of Process to expect."""
812         # pylint: disable=too-many-arguments
813         versioned: dict[str, dict[str, object]]
814         versioned = {'title': {}, 'description': {}, 'effort': {}}
815         if title is not None:
816             versioned['title']['0'] = title
817         if description is not None:
818             versioned['description']['0'] = description
819         if effort is not None:
820             versioned['effort']['0'] = effort
821         d = {'id': id_,
822              'calendarize': False,
823              'suppressed_steps': [],
824              'explicit_steps': explicit_steps if explicit_steps else [],
825              '_versioned': versioned,
826              'conditions': conditions if conditions else [],
827              'disables': disables if disables else [],
828              'enables': enables if enables else [],
829              'blockers': blockers if blockers else []}
830         return d
831
832     def set_proc_from_post(self, id_: int, d: dict[str, Any]) -> None:
833         """Set Process of id_ in library based on POST dict d."""
834         proc = self.lib_get('Process', id_)
835         if proc:
836             for category in ['title', 'description', 'effort']:
837                 history = proc['_versioned'][category]
838                 if len(history) > 0:
839                     last_i = sorted([int(k) for k in history.keys()])[-1]
840                     if d[category] != history[str(last_i)]:
841                         history[str(last_i + 1)] = d[category]
842                 else:
843                     history['0'] = d[category]
844         else:
845             proc = self.proc_as_dict(id_,
846                                      d['title'], d['description'], d['effort'])
847         ignore = {'title', 'description', 'effort', 'new_top_step', 'step_of',
848                   'kept_steps'}
849         for k, v in d.items():
850             if k in ignore\
851                     or k.startswith('step_') or k.startswith('new_step_to'):
852                 continue
853             if k in {'calendarize'}:
854                 v = v in VALID_TRUES
855             elif k in {'suppressed_steps', 'explicit_steps', 'conditions',
856                        'disables', 'enables', 'blockers'}:
857                 if not isinstance(v, list):
858                     v = [v]
859             proc[k] = v
860         self.lib_set('Process', [proc])
861
862
863 class TestCaseWithServer(TestCaseWithDB):
864     """Module tests against our HTTP server/handler (and database)."""
865
866     def setUp(self) -> None:
867         super().setUp()
868         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
869         self.server_thread = Thread(target=self.httpd.serve_forever)
870         self.server_thread.daemon = True
871         self.server_thread.start()
872         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
873                                    self.httpd.server_address[1])
874         self.httpd.render_mode = 'json'
875
876     def tearDown(self) -> None:
877         self.httpd.shutdown()
878         self.httpd.server_close()
879         self.server_thread.join()
880         super().tearDown()
881
882     def post_exp_cond(self,
883                       exps: list[Expected],
884                       payload: dict[str, object],
885                       id_: int = 1,
886                       post_to_id: bool = True,
887                       redir_to_id: bool = True
888                       ) -> None:
889         """POST /condition(s), appropriately update Expecteds."""
890         # pylint: disable=too-many-arguments
891         target = f'/condition?id={id_}' if post_to_id else '/condition'
892         redir = f'/condition?id={id_}' if redir_to_id else '/conditions'
893         self.check_post(payload, target, redir=redir)
894         for exp in exps:
895             exp.set_cond_from_post(id_, payload)
896
897     def post_exp_day(self,
898                      exps: list[Expected],
899                      payload: dict[str, Any],
900                      date: str = '2024-01-01'
901                      ) -> None:
902         """POST /day, appropriately update Expecteds."""
903         if 'make_type' not in payload:
904             payload['make_type'] = 'empty'
905         if 'day_comment' not in payload:
906             payload['day_comment'] = ''
907         target = f'/day?date={date}'
908         redir_to = f'{target}&make_type={payload["make_type"]}'
909         self.check_post(payload, target, 302, redir_to)
910         for exp in exps:
911             exp.set_day_from_post(date, payload)
912
913     def post_exp_process(self,
914                          exps: list[Expected],
915                          payload: dict[str, Any],
916                          id_: int,
917                          ) -> dict[str, object]:
918         """POST /process, appropriately update Expecteds."""
919         if 'title' not in payload:
920             payload['title'] = 'foo'
921         if 'description' not in payload:
922             payload['description'] = 'foo'
923         if 'effort' not in payload:
924             payload['effort'] = 1.1
925         self.check_post(payload, f'/process?id={id_}',
926                         redir=f'/process?id={id_}')
927         for exp in exps:
928             exp.set_proc_from_post(id_, payload)
929         return payload
930
931     def check_filter(self, exp: Expected, category: str, key: str,
932                      val: str, list_ids: list[int]) -> None:
933         """Check GET /{category}?{key}={val} sorts to list_ids."""
934         # pylint: disable=too-many-arguments
935         exp.set(key, val)
936         exp.force(category, list_ids)
937         self.check_json_get(f'/{category}?{key}={val}', exp)
938
939     def check_redirect(self, target: str) -> None:
940         """Check that self.conn answers with a 302 redirect to target."""
941         response = self.conn.getresponse()
942         self.assertEqual(response.status, 302)
943         self.assertEqual(response.getheader('Location'), target)
944
945     def check_get(self, target: str, expected_code: int) -> None:
946         """Check that a GET to target yields expected_code."""
947         self.conn.request('GET', target)
948         self.assertEqual(self.conn.getresponse().status, expected_code)
949
950     def check_minimal_inputs(self,
951                              url: str,
952                              minimal_inputs: dict[str, Any]
953                              ) -> None:
954         """Check that url 400's unless all of minimal_inputs provided."""
955         for to_hide in minimal_inputs.keys():
956             to_post = {k: v for k, v in minimal_inputs.items() if k != to_hide}
957             self.check_post(to_post, url, 400)
958
959     def check_post(self, data: Mapping[str, object], target: str,
960                    expected_code: int = 302, redir: str = '') -> None:
961         """Check that POST of data to target yields expected_code."""
962         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
963         headers = {'Content-Type': 'application/x-www-form-urlencoded',
964                    'Content-Length': str(len(encoded_form_data))}
965         self.conn.request('POST', target,
966                           body=encoded_form_data, headers=headers)
967         if 302 == expected_code:
968             redir = target if redir == '' else redir
969             self.check_redirect(redir)
970         else:
971             self.assertEqual(self.conn.getresponse().status, expected_code)
972
973     def check_get_defaults(self,
974                            path: str,
975                            default_id: str = '1',
976                            id_name: str = 'id'
977                            ) -> None:
978         """Some standard model paths to test."""
979         nonexist_status = 200 if self.checked_class.can_create_by_id else 404
980         self.check_get(path, nonexist_status)
981         self.check_get(f'{path}?{id_name}=', 400)
982         self.check_get(f'{path}?{id_name}=foo', 400)
983         self.check_get(f'/{path}?{id_name}=0', 400)
984         self.check_get(f'{path}?{id_name}={default_id}', nonexist_status)
985
986     def check_json_get(self, path: str, expected: Expected) -> None:
987         """Compare JSON on GET path with expected.
988
989         To simplify comparison of VersionedAttribute histories, transforms
990         timestamp keys of VersionedAttribute history keys into (strings of)
991         integers counting chronologically forward from 0.
992         """
993
994         def rewrite_history_keys_in(item: Any) -> Any:
995             if isinstance(item, dict):
996                 if '_versioned' in item.keys():
997                     for category in item['_versioned']:
998                         vals = item['_versioned'][category].values()
999                         history = {}
1000                         for i, val in enumerate(vals):
1001                             history[str(i)] = val
1002                         item['_versioned'][category] = history
1003                 for category in list(item.keys()):
1004                     rewrite_history_keys_in(item[category])
1005             elif isinstance(item, list):
1006                 item[:] = [rewrite_history_keys_in(i) for i in item]
1007             return item
1008
1009         def walk_diffs(path: str, cmp1: object, cmp2: object) -> None:
1010             # pylint: disable=too-many-branches
1011             def warn(intro: str, val: object) -> None:
1012                 if isinstance(val, (str, int, float)):
1013                     print(intro, val)
1014                 else:
1015                     print(intro)
1016                     pprint(val)
1017             if cmp1 != cmp2:
1018                 if isinstance(cmp1, dict) and isinstance(cmp2, dict):
1019                     for k, v in cmp1.items():
1020                         if k not in cmp2:
1021                             warn(f'DIFF {path}: retrieved lacks {k}', v)
1022                         elif v != cmp2[k]:
1023                             walk_diffs(f'{path}:{k}', v, cmp2[k])
1024                     for k in [k for k in cmp2.keys() if k not in cmp1]:
1025                         warn(f'DIFF {path}: expected lacks retrieved\'s {k}',
1026                              cmp2[k])
1027                 elif isinstance(cmp1, list) and isinstance(cmp2, list):
1028                     for i, v1 in enumerate(cmp1):
1029                         if i >= len(cmp2):
1030                             warn(f'DIFF {path}[{i}] retrieved misses:', v1)
1031                         elif v1 != cmp2[i]:
1032                             walk_diffs(f'{path}[{i}]', v1, cmp2[i])
1033                     if len(cmp2) > len(cmp1):
1034                         for i, v2 in enumerate(cmp2[len(cmp1):]):
1035                             warn(f'DIFF {path}[{len(cmp1)+i}] misses:', v2)
1036                 else:
1037                     warn(f'DIFF {path} – for expected:', cmp1)
1038                     warn('… and for retrieved:', cmp2)
1039
1040         self.conn.request('GET', path)
1041         response = self.conn.getresponse()
1042         self.assertEqual(response.status, 200)
1043         retrieved = json_loads(response.read().decode())
1044         rewrite_history_keys_in(retrieved)
1045         cmp = expected.as_dict
1046         try:
1047             self.assertEqual(cmp, retrieved)
1048         except AssertionError as e:
1049             print('EXPECTED:')
1050             pprint(cmp)
1051             print('RETRIEVED:')
1052             pprint(retrieved)
1053             walk_diffs('', cmp, retrieved)
1054             raise e