home · contact · privacy
Re-organize and extend/improve POST/GET /day tests.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
21     def wrapper(self: TestCase) -> None:
22         if hasattr(self, 'checked_class'):
23             f(self)
24     return wrapper
25
26
27 class TestCaseSansDB(TestCase):
28     """Tests requiring no DB setup."""
29     checked_class: Any
30     default_init_args: list[Any] = []
31     versioned_defaults_to_test: dict[str, str | float] = {}
32     legal_ids = [1, 5]
33     illegal_ids = [0]
34
35     @_within_checked_class
36     def test_id_validation(self) -> None:
37         """Test .id_ validation/setting."""
38         for id_ in self.illegal_ids:
39             with self.assertRaises(HandledException):
40                 self.checked_class(id_, *self.default_init_args)
41         for id_ in self.legal_ids:
42             obj = self.checked_class(id_, *self.default_init_args)
43             self.assertEqual(obj.id_, id_)
44
45     @_within_checked_class
46     def test_versioned_defaults(self) -> None:
47         """Test defaults of VersionedAttributes."""
48         id_ = self.legal_ids[0]
49         obj = self.checked_class(id_, *self.default_init_args)
50         for k, v in self.versioned_defaults_to_test.items():
51             self.assertEqual(getattr(obj, k).newest, v)
52
53
54 class TestCaseWithDB(TestCase):
55     """Module tests not requiring DB setup."""
56     checked_class: Any
57     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
58     default_init_kwargs: dict[str, Any] = {}
59     test_versioneds: dict[str, type] = {}
60
61     def setUp(self) -> None:
62         Condition.empty_cache()
63         Day.empty_cache()
64         Process.empty_cache()
65         ProcessStep.empty_cache()
66         Todo.empty_cache()
67         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
68         self.db_conn = DatabaseConnection(self.db_file)
69
70     def tearDown(self) -> None:
71         self.db_conn.close()
72         remove_file(self.db_file.path)
73
74     def _load_from_db(self, id_: int | str) -> list[object]:
75         db_found: list[object] = []
76         for row in self.db_conn.row_where(self.checked_class.table_name,
77                                           'id', id_):
78             db_found += [self.checked_class.from_table_row(self.db_conn,
79                                                            row)]
80         return db_found
81
82     def _change_obj(self, obj: object) -> str:
83         attr_name: str = self.checked_class.to_save[-1]
84         attr = getattr(obj, attr_name)
85         new_attr: str | int | float | bool
86         if isinstance(attr, (int, float)):
87             new_attr = attr + 1
88         elif isinstance(attr, str):
89             new_attr = attr + '_'
90         elif isinstance(attr, bool):
91             new_attr = not attr
92         setattr(obj, attr_name, new_attr)
93         return attr_name
94
95     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
96         """Test both cache and DB equal content."""
97         expected_cache = {}
98         for item in content:
99             expected_cache[item.id_] = item
100         self.assertEqual(self.checked_class.get_cache(), expected_cache)
101         hashes_content = [hash(x) for x in content]
102         db_found: list[Any] = []
103         for item in content:
104             assert isinstance(item.id_, type(self.default_ids[0]))
105             db_found += self._load_from_db(item.id_)
106         hashes_db_found = [hash(x) for x in db_found]
107         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
108
109     @_within_checked_class
110     def test_saving_versioned(self) -> None:
111         """Test storage and initialization of versioned attributes."""
112         def retrieve_attr_vals() -> list[object]:
113             attr_vals_saved: list[object] = []
114             assert hasattr(retrieved, 'id_')
115             for row in self.db_conn.row_where(attr.table_name, 'parent',
116                                               retrieved.id_):
117                 attr_vals_saved += [row[2]]
118             return attr_vals_saved
119         for attr_name, type_ in self.test_versioneds.items():
120             # fail saving attributes on non-saved owner
121             owner = self.checked_class(None, **self.default_init_kwargs)
122             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
123             attr = getattr(owner, attr_name)
124             attr.set(vals[0])
125             attr.set(vals[1])
126             with self.assertRaises(NotFoundException):
127                 attr.save(self.db_conn)
128             owner.save(self.db_conn)
129             # check stored attribute is as expected
130             retrieved = self._load_from_db(owner.id_)[0]
131             attr = getattr(retrieved, attr_name)
132             self.assertEqual(sorted(attr.history.values()), vals)
133             # check owner.save() created entries in attr table
134             attr_vals_saved = retrieve_attr_vals()
135             self.assertEqual(vals, attr_vals_saved)
136             # check setting new val to attr inconsequential to DB without save
137             attr.set(vals[0])
138             attr_vals_saved = retrieve_attr_vals()
139             self.assertEqual(vals, attr_vals_saved)
140             # check save finally adds new val
141             attr.save(self.db_conn)
142             attr_vals_saved = retrieve_attr_vals()
143             self.assertEqual(vals + [vals[0]], attr_vals_saved)
144
145     @_within_checked_class
146     def test_saving_and_caching(self) -> None:
147         """Test effects of .cache() and .save()."""
148         id1 = self.default_ids[0]
149         # check failure to cache without ID (if None-ID input possible)
150         if isinstance(id1, int):
151             obj0 = self.checked_class(None, **self.default_init_kwargs)
152             with self.assertRaises(HandledException):
153                 obj0.cache()
154         # check mere object init itself doesn't even store in cache
155         obj1 = self.checked_class(id1, **self.default_init_kwargs)
156         self.assertEqual(self.checked_class.get_cache(), {})
157         # check .cache() fills cache, but not DB
158         obj1.cache()
159         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
160         db_found = self._load_from_db(id1)
161         self.assertEqual(db_found, [])
162         # check .save() sets ID (for int IDs), updates cache, and fills DB
163         # (expect ID to be set to id1, despite obj1 already having that as ID:
164         # it's generated by cursor.lastrowid on the DB table, and with obj1
165         # not written there, obj2 should get it first!)
166         id_input = None if isinstance(id1, int) else id1
167         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
168         obj2.save(self.db_conn)
169         obj2_hash = hash(obj2)
170         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
171         db_found += self._load_from_db(id1)
172         self.assertEqual([hash(o) for o in db_found], [obj2_hash])
173         # check we cannot overwrite obj2 with obj1 despite its same ID,
174         # since it has disappeared now
175         with self.assertRaises(HandledException):
176             obj1.save(self.db_conn)
177
178     @_within_checked_class
179     def test_by_id(self) -> None:
180         """Test .by_id()."""
181         id1, id2, _ = self.default_ids
182         # check failure if not yet saved
183         obj1 = self.checked_class(id1, **self.default_init_kwargs)
184         with self.assertRaises(NotFoundException):
185             self.checked_class.by_id(self.db_conn, id1)
186         # check identity of cached and retrieved
187         obj1.cache()
188         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
189         # check identity of saved and retrieved
190         obj2 = self.checked_class(id2, **self.default_init_kwargs)
191         obj2.save(self.db_conn)
192         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
193
194     @_within_checked_class
195     def test_by_id_or_create(self) -> None:
196         """Test .by_id_or_create."""
197         # check .by_id_or_create fails if wrong class
198         if not self.checked_class.can_create_by_id:
199             with self.assertRaises(HandledException):
200                 self.checked_class.by_id_or_create(self.db_conn, None)
201             return
202         # check ID input of None creates, on saving, ID=1,2,… for int IDs
203         if isinstance(self.default_ids[0], int):
204             for n in range(2):
205                 item = self.checked_class.by_id_or_create(self.db_conn, None)
206                 self.assertEqual(item.id_, None)
207                 item.save(self.db_conn)
208                 self.assertEqual(item.id_, n+1)
209         # check .by_id_or_create acts like normal instantiation (sans saving)
210         id_ = self.default_ids[2]
211         item = self.checked_class.by_id_or_create(self.db_conn, id_)
212         self.assertEqual(item.id_, id_)
213         with self.assertRaises(NotFoundException):
214             self.checked_class.by_id(self.db_conn, item.id_)
215         self.assertEqual(self.checked_class(item.id_), item)
216
217     @_within_checked_class
218     def test_from_table_row(self) -> None:
219         """Test .from_table_row() properly reads in class directly from DB."""
220         id_ = self.default_ids[0]
221         obj = self.checked_class(id_, **self.default_init_kwargs)
222         obj.save(self.db_conn)
223         assert isinstance(obj.id_, type(id_))
224         for row in self.db_conn.row_where(self.checked_class.table_name,
225                                           'id', obj.id_):
226             # check .from_table_row reproduces state saved, no matter if obj
227             # later changed (with caching even)
228             hash_original = hash(obj)
229             attr_name = self._change_obj(obj)
230             obj.cache()
231             to_cmp = getattr(obj, attr_name)
232             retrieved = self.checked_class.from_table_row(self.db_conn, row)
233             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
234             self.assertEqual(hash_original, hash(retrieved))
235             # check cache contains what .from_table_row just produced
236             self.assertEqual({retrieved.id_: retrieved},
237                              self.checked_class.get_cache())
238         # check .from_table_row also reads versioned attributes from DB
239         for attr_name, type_ in self.test_versioneds.items():
240             owner = self.checked_class(None)
241             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
242             attr = getattr(owner, attr_name)
243             attr.set(vals[0])
244             attr.set(vals[1])
245             owner.save(self.db_conn)
246             for row in self.db_conn.row_where(owner.table_name, 'id',
247                                               owner.id_):
248                 retrieved = owner.__class__.from_table_row(self.db_conn, row)
249                 attr = getattr(retrieved, attr_name)
250                 self.assertEqual(sorted(attr.history.values()), vals)
251
252     @_within_checked_class
253     def test_all(self) -> None:
254         """Test .all() and its relation to cache and savings."""
255         id_1, id_2, id_3 = self.default_ids
256         item1 = self.checked_class(id_1, **self.default_init_kwargs)
257         item2 = self.checked_class(id_2, **self.default_init_kwargs)
258         item3 = self.checked_class(id_3, **self.default_init_kwargs)
259         # check .all() returns empty list on un-cached items
260         self.assertEqual(self.checked_class.all(self.db_conn), [])
261         # check that all() shows only cached/saved items
262         item1.cache()
263         item3.save(self.db_conn)
264         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
265                          sorted([item1, item3]))
266         item2.save(self.db_conn)
267         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
268                          sorted([item1, item2, item3]))
269
270     @_within_checked_class
271     def test_singularity(self) -> None:
272         """Test pointers made for single object keep pointing to it."""
273         id1 = self.default_ids[0]
274         obj = self.checked_class(id1, **self.default_init_kwargs)
275         obj.save(self.db_conn)
276         # change object, expect retrieved through .by_id to carry change
277         attr_name = self._change_obj(obj)
278         new_attr = getattr(obj, attr_name)
279         retrieved = self.checked_class.by_id(self.db_conn, id1)
280         self.assertEqual(new_attr, getattr(retrieved, attr_name))
281
282     @_within_checked_class
283     def test_versioned_singularity_title(self) -> None:
284         """Test singularity of VersionedAttributes on saving (with .title)."""
285         if 'title' in self.test_versioneds:
286             obj = self.checked_class(None)
287             obj.save(self.db_conn)
288             assert isinstance(obj.id_, int)
289             # change obj, expect retrieved through .by_id to carry change
290             obj.title.set('named')
291             retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
292             self.assertEqual(obj.title.history, retrieved.title.history)
293
294     @_within_checked_class
295     def test_remove(self) -> None:
296         """Test .remove() effects on DB and cache."""
297         id_ = self.default_ids[0]
298         obj = self.checked_class(id_, **self.default_init_kwargs)
299         # check removal only works after saving
300         with self.assertRaises(HandledException):
301             obj.remove(self.db_conn)
302         obj.save(self.db_conn)
303         obj.remove(self.db_conn)
304         # check access to obj fails after removal
305         with self.assertRaises(HandledException):
306             print(obj.id_)
307         # check DB and cache now empty
308         self.check_identity_with_cache_and_db([])
309
310
311 class TestCaseWithServer(TestCaseWithDB):
312     """Module tests against our HTTP server/handler (and database)."""
313
314     def setUp(self) -> None:
315         super().setUp()
316         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
317         self.server_thread = Thread(target=self.httpd.serve_forever)
318         self.server_thread.daemon = True
319         self.server_thread.start()
320         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
321                                    self.httpd.server_address[1])
322         self.httpd.set_json_mode()
323
324     def tearDown(self) -> None:
325         self.httpd.shutdown()
326         self.httpd.server_close()
327         self.server_thread.join()
328         super().tearDown()
329
330     @staticmethod
331     def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
332         """Return list of only 'id' fields of items."""
333         id_list = []
334         for item in items:
335             assert isinstance(item['id'], (int, str))
336             id_list += [item['id']]
337         return id_list
338
339     @staticmethod
340     def as_refs(items: list[dict[str, object]]
341                 ) -> dict[str, dict[str, object]]:
342         """Return dictionary of items by their 'id' fields."""
343         refs = {}
344         for item in items:
345             refs[str(item['id'])] = item
346         return refs
347
348     @staticmethod
349     def cond_as_dict(id_: int = 1,
350                      is_active: bool = False,
351                      titles: None | list[str] = None,
352                      descriptions: None | list[str] = None
353                      ) -> dict[str, object]:
354         """Return JSON of Condition to expect."""
355         d = {'id': id_,
356              'is_active': is_active,
357              '_versioned': {
358                  'title': {},
359                  'description': {}}}
360         titles = titles if titles else []
361         descriptions = descriptions if descriptions else []
362         assert isinstance(d['_versioned'], dict)
363         for i, title in enumerate(titles):
364             d['_versioned']['title'][i] = title
365         for i, description in enumerate(descriptions):
366             d['_versioned']['description'][i] = description
367         return d
368
369     @staticmethod
370     def proc_as_dict(id_: int = 1,
371                      title: str = 'A',
372                      description: str = '',
373                      effort: float = 1.0,
374                      conditions: None | list[int] = None,
375                      disables: None | list[int] = None,
376                      blockers: None | list[int] = None,
377                      enables: None | list[int] = None
378                      ) -> dict[str, object]:
379         """Return JSON of Process to expect."""
380         # pylint: disable=too-many-arguments
381         d = {'id': id_,
382              'calendarize': False,
383              'suppressed_steps': [],
384              'explicit_steps': [],
385              '_versioned': {
386                  'title': {0: title},
387                  'description': {0: description},
388                  'effort': {0: effort}},
389              'conditions': conditions if conditions else [],
390              'disables': disables if disables else [],
391              'enables': enables if enables else [],
392              'blockers': blockers if blockers else []}
393         return d
394
395     def check_redirect(self, target: str) -> None:
396         """Check that self.conn answers with a 302 redirect to target."""
397         response = self.conn.getresponse()
398         self.assertEqual(response.status, 302)
399         self.assertEqual(response.getheader('Location'), target)
400
401     def check_get(self, target: str, expected_code: int) -> None:
402         """Check that a GET to target yields expected_code."""
403         self.conn.request('GET', target)
404         self.assertEqual(self.conn.getresponse().status, expected_code)
405
406     def check_post(self, data: Mapping[str, object], target: str,
407                    expected_code: int, redirect_location: str = '') -> None:
408         """Check that POST of data to target yields expected_code."""
409         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
410         headers = {'Content-Type': 'application/x-www-form-urlencoded',
411                    'Content-Length': str(len(encoded_form_data))}
412         self.conn.request('POST', target,
413                           body=encoded_form_data, headers=headers)
414         if 302 == expected_code:
415             if redirect_location == '':
416                 redirect_location = target
417             self.check_redirect(redirect_location)
418         else:
419             self.assertEqual(self.conn.getresponse().status, expected_code)
420
421     def check_get_defaults(self, path: str) -> None:
422         """Some standard model paths to test."""
423         self.check_get(path, 200)
424         self.check_get(f'{path}?id=', 200)
425         self.check_get(f'{path}?id=foo', 400)
426         self.check_get(f'/{path}?id=0', 500)
427         self.check_get(f'{path}?id=1', 200)
428
429     def post_process(self, id_: int = 1,
430                      form_data: dict[str, Any] | None = None
431                      ) -> dict[str, Any]:
432         """POST basic Process."""
433         if not form_data:
434             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
435         self.check_post(form_data, f'/process?id={id_}', 302,
436                         f'/process?id={id_}')
437         return form_data
438
439     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
440         """Compare JSON on GET path with expected.
441
442         To simplify comparison of VersionedAttribute histories, transforms
443         timestamp keys of VersionedAttribute history keys into integers
444         counting chronologically forward from 0.
445         """
446         def rewrite_history_keys_in(item: Any) -> Any:
447             if isinstance(item, dict):
448                 if '_versioned' in item.keys():
449                     for k in item['_versioned']:
450                         vals = item['_versioned'][k].values()
451                         history = {}
452                         for i, val in enumerate(vals):
453                             history[i] = val
454                         item['_versioned'][k] = history
455                 for k in list(item.keys()):
456                     rewrite_history_keys_in(item[k])
457             elif isinstance(item, list):
458                 item[:] = [rewrite_history_keys_in(i) for i in item]
459             return item
460         self.conn.request('GET', path)
461         response = self.conn.getresponse()
462         self.assertEqual(response.status, 200)
463         retrieved = json_loads(response.read().decode())
464         rewrite_history_keys_in(retrieved)
465         self.assertEqual(expected, retrieved)