home · contact · privacy
Some tests refactoring.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
21     def wrapper(self: TestCase) -> None:
22         if hasattr(self, 'checked_class'):
23             f(self)
24     return wrapper
25
26
27 class TestCaseSansDB(TestCase):
28     """Tests requiring no DB setup."""
29     checked_class: Any
30     default_init_args: list[Any] = []
31     versioned_defaults_to_test: dict[str, str | float] = {}
32     legal_ids = [1, 5]
33     illegal_ids = [0]
34
35     @_within_checked_class
36     def test_id_validation(self) -> None:
37         """Test .id_ validation/setting."""
38         for id_ in self.illegal_ids:
39             with self.assertRaises(HandledException):
40                 self.checked_class(id_, *self.default_init_args)
41         for id_ in self.legal_ids:
42             obj = self.checked_class(id_, *self.default_init_args)
43             self.assertEqual(obj.id_, id_)
44
45     @_within_checked_class
46     def test_versioned_defaults(self) -> None:
47         """Test defaults of VersionedAttributes."""
48         id_ = self.legal_ids[0]
49         obj = self.checked_class(id_, *self.default_init_args)
50         for k, v in self.versioned_defaults_to_test.items():
51             self.assertEqual(getattr(obj, k).newest, v)
52
53
54 class TestCaseWithDB(TestCase):
55     """Module tests not requiring DB setup."""
56     checked_class: Any
57     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
58     default_init_kwargs: dict[str, Any] = {}
59     test_versioneds: dict[str, type] = {}
60
61     def setUp(self) -> None:
62         Condition.empty_cache()
63         Day.empty_cache()
64         Process.empty_cache()
65         ProcessStep.empty_cache()
66         Todo.empty_cache()
67         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
68         self.db_conn = DatabaseConnection(self.db_file)
69
70     def tearDown(self) -> None:
71         self.db_conn.close()
72         remove_file(self.db_file.path)
73
74     def _load_from_db(self, id_: int | str) -> list[object]:
75         db_found: list[object] = []
76         for row in self.db_conn.row_where(self.checked_class.table_name,
77                                           'id', id_):
78             db_found += [self.checked_class.from_table_row(self.db_conn,
79                                                            row)]
80         return db_found
81
82     def _change_obj(self, obj: object) -> str:
83         attr_name: str = self.checked_class.to_save[-1]
84         attr = getattr(obj, attr_name)
85         new_attr: str | int | float | bool
86         if isinstance(attr, (int, float)):
87             new_attr = attr + 1
88         elif isinstance(attr, str):
89             new_attr = attr + '_'
90         elif isinstance(attr, bool):
91             new_attr = not attr
92         setattr(obj, attr_name, new_attr)
93         return attr_name
94
95     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
96         """Test both cache and DB equal content."""
97         expected_cache = {}
98         for item in content:
99             expected_cache[item.id_] = item
100         self.assertEqual(self.checked_class.get_cache(), expected_cache)
101         hashes_content = [hash(x) for x in content]
102         db_found: list[Any] = []
103         for item in content:
104             assert isinstance(item.id_, type(self.default_ids[0]))
105             db_found += self._load_from_db(item.id_)
106         hashes_db_found = [hash(x) for x in db_found]
107         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
108
109     @_within_checked_class
110     def test_saving_versioned(self) -> None:
111         """Test storage and initialization of versioned attributes."""
112         def retrieve_attr_vals() -> list[object]:
113             attr_vals_saved: list[object] = []
114             assert hasattr(retrieved, 'id_')
115             for row in self.db_conn.row_where(attr.table_name, 'parent',
116                                               retrieved.id_):
117                 attr_vals_saved += [row[2]]
118             return attr_vals_saved
119         for attr_name, type_ in self.test_versioneds.items():
120             # fail saving attributes on non-saved owner
121             owner = self.checked_class(None, **self.default_init_kwargs)
122             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
123             attr = getattr(owner, attr_name)
124             attr.set(vals[0])
125             attr.set(vals[1])
126             with self.assertRaises(NotFoundException):
127                 attr.save(self.db_conn)
128             owner.save(self.db_conn)
129             # check stored attribute is as expected
130             retrieved = self._load_from_db(owner.id_)[0]
131             attr = getattr(retrieved, attr_name)
132             self.assertEqual(sorted(attr.history.values()), vals)
133             # check owner.save() created entries in attr table
134             attr_vals_saved = retrieve_attr_vals()
135             self.assertEqual(vals, attr_vals_saved)
136             # check setting new val to attr inconsequential to DB without save
137             attr.set(vals[0])
138             attr_vals_saved = retrieve_attr_vals()
139             self.assertEqual(vals, attr_vals_saved)
140             # check save finally adds new val
141             attr.save(self.db_conn)
142             attr_vals_saved = retrieve_attr_vals()
143             self.assertEqual(vals + [vals[0]], attr_vals_saved)
144
145     @_within_checked_class
146     def test_saving_and_caching(self) -> None:
147         """Test effects of .cache() and .save()."""
148         id1 = self.default_ids[0]
149         # check failure to cache without ID (if None-ID input possible)
150         if isinstance(id1, int):
151             obj0 = self.checked_class(None, **self.default_init_kwargs)
152             with self.assertRaises(HandledException):
153                 obj0.cache()
154         # check mere object init itself doesn't even store in cache
155         obj1 = self.checked_class(id1, **self.default_init_kwargs)
156         self.assertEqual(self.checked_class.get_cache(), {})
157         # check .cache() fills cache, but not DB
158         obj1.cache()
159         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
160         db_found = self._load_from_db(id1)
161         self.assertEqual(db_found, [])
162         # check .save() sets ID (for int IDs), updates cache, and fills DB
163         # (expect ID to be set to id1, despite obj1 already having that as ID:
164         # it's generated by cursor.lastrowid on the DB table, and with obj1
165         # not written there, obj2 should get it first!)
166         id_input = None if isinstance(id1, int) else id1
167         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
168         obj2.save(self.db_conn)
169         obj2_hash = hash(obj2)
170         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
171         db_found += self._load_from_db(id1)
172         self.assertEqual([hash(o) for o in db_found], [obj2_hash])
173         # check we cannot overwrite obj2 with obj1 despite its same ID,
174         # since it has disappeared now
175         with self.assertRaises(HandledException):
176             obj1.save(self.db_conn)
177
178     @_within_checked_class
179     def test_by_id(self) -> None:
180         """Test .by_id()."""
181         id1, id2, _ = self.default_ids
182         # check failure if not yet saved
183         obj1 = self.checked_class(id1, **self.default_init_kwargs)
184         with self.assertRaises(NotFoundException):
185             self.checked_class.by_id(self.db_conn, id1)
186         # check identity of cached and retrieved
187         obj1.cache()
188         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
189         # check identity of saved and retrieved
190         obj2 = self.checked_class(id2, **self.default_init_kwargs)
191         obj2.save(self.db_conn)
192         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
193
194     @_within_checked_class
195     def test_by_id_or_create(self) -> None:
196         """Test .by_id_or_create."""
197         # check .by_id_or_create fails if wrong class
198         if not self.checked_class.can_create_by_id:
199             with self.assertRaises(HandledException):
200                 self.checked_class.by_id_or_create(self.db_conn, None)
201             return
202         # check ID input of None creates, on saving, ID=1,2,… for int IDs
203         if isinstance(self.default_ids[0], int):
204             for n in range(2):
205                 item = self.checked_class.by_id_or_create(self.db_conn, None)
206                 self.assertEqual(item.id_, None)
207                 item.save(self.db_conn)
208                 self.assertEqual(item.id_, n+1)
209         # check .by_id_or_create acts like normal instantiation (sans saving)
210         id_ = self.default_ids[2]
211         item = self.checked_class.by_id_or_create(self.db_conn, id_)
212         self.assertEqual(item.id_, id_)
213         with self.assertRaises(NotFoundException):
214             self.checked_class.by_id(self.db_conn, item.id_)
215         self.assertEqual(self.checked_class(item.id_), item)
216
217     @_within_checked_class
218     def test_from_table_row(self) -> None:
219         """Test .from_table_row() properly reads in class directly from DB."""
220         id_ = self.default_ids[0]
221         obj = self.checked_class(id_, **self.default_init_kwargs)
222         obj.save(self.db_conn)
223         assert isinstance(obj.id_, type(id_))
224         for row in self.db_conn.row_where(self.checked_class.table_name,
225                                           'id', obj.id_):
226             # check .from_table_row reproduces state saved, no matter if obj
227             # later changed (with caching even)
228             hash_original = hash(obj)
229             attr_name = self._change_obj(obj)
230             obj.cache()
231             to_cmp = getattr(obj, attr_name)
232             retrieved = self.checked_class.from_table_row(self.db_conn, row)
233             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
234             self.assertEqual(hash_original, hash(retrieved))
235             # check cache contains what .from_table_row just produced
236             self.assertEqual({retrieved.id_: retrieved},
237                              self.checked_class.get_cache())
238         # check .from_table_row also reads versioned attributes from DB
239         for attr_name, type_ in self.test_versioneds.items():
240             owner = self.checked_class(None)
241             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
242             attr = getattr(owner, attr_name)
243             attr.set(vals[0])
244             attr.set(vals[1])
245             owner.save(self.db_conn)
246             for row in self.db_conn.row_where(owner.table_name, 'id',
247                                               owner.id_):
248                 retrieved = owner.__class__.from_table_row(self.db_conn, row)
249                 attr = getattr(retrieved, attr_name)
250                 self.assertEqual(sorted(attr.history.values()), vals)
251
252     @_within_checked_class
253     def test_all(self) -> None:
254         """Test .all() and its relation to cache and savings."""
255         id_1, id_2, id_3 = self.default_ids
256         item1 = self.checked_class(id_1, **self.default_init_kwargs)
257         item2 = self.checked_class(id_2, **self.default_init_kwargs)
258         item3 = self.checked_class(id_3, **self.default_init_kwargs)
259         # check .all() returns empty list on un-cached items
260         self.assertEqual(self.checked_class.all(self.db_conn), [])
261         # check that all() shows only cached/saved items
262         item1.cache()
263         item3.save(self.db_conn)
264         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
265                          sorted([item1, item3]))
266         item2.save(self.db_conn)
267         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
268                          sorted([item1, item2, item3]))
269
270     @_within_checked_class
271     def test_singularity(self) -> None:
272         """Test pointers made for single object keep pointing to it."""
273         id1 = self.default_ids[0]
274         obj = self.checked_class(id1, **self.default_init_kwargs)
275         obj.save(self.db_conn)
276         # change object, expect retrieved through .by_id to carry change
277         attr_name = self._change_obj(obj)
278         new_attr = getattr(obj, attr_name)
279         retrieved = self.checked_class.by_id(self.db_conn, id1)
280         self.assertEqual(new_attr, getattr(retrieved, attr_name))
281
282     @_within_checked_class
283     def test_versioned_singularity_title(self) -> None:
284         """Test singularity of VersionedAttributes on saving (with .title)."""
285         if 'title' in self.test_versioneds:
286             obj = self.checked_class(None)
287             obj.save(self.db_conn)
288             assert isinstance(obj.id_, int)
289             # change obj, expect retrieved through .by_id to carry change
290             obj.title.set('named')
291             retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
292             self.assertEqual(obj.title.history, retrieved.title.history)
293
294     @_within_checked_class
295     def test_remove(self) -> None:
296         """Test .remove() effects on DB and cache."""
297         id_ = self.default_ids[0]
298         obj = self.checked_class(id_, **self.default_init_kwargs)
299         # check removal only works after saving
300         with self.assertRaises(HandledException):
301             obj.remove(self.db_conn)
302         obj.save(self.db_conn)
303         obj.remove(self.db_conn)
304         # check access to obj fails after removal
305         with self.assertRaises(HandledException):
306             print(obj.id_)
307         # check DB and cache now empty
308         self.check_identity_with_cache_and_db([])
309
310
311 class TestCaseWithServer(TestCaseWithDB):
312     """Module tests against our HTTP server/handler (and database)."""
313
314     def setUp(self) -> None:
315         super().setUp()
316         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
317         self.server_thread = Thread(target=self.httpd.serve_forever)
318         self.server_thread.daemon = True
319         self.server_thread.start()
320         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
321                                    self.httpd.server_address[1])
322         self.httpd.set_json_mode()
323
324     def tearDown(self) -> None:
325         self.httpd.shutdown()
326         self.httpd.server_close()
327         self.server_thread.join()
328         super().tearDown()
329
330     @staticmethod
331     def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
332         """Return list of only 'id' fields of items."""
333         id_list = []
334         for item in items:
335             assert isinstance(item['id'], (int, str))
336             id_list += [item['id']]
337         return id_list
338
339     @staticmethod
340     def as_refs(items: list[dict[str, object]]
341                 ) -> dict[str, dict[str, object]]:
342         """Return dictionary of items by their 'id' fields."""
343         refs = {}
344         for item in items:
345             refs[str(item['id'])] = item
346         return refs
347
348     @staticmethod
349     def proc_as_dict(id_: int = 1,
350                      title: str = 'A',
351                      description: str = '',
352                      effort: float = 1.0,
353                      enables: None | list[dict[str, object]] = None,
354                      disables: None | list[dict[str, object]] = None,
355                      conditions: None | list[dict[str, object]] = None,
356                      blockers: None | list[dict[str, object]] = None
357                      ) -> dict[str, object]:
358         """Return JSON of Process to expect."""
359         # pylint: disable=too-many-arguments
360         as_id_list = TestCaseWithServer.as_id_list
361         d = {'id': id_,
362              'calendarize': False,
363              'suppressed_steps': [],
364              'explicit_steps': [],
365              '_versioned': {
366                  'title': {0: title},
367                  'description': {0: description},
368                  'effort': {0: effort}},
369              'conditions': as_id_list(conditions) if conditions else [],
370              'disables': as_id_list(disables) if disables else [],
371              'enables': as_id_list(enables) if enables else [],
372              'blockers': as_id_list(blockers) if blockers else []}
373         return d
374
375     def check_redirect(self, target: str) -> None:
376         """Check that self.conn answers with a 302 redirect to target."""
377         response = self.conn.getresponse()
378         self.assertEqual(response.status, 302)
379         self.assertEqual(response.getheader('Location'), target)
380
381     def check_get(self, target: str, expected_code: int) -> None:
382         """Check that a GET to target yields expected_code."""
383         self.conn.request('GET', target)
384         self.assertEqual(self.conn.getresponse().status, expected_code)
385
386     def check_post(self, data: Mapping[str, object], target: str,
387                    expected_code: int, redirect_location: str = '') -> None:
388         """Check that POST of data to target yields expected_code."""
389         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
390         headers = {'Content-Type': 'application/x-www-form-urlencoded',
391                    'Content-Length': str(len(encoded_form_data))}
392         self.conn.request('POST', target,
393                           body=encoded_form_data, headers=headers)
394         if 302 == expected_code:
395             if redirect_location == '':
396                 redirect_location = target
397             self.check_redirect(redirect_location)
398         else:
399             self.assertEqual(self.conn.getresponse().status, expected_code)
400
401     def check_get_defaults(self, path: str) -> None:
402         """Some standard model paths to test."""
403         self.check_get(path, 200)
404         self.check_get(f'{path}?id=', 200)
405         self.check_get(f'{path}?id=foo', 400)
406         self.check_get(f'/{path}?id=0', 500)
407         self.check_get(f'{path}?id=1', 200)
408
409     def post_process(self, id_: int = 1,
410                      form_data: dict[str, Any] | None = None
411                      ) -> dict[str, Any]:
412         """POST basic Process."""
413         if not form_data:
414             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
415         self.check_post(form_data, f'/process?id={id_}', 302,
416                         f'/process?id={id_}')
417         return form_data
418
419     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
420         """Compare JSON on GET path with expected.
421
422         To simplify comparison of VersionedAttribute histories, transforms
423         timestamp keys of VersionedAttribute history keys into integers
424         counting chronologically forward from 0.
425         """
426         def rewrite_history_keys_in(item: Any) -> Any:
427             if isinstance(item, dict):
428                 if '_versioned' in item.keys():
429                     for k in item['_versioned']:
430                         vals = item['_versioned'][k].values()
431                         history = {}
432                         for i, val in enumerate(vals):
433                             history[i] = val
434                         item['_versioned'][k] = history
435                 for k in list(item.keys()):
436                     rewrite_history_keys_in(item[k])
437             elif isinstance(item, list):
438                 item[:] = [rewrite_history_keys_in(i) for i in item]
439             return item
440         self.conn.request('GET', path)
441         response = self.conn.getresponse()
442         self.assertEqual(response.status, 200)
443         retrieved = json_loads(response.read().decode())
444         rewrite_history_keys_in(retrieved)
445         self.assertEqual(expected, retrieved)