home · contact · privacy
60157104624ac79a59581757dd58a0344b30da6b
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 class TestCaseSansDB(TestCase):
21     """Tests requiring no DB setup."""
22     checked_class: Any
23     do_id_test: bool = False
24     default_init_args: list[Any] = []
25     versioned_defaults_to_test: dict[str, str | float] = {}
26
27     def test_id_setting(self) -> None:
28         """Test .id_ being set and its legal range being enforced."""
29         if not self.do_id_test:
30             return
31         with self.assertRaises(HandledException):
32             self.checked_class(0, *self.default_init_args)
33         obj = self.checked_class(5, *self.default_init_args)
34         self.assertEqual(obj.id_, 5)
35
36     def test_versioned_defaults(self) -> None:
37         """Test defaults of VersionedAttributes."""
38         if len(self.versioned_defaults_to_test) == 0:
39             return
40         obj = self.checked_class(1, *self.default_init_args)
41         for k, v in self.versioned_defaults_to_test.items():
42             self.assertEqual(getattr(obj, k).newest, v)
43
44
45 class TestCaseWithDB(TestCase):
46     """Module tests not requiring DB setup."""
47     checked_class: Any
48     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49     default_init_kwargs: dict[str, Any] = {}
50     test_versioneds: dict[str, type] = {}
51
52     def setUp(self) -> None:
53         Condition.empty_cache()
54         Day.empty_cache()
55         Process.empty_cache()
56         ProcessStep.empty_cache()
57         Todo.empty_cache()
58         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59         self.db_conn = DatabaseConnection(self.db_file)
60
61     def tearDown(self) -> None:
62         self.db_conn.close()
63         remove_file(self.db_file.path)
64
65     @staticmethod
66     def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67         def wrapper(self: TestCaseWithDB) -> None:
68             if hasattr(self, 'checked_class'):
69                 f(self)
70         return wrapper
71
72     def _load_from_db(self, id_: int | str) -> list[object]:
73         db_found: list[object] = []
74         for row in self.db_conn.row_where(self.checked_class.table_name,
75                                           'id', id_):
76             db_found += [self.checked_class.from_table_row(self.db_conn,
77                                                            row)]
78         return db_found
79
80     @_within_checked_class
81     def test_saving_versioned(self) -> None:
82         """Test storage and initialization of versioned attributes."""
83         def retrieve_attr_vals() -> list[object]:
84             attr_vals_saved: list[object] = []
85             assert hasattr(retrieved, 'id_')
86             for row in self.db_conn.row_where(attr.table_name, 'parent',
87                                               retrieved.id_):
88                 attr_vals_saved += [row[2]]
89             return attr_vals_saved
90         for attr_name, type_ in self.test_versioneds.items():
91             # fail saving attributes on non-saved owner
92             owner = self.checked_class(None, **self.default_init_kwargs)
93             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
94             attr = getattr(owner, attr_name)
95             attr.set(vals[0])
96             attr.set(vals[1])
97             with self.assertRaises(NotFoundException):
98                 attr.save(self.db_conn)
99             owner.save(self.db_conn)
100             # check stored attribute is as expected
101             retrieved = self._load_from_db(owner.id_)[0]
102             attr = getattr(retrieved, attr_name)
103             self.assertEqual(sorted(attr.history.values()), vals)
104             # check owner.save() created entries in attr table
105             attr_vals_saved = retrieve_attr_vals()
106             self.assertEqual(vals, attr_vals_saved)
107             # check setting new val to attr inconsequential to DB without save
108             attr.set(vals[0])
109             attr_vals_saved = retrieve_attr_vals()
110             self.assertEqual(vals, attr_vals_saved)
111             # check save finally adds new val
112             attr.save(self.db_conn)
113             attr_vals_saved = retrieve_attr_vals()
114             self.assertEqual(vals + [vals[0]], attr_vals_saved)
115
116     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
117         """Test both cache and DB equal content."""
118         expected_cache = {}
119         for item in content:
120             expected_cache[item.id_] = item
121         self.assertEqual(self.checked_class.get_cache(), expected_cache)
122         hashes_content = [hash(x) for x in content]
123         db_found: list[Any] = []
124         for item in content:
125             assert isinstance(item.id_, type(self.default_ids[0]))
126             db_found += self._load_from_db(item.id_)
127         hashes_db_found = [hash(x) for x in db_found]
128         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
129
130     @_within_checked_class
131     def test_saving_and_caching(self) -> None:
132         """Test effects of .cache() and .save()."""
133         id1 = self.default_ids[0]
134         # check failure to cache without ID (if None-ID input possible)
135         if isinstance(id1, int):
136             obj0 = self.checked_class(None, **self.default_init_kwargs)
137             with self.assertRaises(HandledException):
138                 obj0.cache()
139         # check mere object init itself doesn't even store in cache
140         obj1 = self.checked_class(id1, **self.default_init_kwargs)
141         self.assertEqual(self.checked_class.get_cache(), {})
142         # check .cache() fills cache, but not DB
143         obj1.cache()
144         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
145         db_found = self._load_from_db(id1)
146         self.assertEqual(db_found, [])
147         # check .save() sets ID (for int IDs), updates cache, and fills DB
148         # (expect ID to be set to id1, despite obj1 already having that as ID:
149         # it's generated by cursor.lastrowid on the DB table, and with obj1
150         # not written there, obj2 should get it first!)
151         id_input = None if isinstance(id1, int) else id1
152         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
153         obj2.save(self.db_conn)
154         obj2_hash = hash(obj2)
155         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
156         db_found += self._load_from_db(id1)
157         self.assertEqual([hash(o) for o in db_found], [obj2_hash])
158         # check we cannot overwrite obj2 with obj1 despite its same ID,
159         # since it has disappeared now
160         with self.assertRaises(HandledException):
161             obj1.save(self.db_conn)
162
163     @_within_checked_class
164     def test_by_id(self) -> None:
165         """Test .by_id()."""
166         id1, id2, _ = self.default_ids
167         # check failure if not yet saved
168         obj1 = self.checked_class(id1, **self.default_init_kwargs)
169         with self.assertRaises(NotFoundException):
170             self.checked_class.by_id(self.db_conn, id1)
171         # check identity of cached and retrieved
172         obj1.cache()
173         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
174         # check identity of saved and retrieved
175         obj2 = self.checked_class(id2, **self.default_init_kwargs)
176         obj2.save(self.db_conn)
177         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
178
179     @_within_checked_class
180     def test_by_id_or_create(self) -> None:
181         """Test .by_id_or_create."""
182         # check .by_id_or_create acts like normal instantiation (sans saving)
183         id_ = self.default_ids[0]
184         if not self.checked_class.can_create_by_id:
185             with self.assertRaises(HandledException):
186                 self.checked_class.by_id_or_create(self.db_conn, id_)
187         # check .by_id_or_create fails if wrong class
188         else:
189             by_id_created = self.checked_class.by_id_or_create(self.db_conn,
190                                                                id_)
191             with self.assertRaises(NotFoundException):
192                 self.checked_class.by_id(self.db_conn, id_)
193             self.assertEqual(self.checked_class(id_), by_id_created)
194
195     @_within_checked_class
196     def test_from_table_row(self) -> None:
197         """Test .from_table_row() properly reads in class directly from DB."""
198         id_ = self.default_ids[0]
199         obj = self.checked_class(id_, **self.default_init_kwargs)
200         obj.save(self.db_conn)
201         assert isinstance(obj.id_, type(self.default_ids[0]))
202         for row in self.db_conn.row_where(self.checked_class.table_name,
203                                           'id', obj.id_):
204             # check .from_table_row reproduces state saved, no matter if obj
205             # later changed (with caching even)
206             hash_original = hash(obj)
207             attr_name = self.checked_class.to_save[-1]
208             attr = getattr(obj, attr_name)
209             if isinstance(attr, (int, float)):
210                 setattr(obj, attr_name, attr + 1)
211             elif isinstance(attr, str):
212                 setattr(obj, attr_name, attr + "_")
213             elif isinstance(attr, bool):
214                 setattr(obj, attr_name, not attr)
215             obj.cache()
216             to_cmp = getattr(obj, attr_name)
217             retrieved = self.checked_class.from_table_row(self.db_conn, row)
218             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
219             self.assertEqual(hash_original, hash(retrieved))
220             # check cache contains what .from_table_row just produced
221             self.assertEqual({retrieved.id_: retrieved},
222                              self.checked_class.get_cache())
223
224     def check_versioned_from_table_row(self, attr_name: str,
225                                        type_: type) -> None:
226         """Test .from_table_row() reads versioned attributes from DB."""
227         owner = self.checked_class(None)
228         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
229         attr = getattr(owner, attr_name)
230         attr.set(vals[0])
231         attr.set(vals[1])
232         owner.save(self.db_conn)
233         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
234             retrieved = owner.__class__.from_table_row(self.db_conn, row)
235             attr = getattr(retrieved, attr_name)
236             self.assertEqual(sorted(attr.history.values()), vals)
237
238     @_within_checked_class
239     def test_all(self) -> None:
240         """Test .all() and its relation to cache and savings."""
241         id_1, id_2, id_3 = self.default_ids
242         item1 = self.checked_class(id_1, **self.default_init_kwargs)
243         item2 = self.checked_class(id_2, **self.default_init_kwargs)
244         item3 = self.checked_class(id_3, **self.default_init_kwargs)
245         # check .all() returns empty list on un-cached items
246         self.assertEqual(self.checked_class.all(self.db_conn), [])
247         # check that all() shows only cached/saved items
248         item1.cache()
249         item3.save(self.db_conn)
250         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
251                          sorted([item1, item3]))
252         item2.save(self.db_conn)
253         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
254                          sorted([item1, item2, item3]))
255
256     @_within_checked_class
257     def test_singularity(self) -> None:
258         """Test pointers made for single object keep pointing to it."""
259         id1 = self.default_ids[0]
260         obj = self.checked_class(id1, **self.default_init_kwargs)
261         obj.save(self.db_conn)
262         attr_name = self.checked_class.to_save[-1]
263         attr = getattr(obj, attr_name)
264         new_attr: str | int | float | bool
265         if isinstance(attr, (int, float)):
266             new_attr = attr + 1
267         elif isinstance(attr, str):
268             new_attr = attr + '_'
269         elif isinstance(attr, bool):
270             new_attr = not attr
271         setattr(obj, attr_name, new_attr)
272         retrieved = self.checked_class.by_id(self.db_conn, id1)
273         self.assertEqual(new_attr, getattr(retrieved, attr_name))
274
275     def check_versioned_singularity(self) -> None:
276         """Test singularity of VersionedAttributes on saving (with .title)."""
277         obj = self.checked_class(None)  # pylint: disable=not-callable
278         obj.save(self.db_conn)
279         assert isinstance(obj.id_, int)
280         obj.title.set('named')
281         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
282         self.assertEqual(obj.title.history, retrieved.title.history)
283
284     def check_remove(self, *args: Any) -> None:
285         """Test .remove() effects on DB and cache."""
286         id_ = self.default_ids[0]
287         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
288         with self.assertRaises(HandledException):
289             obj.remove(self.db_conn)
290         obj.save(self.db_conn)
291         obj.remove(self.db_conn)
292         self.check_identity_with_cache_and_db([])
293
294
295 class TestCaseWithServer(TestCaseWithDB):
296     """Module tests against our HTTP server/handler (and database)."""
297
298     def setUp(self) -> None:
299         super().setUp()
300         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
301         self.server_thread = Thread(target=self.httpd.serve_forever)
302         self.server_thread.daemon = True
303         self.server_thread.start()
304         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
305                                    self.httpd.server_address[1])
306         self.httpd.set_json_mode()
307
308     def tearDown(self) -> None:
309         self.httpd.shutdown()
310         self.httpd.server_close()
311         self.server_thread.join()
312         super().tearDown()
313
314     def check_redirect(self, target: str) -> None:
315         """Check that self.conn answers with a 302 redirect to target."""
316         response = self.conn.getresponse()
317         self.assertEqual(response.status, 302)
318         self.assertEqual(response.getheader('Location'), target)
319
320     def check_get(self, target: str, expected_code: int) -> None:
321         """Check that a GET to target yields expected_code."""
322         self.conn.request('GET', target)
323         self.assertEqual(self.conn.getresponse().status, expected_code)
324
325     def check_post(self, data: Mapping[str, object], target: str,
326                    expected_code: int, redirect_location: str = '') -> None:
327         """Check that POST of data to target yields expected_code."""
328         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
329         headers = {'Content-Type': 'application/x-www-form-urlencoded',
330                    'Content-Length': str(len(encoded_form_data))}
331         self.conn.request('POST', target,
332                           body=encoded_form_data, headers=headers)
333         if 302 == expected_code:
334             if redirect_location == '':
335                 redirect_location = target
336             self.check_redirect(redirect_location)
337         else:
338             self.assertEqual(self.conn.getresponse().status, expected_code)
339
340     def check_get_defaults(self, path: str) -> None:
341         """Some standard model paths to test."""
342         self.check_get(path, 200)
343         self.check_get(f'{path}?id=', 200)
344         self.check_get(f'{path}?id=foo', 400)
345         self.check_get(f'/{path}?id=0', 500)
346         self.check_get(f'{path}?id=1', 200)
347
348     def post_process(self, id_: int = 1,
349                      form_data: dict[str, Any] | None = None
350                      ) -> dict[str, Any]:
351         """POST basic Process."""
352         if not form_data:
353             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
354         self.check_post(form_data, f'/process?id={id_}', 302,
355                         f'/process?id={id_}')
356         return form_data
357
358     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
359         """Compare JSON on GET path with expected.
360
361         To simplify comparison of VersionedAttribute histories, transforms
362         timestamp keys of VersionedAttribute history keys into integers
363         counting chronologically forward from 0.
364         """
365         def rewrite_history_keys_in(item: Any) -> Any:
366             if isinstance(item, dict):
367                 if '_versioned' in item.keys():
368                     for k in item['_versioned']:
369                         vals = item['_versioned'][k].values()
370                         history = {}
371                         for i, val in enumerate(vals):
372                             history[i] = val
373                         item['_versioned'][k] = history
374                 for k in list(item.keys()):
375                     rewrite_history_keys_in(item[k])
376             elif isinstance(item, list):
377                 item[:] = [rewrite_history_keys_in(i) for i in item]
378             return item
379         self.conn.request('GET', path)
380         response = self.conn.getresponse()
381         self.assertEqual(response.status, 200)
382         retrieved = json_loads(response.read().decode())
383         rewrite_history_keys_in(retrieved)
384         self.assertEqual(expected, retrieved)