home · contact · privacy
Re-factor TestCaseSansDB methods.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
21     def wrapper(self: TestCase) -> None:
22         if hasattr(self, 'checked_class'):
23             f(self)
24     return wrapper
25
26
27 class TestCaseSansDB(TestCase):
28     """Tests requiring no DB setup."""
29     checked_class: Any
30     default_init_args: list[Any] = []
31     versioned_defaults_to_test: dict[str, str | float] = {}
32     legal_ids = [1, 5]
33     illegal_ids = [0]
34
35     @_within_checked_class
36     def test_id_validation(self) -> None:
37         """Test .id_ validation/setting."""
38         for id_ in self.illegal_ids:
39             with self.assertRaises(HandledException):
40                 self.checked_class(id_, *self.default_init_args)
41         for id_ in self.legal_ids:
42             obj = self.checked_class(id_, *self.default_init_args)
43             self.assertEqual(obj.id_, id_)
44
45     @_within_checked_class
46     def test_versioned_defaults(self) -> None:
47         """Test defaults of VersionedAttributes."""
48         id_ = self.legal_ids[0]
49         obj = self.checked_class(id_, *self.default_init_args)
50         for k, v in self.versioned_defaults_to_test.items():
51             self.assertEqual(getattr(obj, k).newest, v)
52
53
54 class TestCaseWithDB(TestCase):
55     """Module tests not requiring DB setup."""
56     checked_class: Any
57     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
58     default_init_kwargs: dict[str, Any] = {}
59     test_versioneds: dict[str, type] = {}
60
61     def setUp(self) -> None:
62         Condition.empty_cache()
63         Day.empty_cache()
64         Process.empty_cache()
65         ProcessStep.empty_cache()
66         Todo.empty_cache()
67         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
68         self.db_conn = DatabaseConnection(self.db_file)
69
70     def tearDown(self) -> None:
71         self.db_conn.close()
72         remove_file(self.db_file.path)
73
74     def _load_from_db(self, id_: int | str) -> list[object]:
75         db_found: list[object] = []
76         for row in self.db_conn.row_where(self.checked_class.table_name,
77                                           'id', id_):
78             db_found += [self.checked_class.from_table_row(self.db_conn,
79                                                            row)]
80         return db_found
81
82     def _change_obj(self, obj: object) -> str:
83         attr_name: str = self.checked_class.to_save[-1]
84         attr = getattr(obj, attr_name)
85         new_attr: str | int | float | bool
86         if isinstance(attr, (int, float)):
87             new_attr = attr + 1
88         elif isinstance(attr, str):
89             new_attr = attr + '_'
90         elif isinstance(attr, bool):
91             new_attr = not attr
92         setattr(obj, attr_name, new_attr)
93         return attr_name
94
95     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
96         """Test both cache and DB equal content."""
97         expected_cache = {}
98         for item in content:
99             expected_cache[item.id_] = item
100         self.assertEqual(self.checked_class.get_cache(), expected_cache)
101         hashes_content = [hash(x) for x in content]
102         db_found: list[Any] = []
103         for item in content:
104             assert isinstance(item.id_, type(self.default_ids[0]))
105             db_found += self._load_from_db(item.id_)
106         hashes_db_found = [hash(x) for x in db_found]
107         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
108
109     @_within_checked_class
110     def test_saving_versioned(self) -> None:
111         """Test storage and initialization of versioned attributes."""
112         def retrieve_attr_vals() -> list[object]:
113             attr_vals_saved: list[object] = []
114             assert hasattr(retrieved, 'id_')
115             for row in self.db_conn.row_where(attr.table_name, 'parent',
116                                               retrieved.id_):
117                 attr_vals_saved += [row[2]]
118             return attr_vals_saved
119         for attr_name, type_ in self.test_versioneds.items():
120             # fail saving attributes on non-saved owner
121             owner = self.checked_class(None, **self.default_init_kwargs)
122             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
123             attr = getattr(owner, attr_name)
124             attr.set(vals[0])
125             attr.set(vals[1])
126             with self.assertRaises(NotFoundException):
127                 attr.save(self.db_conn)
128             owner.save(self.db_conn)
129             # check stored attribute is as expected
130             retrieved = self._load_from_db(owner.id_)[0]
131             attr = getattr(retrieved, attr_name)
132             self.assertEqual(sorted(attr.history.values()), vals)
133             # check owner.save() created entries in attr table
134             attr_vals_saved = retrieve_attr_vals()
135             self.assertEqual(vals, attr_vals_saved)
136             # check setting new val to attr inconsequential to DB without save
137             attr.set(vals[0])
138             attr_vals_saved = retrieve_attr_vals()
139             self.assertEqual(vals, attr_vals_saved)
140             # check save finally adds new val
141             attr.save(self.db_conn)
142             attr_vals_saved = retrieve_attr_vals()
143             self.assertEqual(vals + [vals[0]], attr_vals_saved)
144
145     @_within_checked_class
146     def test_saving_and_caching(self) -> None:
147         """Test effects of .cache() and .save()."""
148         id1 = self.default_ids[0]
149         # check failure to cache without ID (if None-ID input possible)
150         if isinstance(id1, int):
151             obj0 = self.checked_class(None, **self.default_init_kwargs)
152             with self.assertRaises(HandledException):
153                 obj0.cache()
154         # check mere object init itself doesn't even store in cache
155         obj1 = self.checked_class(id1, **self.default_init_kwargs)
156         self.assertEqual(self.checked_class.get_cache(), {})
157         # check .cache() fills cache, but not DB
158         obj1.cache()
159         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
160         db_found = self._load_from_db(id1)
161         self.assertEqual(db_found, [])
162         # check .save() sets ID (for int IDs), updates cache, and fills DB
163         # (expect ID to be set to id1, despite obj1 already having that as ID:
164         # it's generated by cursor.lastrowid on the DB table, and with obj1
165         # not written there, obj2 should get it first!)
166         id_input = None if isinstance(id1, int) else id1
167         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
168         obj2.save(self.db_conn)
169         obj2_hash = hash(obj2)
170         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
171         db_found += self._load_from_db(id1)
172         self.assertEqual([hash(o) for o in db_found], [obj2_hash])
173         # check we cannot overwrite obj2 with obj1 despite its same ID,
174         # since it has disappeared now
175         with self.assertRaises(HandledException):
176             obj1.save(self.db_conn)
177
178     @_within_checked_class
179     def test_by_id(self) -> None:
180         """Test .by_id()."""
181         id1, id2, _ = self.default_ids
182         # check failure if not yet saved
183         obj1 = self.checked_class(id1, **self.default_init_kwargs)
184         with self.assertRaises(NotFoundException):
185             self.checked_class.by_id(self.db_conn, id1)
186         # check identity of cached and retrieved
187         obj1.cache()
188         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
189         # check identity of saved and retrieved
190         obj2 = self.checked_class(id2, **self.default_init_kwargs)
191         obj2.save(self.db_conn)
192         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
193
194     @_within_checked_class
195     def test_by_id_or_create(self) -> None:
196         """Test .by_id_or_create."""
197         # check .by_id_or_create fails if wrong class
198         if not self.checked_class.can_create_by_id:
199             with self.assertRaises(HandledException):
200                 self.checked_class.by_id_or_create(self.db_conn, None)
201             return
202         # check ID input of None creates, on saving, ID=1,2,… for int IDs
203         if isinstance(self.default_ids[0], int):
204             for n in range(2):
205                 item = self.checked_class.by_id_or_create(self.db_conn, None)
206                 self.assertEqual(item.id_, None)
207                 item.save(self.db_conn)
208                 self.assertEqual(item.id_, n+1)
209         # check .by_id_or_create acts like normal instantiation (sans saving)
210         id_ = self.default_ids[2]
211         item = self.checked_class.by_id_or_create(self.db_conn, id_)
212         self.assertEqual(item.id_, id_)
213         with self.assertRaises(NotFoundException):
214             self.checked_class.by_id(self.db_conn, item.id_)
215         self.assertEqual(self.checked_class(item.id_), item)
216
217     @_within_checked_class
218     def test_from_table_row(self) -> None:
219         """Test .from_table_row() properly reads in class directly from DB."""
220         id_ = self.default_ids[0]
221         obj = self.checked_class(id_, **self.default_init_kwargs)
222         obj.save(self.db_conn)
223         assert isinstance(obj.id_, type(id_))
224         for row in self.db_conn.row_where(self.checked_class.table_name,
225                                           'id', obj.id_):
226             # check .from_table_row reproduces state saved, no matter if obj
227             # later changed (with caching even)
228             hash_original = hash(obj)
229             attr_name = self._change_obj(obj)
230             obj.cache()
231             to_cmp = getattr(obj, attr_name)
232             retrieved = self.checked_class.from_table_row(self.db_conn, row)
233             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
234             self.assertEqual(hash_original, hash(retrieved))
235             # check cache contains what .from_table_row just produced
236             self.assertEqual({retrieved.id_: retrieved},
237                              self.checked_class.get_cache())
238         # check .from_table_row also reads versioned attributes from DB
239         for attr_name, type_ in self.test_versioneds.items():
240             owner = self.checked_class(None)
241             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
242             attr = getattr(owner, attr_name)
243             attr.set(vals[0])
244             attr.set(vals[1])
245             owner.save(self.db_conn)
246             for row in self.db_conn.row_where(owner.table_name, 'id',
247                                               owner.id_):
248                 retrieved = owner.__class__.from_table_row(self.db_conn, row)
249                 attr = getattr(retrieved, attr_name)
250                 self.assertEqual(sorted(attr.history.values()), vals)
251
252     @_within_checked_class
253     def test_all(self) -> None:
254         """Test .all() and its relation to cache and savings."""
255         id_1, id_2, id_3 = self.default_ids
256         item1 = self.checked_class(id_1, **self.default_init_kwargs)
257         item2 = self.checked_class(id_2, **self.default_init_kwargs)
258         item3 = self.checked_class(id_3, **self.default_init_kwargs)
259         # check .all() returns empty list on un-cached items
260         self.assertEqual(self.checked_class.all(self.db_conn), [])
261         # check that all() shows only cached/saved items
262         item1.cache()
263         item3.save(self.db_conn)
264         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
265                          sorted([item1, item3]))
266         item2.save(self.db_conn)
267         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
268                          sorted([item1, item2, item3]))
269
270     @_within_checked_class
271     def test_singularity(self) -> None:
272         """Test pointers made for single object keep pointing to it."""
273         id1 = self.default_ids[0]
274         obj = self.checked_class(id1, **self.default_init_kwargs)
275         obj.save(self.db_conn)
276         # change object, expect retrieved through .by_id to carry change
277         attr_name = self._change_obj(obj)
278         new_attr = getattr(obj, attr_name)
279         retrieved = self.checked_class.by_id(self.db_conn, id1)
280         self.assertEqual(new_attr, getattr(retrieved, attr_name))
281
282     @_within_checked_class
283     def test_versioned_singularity_title(self) -> None:
284         """Test singularity of VersionedAttributes on saving (with .title)."""
285         if 'title' in self.test_versioneds:
286             obj = self.checked_class(None)
287             obj.save(self.db_conn)
288             assert isinstance(obj.id_, int)
289             # change obj, expect retrieved through .by_id to carry change
290             obj.title.set('named')
291             retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
292             self.assertEqual(obj.title.history, retrieved.title.history)
293
294     @_within_checked_class
295     def test_remove(self) -> None:
296         """Test .remove() effects on DB and cache."""
297         id_ = self.default_ids[0]
298         obj = self.checked_class(id_, **self.default_init_kwargs)
299         # check removal only works after saving
300         with self.assertRaises(HandledException):
301             obj.remove(self.db_conn)
302         obj.save(self.db_conn)
303         obj.remove(self.db_conn)
304         # check access to obj fails after removal
305         with self.assertRaises(HandledException):
306             print(obj.id_)
307         # check DB and cache now empty
308         self.check_identity_with_cache_and_db([])
309
310
311 class TestCaseWithServer(TestCaseWithDB):
312     """Module tests against our HTTP server/handler (and database)."""
313
314     def setUp(self) -> None:
315         super().setUp()
316         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
317         self.server_thread = Thread(target=self.httpd.serve_forever)
318         self.server_thread.daemon = True
319         self.server_thread.start()
320         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
321                                    self.httpd.server_address[1])
322         self.httpd.set_json_mode()
323
324     def tearDown(self) -> None:
325         self.httpd.shutdown()
326         self.httpd.server_close()
327         self.server_thread.join()
328         super().tearDown()
329
330     def check_redirect(self, target: str) -> None:
331         """Check that self.conn answers with a 302 redirect to target."""
332         response = self.conn.getresponse()
333         self.assertEqual(response.status, 302)
334         self.assertEqual(response.getheader('Location'), target)
335
336     def check_get(self, target: str, expected_code: int) -> None:
337         """Check that a GET to target yields expected_code."""
338         self.conn.request('GET', target)
339         self.assertEqual(self.conn.getresponse().status, expected_code)
340
341     def check_post(self, data: Mapping[str, object], target: str,
342                    expected_code: int, redirect_location: str = '') -> None:
343         """Check that POST of data to target yields expected_code."""
344         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
345         headers = {'Content-Type': 'application/x-www-form-urlencoded',
346                    'Content-Length': str(len(encoded_form_data))}
347         self.conn.request('POST', target,
348                           body=encoded_form_data, headers=headers)
349         if 302 == expected_code:
350             if redirect_location == '':
351                 redirect_location = target
352             self.check_redirect(redirect_location)
353         else:
354             self.assertEqual(self.conn.getresponse().status, expected_code)
355
356     def check_get_defaults(self, path: str) -> None:
357         """Some standard model paths to test."""
358         self.check_get(path, 200)
359         self.check_get(f'{path}?id=', 200)
360         self.check_get(f'{path}?id=foo', 400)
361         self.check_get(f'/{path}?id=0', 500)
362         self.check_get(f'{path}?id=1', 200)
363
364     def post_process(self, id_: int = 1,
365                      form_data: dict[str, Any] | None = None
366                      ) -> dict[str, Any]:
367         """POST basic Process."""
368         if not form_data:
369             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
370         self.check_post(form_data, f'/process?id={id_}', 302,
371                         f'/process?id={id_}')
372         return form_data
373
374     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
375         """Compare JSON on GET path with expected.
376
377         To simplify comparison of VersionedAttribute histories, transforms
378         timestamp keys of VersionedAttribute history keys into integers
379         counting chronologically forward from 0.
380         """
381         def rewrite_history_keys_in(item: Any) -> Any:
382             if isinstance(item, dict):
383                 if '_versioned' in item.keys():
384                     for k in item['_versioned']:
385                         vals = item['_versioned'][k].values()
386                         history = {}
387                         for i, val in enumerate(vals):
388                             history[i] = val
389                         item['_versioned'][k] = history
390                 for k in list(item.keys()):
391                     rewrite_history_keys_in(item[k])
392             elif isinstance(item, list):
393                 item[:] = [rewrite_history_keys_in(i) for i in item]
394             return item
395         self.conn.request('GET', path)
396         response = self.conn.getresponse()
397         self.assertEqual(response.status, 200)
398         retrieved = json_loads(response.read().decode())
399         rewrite_history_keys_in(retrieved)
400         self.assertEqual(expected, retrieved)