home · contact · privacy
0925b2d5b2adc0e415293526a4b01c04fc42b178
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 class TestCaseSansDB(TestCase):
21     """Tests requiring no DB setup."""
22     checked_class: Any
23     do_id_test: bool = False
24     default_init_args: list[Any] = []
25     versioned_defaults_to_test: dict[str, str | float] = {}
26
27     def test_id_setting(self) -> None:
28         """Test .id_ being set and its legal range being enforced."""
29         if not self.do_id_test:
30             return
31         with self.assertRaises(HandledException):
32             self.checked_class(0, *self.default_init_args)
33         obj = self.checked_class(5, *self.default_init_args)
34         self.assertEqual(obj.id_, 5)
35
36     def test_versioned_defaults(self) -> None:
37         """Test defaults of VersionedAttributes."""
38         if len(self.versioned_defaults_to_test) == 0:
39             return
40         obj = self.checked_class(1, *self.default_init_args)
41         for k, v in self.versioned_defaults_to_test.items():
42             self.assertEqual(getattr(obj, k).newest, v)
43
44
45 class TestCaseWithDB(TestCase):
46     """Module tests not requiring DB setup."""
47     checked_class: Any
48     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49     default_init_kwargs: dict[str, Any] = {}
50     test_versioneds: dict[str, type] = {}
51
52     def setUp(self) -> None:
53         Condition.empty_cache()
54         Day.empty_cache()
55         Process.empty_cache()
56         ProcessStep.empty_cache()
57         Todo.empty_cache()
58         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59         self.db_conn = DatabaseConnection(self.db_file)
60
61     def tearDown(self) -> None:
62         self.db_conn.close()
63         remove_file(self.db_file.path)
64
65     @staticmethod
66     def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67         def wrapper(self: TestCaseWithDB) -> None:
68             if hasattr(self, 'checked_class'):
69                 f(self)
70         return wrapper
71
72     def _load_from_db(self, id_: int | str) -> list[object]:
73         db_found: list[object] = []
74         for row in self.db_conn.row_where(self.checked_class.table_name,
75                                           'id', id_):
76             db_found += [self.checked_class.from_table_row(self.db_conn,
77                                                            row)]
78         return db_found
79
80     def _change_obj(self, obj: object) -> str:
81         attr_name: str = self.checked_class.to_save[-1]
82         attr = getattr(obj, attr_name)
83         new_attr: str | int | float | bool
84         if isinstance(attr, (int, float)):
85             new_attr = attr + 1
86         elif isinstance(attr, str):
87             new_attr = attr + '_'
88         elif isinstance(attr, bool):
89             new_attr = not attr
90         setattr(obj, attr_name, new_attr)
91         return attr_name
92
93     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
94         """Test both cache and DB equal content."""
95         expected_cache = {}
96         for item in content:
97             expected_cache[item.id_] = item
98         self.assertEqual(self.checked_class.get_cache(), expected_cache)
99         hashes_content = [hash(x) for x in content]
100         db_found: list[Any] = []
101         for item in content:
102             assert isinstance(item.id_, type(self.default_ids[0]))
103             db_found += self._load_from_db(item.id_)
104         hashes_db_found = [hash(x) for x in db_found]
105         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
106
107     @_within_checked_class
108     def test_saving_versioned(self) -> None:
109         """Test storage and initialization of versioned attributes."""
110         def retrieve_attr_vals() -> list[object]:
111             attr_vals_saved: list[object] = []
112             assert hasattr(retrieved, 'id_')
113             for row in self.db_conn.row_where(attr.table_name, 'parent',
114                                               retrieved.id_):
115                 attr_vals_saved += [row[2]]
116             return attr_vals_saved
117         for attr_name, type_ in self.test_versioneds.items():
118             # fail saving attributes on non-saved owner
119             owner = self.checked_class(None, **self.default_init_kwargs)
120             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
121             attr = getattr(owner, attr_name)
122             attr.set(vals[0])
123             attr.set(vals[1])
124             with self.assertRaises(NotFoundException):
125                 attr.save(self.db_conn)
126             owner.save(self.db_conn)
127             # check stored attribute is as expected
128             retrieved = self._load_from_db(owner.id_)[0]
129             attr = getattr(retrieved, attr_name)
130             self.assertEqual(sorted(attr.history.values()), vals)
131             # check owner.save() created entries in attr table
132             attr_vals_saved = retrieve_attr_vals()
133             self.assertEqual(vals, attr_vals_saved)
134             # check setting new val to attr inconsequential to DB without save
135             attr.set(vals[0])
136             attr_vals_saved = retrieve_attr_vals()
137             self.assertEqual(vals, attr_vals_saved)
138             # check save finally adds new val
139             attr.save(self.db_conn)
140             attr_vals_saved = retrieve_attr_vals()
141             self.assertEqual(vals + [vals[0]], attr_vals_saved)
142
143     @_within_checked_class
144     def test_saving_and_caching(self) -> None:
145         """Test effects of .cache() and .save()."""
146         id1 = self.default_ids[0]
147         # check failure to cache without ID (if None-ID input possible)
148         if isinstance(id1, int):
149             obj0 = self.checked_class(None, **self.default_init_kwargs)
150             with self.assertRaises(HandledException):
151                 obj0.cache()
152         # check mere object init itself doesn't even store in cache
153         obj1 = self.checked_class(id1, **self.default_init_kwargs)
154         self.assertEqual(self.checked_class.get_cache(), {})
155         # check .cache() fills cache, but not DB
156         obj1.cache()
157         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
158         db_found = self._load_from_db(id1)
159         self.assertEqual(db_found, [])
160         # check .save() sets ID (for int IDs), updates cache, and fills DB
161         # (expect ID to be set to id1, despite obj1 already having that as ID:
162         # it's generated by cursor.lastrowid on the DB table, and with obj1
163         # not written there, obj2 should get it first!)
164         id_input = None if isinstance(id1, int) else id1
165         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
166         obj2.save(self.db_conn)
167         obj2_hash = hash(obj2)
168         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
169         db_found += self._load_from_db(id1)
170         self.assertEqual([hash(o) for o in db_found], [obj2_hash])
171         # check we cannot overwrite obj2 with obj1 despite its same ID,
172         # since it has disappeared now
173         with self.assertRaises(HandledException):
174             obj1.save(self.db_conn)
175
176     @_within_checked_class
177     def test_by_id(self) -> None:
178         """Test .by_id()."""
179         id1, id2, _ = self.default_ids
180         # check failure if not yet saved
181         obj1 = self.checked_class(id1, **self.default_init_kwargs)
182         with self.assertRaises(NotFoundException):
183             self.checked_class.by_id(self.db_conn, id1)
184         # check identity of cached and retrieved
185         obj1.cache()
186         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
187         # check identity of saved and retrieved
188         obj2 = self.checked_class(id2, **self.default_init_kwargs)
189         obj2.save(self.db_conn)
190         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
191
192     @_within_checked_class
193     def test_by_id_or_create(self) -> None:
194         """Test .by_id_or_create."""
195         # check .by_id_or_create fails if wrong class
196         if not self.checked_class.can_create_by_id:
197             with self.assertRaises(HandledException):
198                 self.checked_class.by_id_or_create(self.db_conn, None)
199             return
200         # check ID input of None creates, on saving, ID=1,2,… for int IDs
201         if isinstance(self.default_ids[0], int):
202             for n in range(2):
203                 item = self.checked_class.by_id_or_create(self.db_conn, None)
204                 self.assertEqual(item.id_, None)
205                 item.save(self.db_conn)
206                 self.assertEqual(item.id_, n+1)
207         # check .by_id_or_create acts like normal instantiation (sans saving)
208         id_ = self.default_ids[2]
209         item = self.checked_class.by_id_or_create(self.db_conn, id_)
210         self.assertEqual(item.id_, id_)
211         with self.assertRaises(NotFoundException):
212             self.checked_class.by_id(self.db_conn, item.id_)
213         self.assertEqual(self.checked_class(item.id_), item)
214
215     @_within_checked_class
216     def test_from_table_row(self) -> None:
217         """Test .from_table_row() properly reads in class directly from DB."""
218         id_ = self.default_ids[0]
219         obj = self.checked_class(id_, **self.default_init_kwargs)
220         obj.save(self.db_conn)
221         assert isinstance(obj.id_, type(id_))
222         for row in self.db_conn.row_where(self.checked_class.table_name,
223                                           'id', obj.id_):
224             # check .from_table_row reproduces state saved, no matter if obj
225             # later changed (with caching even)
226             hash_original = hash(obj)
227             attr_name = self._change_obj(obj)
228             obj.cache()
229             to_cmp = getattr(obj, attr_name)
230             retrieved = self.checked_class.from_table_row(self.db_conn, row)
231             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
232             self.assertEqual(hash_original, hash(retrieved))
233             # check cache contains what .from_table_row just produced
234             self.assertEqual({retrieved.id_: retrieved},
235                              self.checked_class.get_cache())
236         # check .from_table_row also reads versioned attributes from DB
237         for attr_name, type_ in self.test_versioneds.items():
238             owner = self.checked_class(None)
239             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
240             attr = getattr(owner, attr_name)
241             attr.set(vals[0])
242             attr.set(vals[1])
243             owner.save(self.db_conn)
244             for row in self.db_conn.row_where(owner.table_name, 'id',
245                                               owner.id_):
246                 retrieved = owner.__class__.from_table_row(self.db_conn, row)
247                 attr = getattr(retrieved, attr_name)
248                 self.assertEqual(sorted(attr.history.values()), vals)
249
250     @_within_checked_class
251     def test_all(self) -> None:
252         """Test .all() and its relation to cache and savings."""
253         id_1, id_2, id_3 = self.default_ids
254         item1 = self.checked_class(id_1, **self.default_init_kwargs)
255         item2 = self.checked_class(id_2, **self.default_init_kwargs)
256         item3 = self.checked_class(id_3, **self.default_init_kwargs)
257         # check .all() returns empty list on un-cached items
258         self.assertEqual(self.checked_class.all(self.db_conn), [])
259         # check that all() shows only cached/saved items
260         item1.cache()
261         item3.save(self.db_conn)
262         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
263                          sorted([item1, item3]))
264         item2.save(self.db_conn)
265         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
266                          sorted([item1, item2, item3]))
267
268     @_within_checked_class
269     def test_singularity(self) -> None:
270         """Test pointers made for single object keep pointing to it."""
271         id1 = self.default_ids[0]
272         obj = self.checked_class(id1, **self.default_init_kwargs)
273         obj.save(self.db_conn)
274         # change object, expect retrieved through .by_id to carry change
275         attr_name = self._change_obj(obj)
276         new_attr = getattr(obj, attr_name)
277         retrieved = self.checked_class.by_id(self.db_conn, id1)
278         self.assertEqual(new_attr, getattr(retrieved, attr_name))
279
280     @_within_checked_class
281     def test_versioned_singularity_title(self) -> None:
282         """Test singularity of VersionedAttributes on saving (with .title)."""
283         if 'title' in self.test_versioneds:
284             obj = self.checked_class(None)
285             obj.save(self.db_conn)
286             assert isinstance(obj.id_, int)
287             # change obj, expect retrieved through .by_id to carry change
288             obj.title.set('named')
289             retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
290             self.assertEqual(obj.title.history, retrieved.title.history)
291
292     @_within_checked_class
293     def test_remove(self) -> None:
294         """Test .remove() effects on DB and cache."""
295         id_ = self.default_ids[0]
296         obj = self.checked_class(id_, **self.default_init_kwargs)
297         # check removal only works after saving
298         with self.assertRaises(HandledException):
299             obj.remove(self.db_conn)
300         obj.save(self.db_conn)
301         obj.remove(self.db_conn)
302         # check access to obj fails after removal
303         with self.assertRaises(HandledException):
304             print(obj.id_)
305         # check DB and cache now empty
306         self.check_identity_with_cache_and_db([])
307
308
309 class TestCaseWithServer(TestCaseWithDB):
310     """Module tests against our HTTP server/handler (and database)."""
311
312     def setUp(self) -> None:
313         super().setUp()
314         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
315         self.server_thread = Thread(target=self.httpd.serve_forever)
316         self.server_thread.daemon = True
317         self.server_thread.start()
318         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
319                                    self.httpd.server_address[1])
320         self.httpd.set_json_mode()
321
322     def tearDown(self) -> None:
323         self.httpd.shutdown()
324         self.httpd.server_close()
325         self.server_thread.join()
326         super().tearDown()
327
328     def check_redirect(self, target: str) -> None:
329         """Check that self.conn answers with a 302 redirect to target."""
330         response = self.conn.getresponse()
331         self.assertEqual(response.status, 302)
332         self.assertEqual(response.getheader('Location'), target)
333
334     def check_get(self, target: str, expected_code: int) -> None:
335         """Check that a GET to target yields expected_code."""
336         self.conn.request('GET', target)
337         self.assertEqual(self.conn.getresponse().status, expected_code)
338
339     def check_post(self, data: Mapping[str, object], target: str,
340                    expected_code: int, redirect_location: str = '') -> None:
341         """Check that POST of data to target yields expected_code."""
342         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
343         headers = {'Content-Type': 'application/x-www-form-urlencoded',
344                    'Content-Length': str(len(encoded_form_data))}
345         self.conn.request('POST', target,
346                           body=encoded_form_data, headers=headers)
347         if 302 == expected_code:
348             if redirect_location == '':
349                 redirect_location = target
350             self.check_redirect(redirect_location)
351         else:
352             self.assertEqual(self.conn.getresponse().status, expected_code)
353
354     def check_get_defaults(self, path: str) -> None:
355         """Some standard model paths to test."""
356         self.check_get(path, 200)
357         self.check_get(f'{path}?id=', 200)
358         self.check_get(f'{path}?id=foo', 400)
359         self.check_get(f'/{path}?id=0', 500)
360         self.check_get(f'{path}?id=1', 200)
361
362     def post_process(self, id_: int = 1,
363                      form_data: dict[str, Any] | None = None
364                      ) -> dict[str, Any]:
365         """POST basic Process."""
366         if not form_data:
367             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
368         self.check_post(form_data, f'/process?id={id_}', 302,
369                         f'/process?id={id_}')
370         return form_data
371
372     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
373         """Compare JSON on GET path with expected.
374
375         To simplify comparison of VersionedAttribute histories, transforms
376         timestamp keys of VersionedAttribute history keys into integers
377         counting chronologically forward from 0.
378         """
379         def rewrite_history_keys_in(item: Any) -> Any:
380             if isinstance(item, dict):
381                 if '_versioned' in item.keys():
382                     for k in item['_versioned']:
383                         vals = item['_versioned'][k].values()
384                         history = {}
385                         for i, val in enumerate(vals):
386                             history[i] = val
387                         item['_versioned'][k] = history
388                 for k in list(item.keys()):
389                     rewrite_history_keys_in(item[k])
390             elif isinstance(item, list):
391                 item[:] = [rewrite_history_keys_in(i) for i in item]
392             return item
393         self.conn.request('GET', path)
394         response = self.conn.getresponse()
395         self.assertEqual(response.status, 200)
396         retrieved = json_loads(response.read().decode())
397         rewrite_history_keys_in(retrieved)
398         self.assertEqual(expected, retrieved)