home · contact · privacy
25cc9ba1e79d663ec692570f6f4c1fce4eaaf911
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 class TestCaseSansDB(TestCase):
21     """Tests requiring no DB setup."""
22     checked_class: Any
23     do_id_test: bool = False
24     default_init_args: list[Any] = []
25     versioned_defaults_to_test: dict[str, str | float] = {}
26
27     def test_id_setting(self) -> None:
28         """Test .id_ being set and its legal range being enforced."""
29         if not self.do_id_test:
30             return
31         with self.assertRaises(HandledException):
32             self.checked_class(0, *self.default_init_args)
33         obj = self.checked_class(5, *self.default_init_args)
34         self.assertEqual(obj.id_, 5)
35
36     def test_versioned_defaults(self) -> None:
37         """Test defaults of VersionedAttributes."""
38         if len(self.versioned_defaults_to_test) == 0:
39             return
40         obj = self.checked_class(1, *self.default_init_args)
41         for k, v in self.versioned_defaults_to_test.items():
42             self.assertEqual(getattr(obj, k).newest, v)
43
44
45 class TestCaseWithDB(TestCase):
46     """Module tests not requiring DB setup."""
47     checked_class: Any
48     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49     default_init_kwargs: dict[str, Any] = {}
50     test_versioneds: dict[str, type] = {}
51
52     def setUp(self) -> None:
53         Condition.empty_cache()
54         Day.empty_cache()
55         Process.empty_cache()
56         ProcessStep.empty_cache()
57         Todo.empty_cache()
58         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59         self.db_conn = DatabaseConnection(self.db_file)
60
61     def tearDown(self) -> None:
62         self.db_conn.close()
63         remove_file(self.db_file.path)
64
65     @staticmethod
66     def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67         def wrapper(self: TestCaseWithDB) -> None:
68             if hasattr(self, 'checked_class'):
69                 f(self)
70         return wrapper
71
72     def _load_from_db(self, id_: int | str) -> list[object]:
73         db_found: list[object] = []
74         for row in self.db_conn.row_where(self.checked_class.table_name,
75                                           'id', id_):
76             db_found += [self.checked_class.from_table_row(self.db_conn,
77                                                            row)]
78         return db_found
79
80     def _change_obj(self, obj: object) -> str:
81         attr_name: str = self.checked_class.to_save[-1]
82         attr = getattr(obj, attr_name)
83         new_attr: str | int | float | bool
84         if isinstance(attr, (int, float)):
85             new_attr = attr + 1
86         elif isinstance(attr, str):
87             new_attr = attr + '_'
88         elif isinstance(attr, bool):
89             new_attr = not attr
90         setattr(obj, attr_name, new_attr)
91         return attr_name
92
93     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
94         """Test both cache and DB equal content."""
95         expected_cache = {}
96         for item in content:
97             expected_cache[item.id_] = item
98         self.assertEqual(self.checked_class.get_cache(), expected_cache)
99         hashes_content = [hash(x) for x in content]
100         db_found: list[Any] = []
101         for item in content:
102             assert isinstance(item.id_, type(self.default_ids[0]))
103             db_found += self._load_from_db(item.id_)
104         hashes_db_found = [hash(x) for x in db_found]
105         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
106
107     @_within_checked_class
108     def test_saving_versioned(self) -> None:
109         """Test storage and initialization of versioned attributes."""
110         def retrieve_attr_vals() -> list[object]:
111             attr_vals_saved: list[object] = []
112             assert hasattr(retrieved, 'id_')
113             for row in self.db_conn.row_where(attr.table_name, 'parent',
114                                               retrieved.id_):
115                 attr_vals_saved += [row[2]]
116             return attr_vals_saved
117         for attr_name, type_ in self.test_versioneds.items():
118             # fail saving attributes on non-saved owner
119             owner = self.checked_class(None, **self.default_init_kwargs)
120             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
121             attr = getattr(owner, attr_name)
122             attr.set(vals[0])
123             attr.set(vals[1])
124             with self.assertRaises(NotFoundException):
125                 attr.save(self.db_conn)
126             owner.save(self.db_conn)
127             # check stored attribute is as expected
128             retrieved = self._load_from_db(owner.id_)[0]
129             attr = getattr(retrieved, attr_name)
130             self.assertEqual(sorted(attr.history.values()), vals)
131             # check owner.save() created entries in attr table
132             attr_vals_saved = retrieve_attr_vals()
133             self.assertEqual(vals, attr_vals_saved)
134             # check setting new val to attr inconsequential to DB without save
135             attr.set(vals[0])
136             attr_vals_saved = retrieve_attr_vals()
137             self.assertEqual(vals, attr_vals_saved)
138             # check save finally adds new val
139             attr.save(self.db_conn)
140             attr_vals_saved = retrieve_attr_vals()
141             self.assertEqual(vals + [vals[0]], attr_vals_saved)
142
143     @_within_checked_class
144     def test_saving_and_caching(self) -> None:
145         """Test effects of .cache() and .save()."""
146         id1 = self.default_ids[0]
147         # check failure to cache without ID (if None-ID input possible)
148         if isinstance(id1, int):
149             obj0 = self.checked_class(None, **self.default_init_kwargs)
150             with self.assertRaises(HandledException):
151                 obj0.cache()
152         # check mere object init itself doesn't even store in cache
153         obj1 = self.checked_class(id1, **self.default_init_kwargs)
154         self.assertEqual(self.checked_class.get_cache(), {})
155         # check .cache() fills cache, but not DB
156         obj1.cache()
157         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
158         db_found = self._load_from_db(id1)
159         self.assertEqual(db_found, [])
160         # check .save() sets ID (for int IDs), updates cache, and fills DB
161         # (expect ID to be set to id1, despite obj1 already having that as ID:
162         # it's generated by cursor.lastrowid on the DB table, and with obj1
163         # not written there, obj2 should get it first!)
164         id_input = None if isinstance(id1, int) else id1
165         obj2 = self.checked_class(id_input, **self.default_init_kwargs)
166         obj2.save(self.db_conn)
167         obj2_hash = hash(obj2)
168         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
169         db_found += self._load_from_db(id1)
170         self.assertEqual([hash(o) for o in db_found], [obj2_hash])
171         # check we cannot overwrite obj2 with obj1 despite its same ID,
172         # since it has disappeared now
173         with self.assertRaises(HandledException):
174             obj1.save(self.db_conn)
175
176     @_within_checked_class
177     def test_by_id(self) -> None:
178         """Test .by_id()."""
179         id1, id2, _ = self.default_ids
180         # check failure if not yet saved
181         obj1 = self.checked_class(id1, **self.default_init_kwargs)
182         with self.assertRaises(NotFoundException):
183             self.checked_class.by_id(self.db_conn, id1)
184         # check identity of cached and retrieved
185         obj1.cache()
186         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
187         # check identity of saved and retrieved
188         obj2 = self.checked_class(id2, **self.default_init_kwargs)
189         obj2.save(self.db_conn)
190         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
191
192     @_within_checked_class
193     def test_by_id_or_create(self) -> None:
194         """Test .by_id_or_create."""
195         # check .by_id_or_create acts like normal instantiation (sans saving)
196         id_ = self.default_ids[0]
197         if not self.checked_class.can_create_by_id:
198             with self.assertRaises(HandledException):
199                 self.checked_class.by_id_or_create(self.db_conn, id_)
200         # check .by_id_or_create fails if wrong class
201         else:
202             by_id_created = self.checked_class.by_id_or_create(self.db_conn,
203                                                                id_)
204             with self.assertRaises(NotFoundException):
205                 self.checked_class.by_id(self.db_conn, id_)
206             self.assertEqual(self.checked_class(id_), by_id_created)
207
208     @_within_checked_class
209     def test_from_table_row(self) -> None:
210         """Test .from_table_row() properly reads in class directly from DB."""
211         id_ = self.default_ids[0]
212         obj = self.checked_class(id_, **self.default_init_kwargs)
213         obj.save(self.db_conn)
214         assert isinstance(obj.id_, type(id_))
215         for row in self.db_conn.row_where(self.checked_class.table_name,
216                                           'id', obj.id_):
217             # check .from_table_row reproduces state saved, no matter if obj
218             # later changed (with caching even)
219             hash_original = hash(obj)
220             attr_name = self._change_obj(obj)
221             obj.cache()
222             to_cmp = getattr(obj, attr_name)
223             retrieved = self.checked_class.from_table_row(self.db_conn, row)
224             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
225             self.assertEqual(hash_original, hash(retrieved))
226             # check cache contains what .from_table_row just produced
227             self.assertEqual({retrieved.id_: retrieved},
228                              self.checked_class.get_cache())
229         # check .from_table_row also reads versioned attributes from DB
230         for attr_name, type_ in self.test_versioneds.items():
231             owner = self.checked_class(None)
232             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
233             attr = getattr(owner, attr_name)
234             attr.set(vals[0])
235             attr.set(vals[1])
236             owner.save(self.db_conn)
237             for row in self.db_conn.row_where(owner.table_name, 'id',
238                                               owner.id_):
239                 retrieved = owner.__class__.from_table_row(self.db_conn, row)
240                 attr = getattr(retrieved, attr_name)
241                 self.assertEqual(sorted(attr.history.values()), vals)
242
243     @_within_checked_class
244     def test_all(self) -> None:
245         """Test .all() and its relation to cache and savings."""
246         id_1, id_2, id_3 = self.default_ids
247         item1 = self.checked_class(id_1, **self.default_init_kwargs)
248         item2 = self.checked_class(id_2, **self.default_init_kwargs)
249         item3 = self.checked_class(id_3, **self.default_init_kwargs)
250         # check .all() returns empty list on un-cached items
251         self.assertEqual(self.checked_class.all(self.db_conn), [])
252         # check that all() shows only cached/saved items
253         item1.cache()
254         item3.save(self.db_conn)
255         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
256                          sorted([item1, item3]))
257         item2.save(self.db_conn)
258         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
259                          sorted([item1, item2, item3]))
260
261     @_within_checked_class
262     def test_singularity(self) -> None:
263         """Test pointers made for single object keep pointing to it."""
264         id1 = self.default_ids[0]
265         obj = self.checked_class(id1, **self.default_init_kwargs)
266         obj.save(self.db_conn)
267         # change object, expect retrieved through .by_id to carry change
268         attr_name = self._change_obj(obj)
269         new_attr = getattr(obj, attr_name)
270         retrieved = self.checked_class.by_id(self.db_conn, id1)
271         self.assertEqual(new_attr, getattr(retrieved, attr_name))
272
273     @_within_checked_class
274     def test_versioned_singularity_title(self) -> None:
275         """Test singularity of VersionedAttributes on saving (with .title)."""
276         if 'title' in self.test_versioneds:
277             obj = self.checked_class(None)
278             obj.save(self.db_conn)
279             assert isinstance(obj.id_, int)
280             # change obj, expect retrieved through .by_id to carry change
281             obj.title.set('named')
282             retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
283             self.assertEqual(obj.title.history, retrieved.title.history)
284
285     @_within_checked_class
286     def test_remove(self) -> None:
287         """Test .remove() effects on DB and cache."""
288         id_ = self.default_ids[0]
289         obj = self.checked_class(id_, **self.default_init_kwargs)
290         # check removal only works after saving
291         with self.assertRaises(HandledException):
292             obj.remove(self.db_conn)
293         obj.save(self.db_conn)
294         obj.remove(self.db_conn)
295         # check access to obj fails after removal
296         with self.assertRaises(HandledException):
297             print(obj.id_)
298         # check DB and cache now empty
299         self.check_identity_with_cache_and_db([])
300
301
302 class TestCaseWithServer(TestCaseWithDB):
303     """Module tests against our HTTP server/handler (and database)."""
304
305     def setUp(self) -> None:
306         super().setUp()
307         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
308         self.server_thread = Thread(target=self.httpd.serve_forever)
309         self.server_thread.daemon = True
310         self.server_thread.start()
311         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
312                                    self.httpd.server_address[1])
313         self.httpd.set_json_mode()
314
315     def tearDown(self) -> None:
316         self.httpd.shutdown()
317         self.httpd.server_close()
318         self.server_thread.join()
319         super().tearDown()
320
321     def check_redirect(self, target: str) -> None:
322         """Check that self.conn answers with a 302 redirect to target."""
323         response = self.conn.getresponse()
324         self.assertEqual(response.status, 302)
325         self.assertEqual(response.getheader('Location'), target)
326
327     def check_get(self, target: str, expected_code: int) -> None:
328         """Check that a GET to target yields expected_code."""
329         self.conn.request('GET', target)
330         self.assertEqual(self.conn.getresponse().status, expected_code)
331
332     def check_post(self, data: Mapping[str, object], target: str,
333                    expected_code: int, redirect_location: str = '') -> None:
334         """Check that POST of data to target yields expected_code."""
335         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
336         headers = {'Content-Type': 'application/x-www-form-urlencoded',
337                    'Content-Length': str(len(encoded_form_data))}
338         self.conn.request('POST', target,
339                           body=encoded_form_data, headers=headers)
340         if 302 == expected_code:
341             if redirect_location == '':
342                 redirect_location = target
343             self.check_redirect(redirect_location)
344         else:
345             self.assertEqual(self.conn.getresponse().status, expected_code)
346
347     def check_get_defaults(self, path: str) -> None:
348         """Some standard model paths to test."""
349         self.check_get(path, 200)
350         self.check_get(f'{path}?id=', 200)
351         self.check_get(f'{path}?id=foo', 400)
352         self.check_get(f'/{path}?id=0', 500)
353         self.check_get(f'{path}?id=1', 200)
354
355     def post_process(self, id_: int = 1,
356                      form_data: dict[str, Any] | None = None
357                      ) -> dict[str, Any]:
358         """POST basic Process."""
359         if not form_data:
360             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
361         self.check_post(form_data, f'/process?id={id_}', 302,
362                         f'/process?id={id_}')
363         return form_data
364
365     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
366         """Compare JSON on GET path with expected.
367
368         To simplify comparison of VersionedAttribute histories, transforms
369         timestamp keys of VersionedAttribute history keys into integers
370         counting chronologically forward from 0.
371         """
372         def rewrite_history_keys_in(item: Any) -> Any:
373             if isinstance(item, dict):
374                 if '_versioned' in item.keys():
375                     for k in item['_versioned']:
376                         vals = item['_versioned'][k].values()
377                         history = {}
378                         for i, val in enumerate(vals):
379                             history[i] = val
380                         item['_versioned'][k] = history
381                 for k in list(item.keys()):
382                     rewrite_history_keys_in(item[k])
383             elif isinstance(item, list):
384                 item[:] = [rewrite_history_keys_in(i) for i in item]
385             return item
386         self.conn.request('GET', path)
387         response = self.conn.getresponse()
388         self.assertEqual(response.status, 200)
389         retrieved = json_loads(response.read().decode())
390         rewrite_history_keys_in(retrieved)
391         self.assertEqual(expected, retrieved)