home · contact · privacy
Refactor tests, expand Days testing.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
21     def wrapper(self: TestCase) -> None:
22         if hasattr(self, 'checked_class'):
23             f(self)
24     return wrapper
25
26
27 class TestCaseSansDB(TestCase):
28     """Tests requiring no DB setup."""
29     checked_class: Any
30     default_init_args: list[Any] = []
31     versioned_defaults_to_test: dict[str, str | float] = {}
32     legal_ids = [1, 5]
33     illegal_ids = [0]
34
35     @_within_checked_class
36     def test_id_validation(self) -> None:
37         """Test .id_ validation/setting."""
38         for id_ in self.illegal_ids:
39             with self.assertRaises(HandledException):
40                 self.checked_class(id_, *self.default_init_args)
41         for id_ in self.legal_ids:
42             obj = self.checked_class(id_, *self.default_init_args)
43             self.assertEqual(obj.id_, id_)
44
45     @_within_checked_class
46     def test_versioned_defaults(self) -> None:
47         """Test defaults of VersionedAttributes."""
48         id_ = self.legal_ids[0]
49         obj = self.checked_class(id_, *self.default_init_args)
50         for k, v in self.versioned_defaults_to_test.items():
51             self.assertEqual(getattr(obj, k).newest, v)
52
53
54 class TestCaseWithDB(TestCase):
55     """Module tests not requiring DB setup."""
56     checked_class: Any
57     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
58     default_init_kwargs: dict[str, Any] = {}
59     test_versioneds: dict[str, type] = {}
60
61     def setUp(self) -> None:
62         Condition.empty_cache()
63         Day.empty_cache()
64         Process.empty_cache()
65         ProcessStep.empty_cache()
66         Todo.empty_cache()
67         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
68         self.db_conn = DatabaseConnection(self.db_file)
69
70     def tearDown(self) -> None:
71         self.db_conn.close()
72         remove_file(self.db_file.path)
73
74     def _load_from_db(self, id_: int | str) -> list[object]:
75         db_found: list[object] = []
76         for row in self.db_conn.row_where(self.checked_class.table_name,
77                                           'id', id_):
78             db_found += [self.checked_class.from_table_row(self.db_conn,
79                                                            row)]
80         return db_found
81
82     def _change_obj(self, obj: object) -> str:
83         attr_name: str = self.checked_class.to_save[-1]
84         attr = getattr(obj, attr_name)
85         new_attr: str | int | float | bool
86         if isinstance(attr, (int, float)):
87             new_attr = attr + 1
88         elif isinstance(attr, str):
89             new_attr = attr + '_'
90         elif isinstance(attr, bool):
91             new_attr = not attr
92         setattr(obj, attr_name, new_attr)
93         return attr_name
94
95     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
96         """Test both cache and DB equal content."""
97         expected_cache = {}
98         for item in content:
99             expected_cache[item.id_] = item
100         self.assertEqual(self.checked_class.get_cache(), expected_cache)
101         hashes_content = [hash(x) for x in content]
102         db_found: list[Any] = []
103         for item in content:
104             assert isinstance(item.id_, type(self.default_ids[0]))
105             db_found += self._load_from_db(item.id_)
106         hashes_db_found = [hash(x) for x in db_found]
107         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
108
109     @_within_checked_class
110     def test_saving_versioned(self) -> None:
111         """Test storage and initialization of versioned attributes."""
112         def retrieve_attr_vals() -> list[object]:
113             attr_vals_saved: list[object] = []
114             assert hasattr(retrieved, 'id_')
115             for row in self.db_conn.row_where(attr.table_name, 'parent',
116                                               retrieved.id_):
117                 attr_vals_saved += [row[2]]
118             return attr_vals_saved
119         for attr_name, type_ in self.test_versioneds.items():
120             # fail saving attributes on non-saved owner
121             owner = self.checked_class(None, **self.default_init_kwargs)
122             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
123             attr = getattr(owner, attr_name)
124             attr.set(vals[0])
125             attr.set(vals[1])
126             with self.assertRaises(NotFoundException):
127                 attr.save(self.db_conn)
128             owner.save(self.db_conn)
129             # check stored attribute is as expected
130             retrieved = self._load_from_db(owner.id_)[0]
131             attr = getattr(retrieved, attr_name)
132             self.assertEqual(sorted(attr.history.values()), vals)
133             # check owner.save() created entries in attr table
134             attr_vals_saved = retrieve_attr_vals()
135             self.assertEqual(vals, attr_vals_saved)
136             # check setting new val to attr inconsequential to DB without save
137             attr.set(vals[0])
138             attr_vals_saved = retrieve_attr_vals()
139             self.assertEqual(vals, attr_vals_saved)
140             # check save finally adds new val
141             attr.save(self.db_conn)
142             attr_vals_saved = retrieve_attr_vals()
143             self.assertEqual(vals + [vals[0]], attr_vals_saved)
144
145     @_within_checked_class
146     def test_saving_and_caching(self) -> None:
147         """Test effects of .cache() and .save()."""
148         id1 = self.default_ids[0]
149         # check failure to cache without ID (if None-ID input possible)
150         if isinstance(id1, int):
151             obj0 = self.checked_class(None, **self.default_init_kwargs)
152             with self.assertRaises(HandledException):
153                 obj0.cache()
154         # check mere object init itself doesn't even store in cache
155         obj1 = self.checked_class(id1, **self.default_init_kwargs)
156         self.assertEqual(self.checked_class.get_cache(), {})
157         # check .cache() fills cache, but not DB
158         obj1.cache()
159         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
160         db_found = self._load_from_db(id1)
161         self.assertEqual(db_found, [])
162         # check .save() sets ID (for int IDs), updates cache, and fills DB
163         # (expect ID to be set to id1, despite obj1 already having that as ID:
164         # it's generated by cursor.lastrowid on the DB table, and with obj1
165         # not written there, obj2 should get it first!)
166         id_input = None if isinstance(id1, int) else id1
167         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
168         obj2.save(self.db_conn)
169         obj2_hash = hash(obj2)
170         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
171         db_found += self._load_from_db(id1)
172         self.assertEqual([hash(o) for o in db_found], [obj2_hash])
173         # check we cannot overwrite obj2 with obj1 despite its same ID,
174         # since it has disappeared now
175         with self.assertRaises(HandledException):
176             obj1.save(self.db_conn)
177
178     @_within_checked_class
179     def test_by_id(self) -> None:
180         """Test .by_id()."""
181         id1, id2, _ = self.default_ids
182         # check failure if not yet saved
183         obj1 = self.checked_class(id1, **self.default_init_kwargs)
184         with self.assertRaises(NotFoundException):
185             self.checked_class.by_id(self.db_conn, id1)
186         # check identity of cached and retrieved
187         obj1.cache()
188         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
189         # check identity of saved and retrieved
190         obj2 = self.checked_class(id2, **self.default_init_kwargs)
191         obj2.save(self.db_conn)
192         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
193
194     @_within_checked_class
195     def test_by_id_or_create(self) -> None:
196         """Test .by_id_or_create."""
197         # check .by_id_or_create fails if wrong class
198         if not self.checked_class.can_create_by_id:
199             with self.assertRaises(HandledException):
200                 self.checked_class.by_id_or_create(self.db_conn, None)
201             return
202         # check ID input of None creates, on saving, ID=1,2,… for int IDs
203         if isinstance(self.default_ids[0], int):
204             for n in range(2):
205                 item = self.checked_class.by_id_or_create(self.db_conn, None)
206                 self.assertEqual(item.id_, None)
207                 item.save(self.db_conn)
208                 self.assertEqual(item.id_, n+1)
209         # check .by_id_or_create acts like normal instantiation (sans saving)
210         id_ = self.default_ids[2]
211         item = self.checked_class.by_id_or_create(self.db_conn, id_)
212         self.assertEqual(item.id_, id_)
213         with self.assertRaises(NotFoundException):
214             self.checked_class.by_id(self.db_conn, item.id_)
215         self.assertEqual(self.checked_class(item.id_), item)
216
217     @_within_checked_class
218     def test_from_table_row(self) -> None:
219         """Test .from_table_row() properly reads in class directly from DB."""
220         id_ = self.default_ids[0]
221         obj = self.checked_class(id_, **self.default_init_kwargs)
222         obj.save(self.db_conn)
223         assert isinstance(obj.id_, type(id_))
224         for row in self.db_conn.row_where(self.checked_class.table_name,
225                                           'id', obj.id_):
226             # check .from_table_row reproduces state saved, no matter if obj
227             # later changed (with caching even)
228             hash_original = hash(obj)
229             attr_name = self._change_obj(obj)
230             obj.cache()
231             to_cmp = getattr(obj, attr_name)
232             retrieved = self.checked_class.from_table_row(self.db_conn, row)
233             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
234             self.assertEqual(hash_original, hash(retrieved))
235             # check cache contains what .from_table_row just produced
236             self.assertEqual({retrieved.id_: retrieved},
237                              self.checked_class.get_cache())
238         # check .from_table_row also reads versioned attributes from DB
239         for attr_name, type_ in self.test_versioneds.items():
240             owner = self.checked_class(None)
241             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
242             attr = getattr(owner, attr_name)
243             attr.set(vals[0])
244             attr.set(vals[1])
245             owner.save(self.db_conn)
246             for row in self.db_conn.row_where(owner.table_name, 'id',
247                                               owner.id_):
248                 retrieved = owner.__class__.from_table_row(self.db_conn, row)
249                 attr = getattr(retrieved, attr_name)
250                 self.assertEqual(sorted(attr.history.values()), vals)
251
252     @_within_checked_class
253     def test_all(self) -> None:
254         """Test .all() and its relation to cache and savings."""
255         id_1, id_2, id_3 = self.default_ids
256         item1 = self.checked_class(id_1, **self.default_init_kwargs)
257         item2 = self.checked_class(id_2, **self.default_init_kwargs)
258         item3 = self.checked_class(id_3, **self.default_init_kwargs)
259         # check .all() returns empty list on un-cached items
260         self.assertEqual(self.checked_class.all(self.db_conn), [])
261         # check that all() shows only cached/saved items
262         item1.cache()
263         item3.save(self.db_conn)
264         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
265                          sorted([item1, item3]))
266         item2.save(self.db_conn)
267         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
268                          sorted([item1, item2, item3]))
269
270     @_within_checked_class
271     def test_singularity(self) -> None:
272         """Test pointers made for single object keep pointing to it."""
273         id1 = self.default_ids[0]
274         obj = self.checked_class(id1, **self.default_init_kwargs)
275         obj.save(self.db_conn)
276         # change object, expect retrieved through .by_id to carry change
277         attr_name = self._change_obj(obj)
278         new_attr = getattr(obj, attr_name)
279         retrieved = self.checked_class.by_id(self.db_conn, id1)
280         self.assertEqual(new_attr, getattr(retrieved, attr_name))
281
282     @_within_checked_class
283     def test_versioned_singularity_title(self) -> None:
284         """Test singularity of VersionedAttributes on saving (with .title)."""
285         if 'title' in self.test_versioneds:
286             obj = self.checked_class(None)
287             obj.save(self.db_conn)
288             assert isinstance(obj.id_, int)
289             # change obj, expect retrieved through .by_id to carry change
290             obj.title.set('named')
291             retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
292             self.assertEqual(obj.title.history, retrieved.title.history)
293
294     @_within_checked_class
295     def test_remove(self) -> None:
296         """Test .remove() effects on DB and cache."""
297         id_ = self.default_ids[0]
298         obj = self.checked_class(id_, **self.default_init_kwargs)
299         # check removal only works after saving
300         with self.assertRaises(HandledException):
301             obj.remove(self.db_conn)
302         obj.save(self.db_conn)
303         obj.remove(self.db_conn)
304         # check access to obj fails after removal
305         with self.assertRaises(HandledException):
306             print(obj.id_)
307         # check DB and cache now empty
308         self.check_identity_with_cache_and_db([])
309
310
311 class TestCaseWithServer(TestCaseWithDB):
312     """Module tests against our HTTP server/handler (and database)."""
313
314     def setUp(self) -> None:
315         super().setUp()
316         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
317         self.server_thread = Thread(target=self.httpd.serve_forever)
318         self.server_thread.daemon = True
319         self.server_thread.start()
320         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
321                                    self.httpd.server_address[1])
322         self.httpd.set_json_mode()
323
324     def tearDown(self) -> None:
325         self.httpd.shutdown()
326         self.httpd.server_close()
327         self.server_thread.join()
328         super().tearDown()
329
330     @staticmethod
331     def proc_as_dict(id_: int = 1,
332                      title: str = 'A',
333                      description: str = '',
334                      effort: float = 1.0,
335                      enables: None | list[dict[str, object]] = None,
336                      disables: None | list[dict[str, object]] = None,
337                      conditions: None | list[dict[str, object]] = None,
338                      blockers: None | list[dict[str, object]] = None
339                      ) -> dict[str, object]:
340         """Return JSON of Process to expect."""
341         # pylint: disable=too-many-arguments
342         d = {'id': id_,
343              'calendarize': False,
344              'suppressed_steps': [],
345              'explicit_steps': [],
346              '_versioned': {
347                  'title': {0: title},
348                  'description': {0: description},
349                  'effort': {0: effort}
350                  },
351              'conditions': conditions if conditions else [],
352              'disables': disables if disables else [],
353              'enables': enables if enables else [],
354              'blockers': blockers if blockers else []}
355         return d
356
357     def check_redirect(self, target: str) -> None:
358         """Check that self.conn answers with a 302 redirect to target."""
359         response = self.conn.getresponse()
360         self.assertEqual(response.status, 302)
361         self.assertEqual(response.getheader('Location'), target)
362
363     def check_get(self, target: str, expected_code: int) -> None:
364         """Check that a GET to target yields expected_code."""
365         self.conn.request('GET', target)
366         self.assertEqual(self.conn.getresponse().status, expected_code)
367
368     def check_post(self, data: Mapping[str, object], target: str,
369                    expected_code: int, redirect_location: str = '') -> None:
370         """Check that POST of data to target yields expected_code."""
371         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
372         headers = {'Content-Type': 'application/x-www-form-urlencoded',
373                    'Content-Length': str(len(encoded_form_data))}
374         self.conn.request('POST', target,
375                           body=encoded_form_data, headers=headers)
376         if 302 == expected_code:
377             if redirect_location == '':
378                 redirect_location = target
379             self.check_redirect(redirect_location)
380         else:
381             self.assertEqual(self.conn.getresponse().status, expected_code)
382
383     def check_get_defaults(self, path: str) -> None:
384         """Some standard model paths to test."""
385         self.check_get(path, 200)
386         self.check_get(f'{path}?id=', 200)
387         self.check_get(f'{path}?id=foo', 400)
388         self.check_get(f'/{path}?id=0', 500)
389         self.check_get(f'{path}?id=1', 200)
390
391     def post_process(self, id_: int = 1,
392                      form_data: dict[str, Any] | None = None
393                      ) -> dict[str, Any]:
394         """POST basic Process."""
395         if not form_data:
396             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
397         self.check_post(form_data, f'/process?id={id_}', 302,
398                         f'/process?id={id_}')
399         return form_data
400
401     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
402         """Compare JSON on GET path with expected.
403
404         To simplify comparison of VersionedAttribute histories, transforms
405         timestamp keys of VersionedAttribute history keys into integers
406         counting chronologically forward from 0.
407         """
408         def rewrite_history_keys_in(item: Any) -> Any:
409             if isinstance(item, dict):
410                 if '_versioned' in item.keys():
411                     for k in item['_versioned']:
412                         vals = item['_versioned'][k].values()
413                         history = {}
414                         for i, val in enumerate(vals):
415                             history[i] = val
416                         item['_versioned'][k] = history
417                 for k in list(item.keys()):
418                     rewrite_history_keys_in(item[k])
419             elif isinstance(item, list):
420                 item[:] = [rewrite_history_keys_in(i) for i in item]
421             return item
422         self.conn.request('GET', path)
423         response = self.conn.getresponse()
424         self.assertEqual(response.status, 200)
425         retrieved = json_loads(response.read().decode())
426         rewrite_history_keys_in(retrieved)
427         self.assertEqual(expected, retrieved)