home · contact · privacy
a4f29ff45263cc2716f2299fa3f6d36d973d03a6
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from json import loads as json_loads
6 from urllib.parse import urlencode
7 from uuid import uuid4
8 from os import remove as remove_file
9 from typing import Mapping, Any
10 from plomtask.db import DatabaseFile, DatabaseConnection
11 from plomtask.http import TaskHandler, TaskServer
12 from plomtask.processes import Process, ProcessStep
13 from plomtask.conditions import Condition
14 from plomtask.days import Day
15 from plomtask.todos import Todo
16 from plomtask.exceptions import NotFoundException, HandledException
17
18
19 class TestCaseSansDB(TestCase):
20     """Tests requiring no DB setup."""
21     checked_class: Any
22     do_id_test: bool = False
23     default_init_args: list[Any] = []
24     versioned_defaults_to_test: dict[str, str | float] = {}
25
26     def test_id_setting(self) -> None:
27         """Test .id_ being set and its legal range being enforced."""
28         if not self.do_id_test:
29             return
30         with self.assertRaises(HandledException):
31             self.checked_class(0, *self.default_init_args)
32         obj = self.checked_class(5, *self.default_init_args)
33         self.assertEqual(obj.id_, 5)
34
35     def test_versioned_defaults(self) -> None:
36         """Test defaults of VersionedAttributes."""
37         if len(self.versioned_defaults_to_test) == 0:
38             return
39         obj = self.checked_class(1, *self.default_init_args)
40         for k, v in self.versioned_defaults_to_test.items():
41             self.assertEqual(getattr(obj, k).newest, v)
42
43
44 class TestCaseWithDB(TestCase):
45     """Module tests not requiring DB setup."""
46     checked_class: Any
47     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
48     default_init_kwargs: dict[str, Any] = {}
49     test_versioneds: dict[str, type] = {}
50
51     def setUp(self) -> None:
52         Condition.empty_cache()
53         Day.empty_cache()
54         Process.empty_cache()
55         ProcessStep.empty_cache()
56         Todo.empty_cache()
57         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
58         self.db_conn = DatabaseConnection(self.db_file)
59
60     def tearDown(self) -> None:
61         self.db_conn.close()
62         remove_file(self.db_file.path)
63
64     def test_saving_and_caching(self) -> None:
65         """Test storage and initialization of instances and attributes."""
66         if not hasattr(self, 'checked_class'):
67             return
68         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69         obj = self.checked_class(None, **self.default_init_kwargs)
70         obj.save(self.db_conn)
71         self.assertEqual(obj.id_, 2)
72         for k, v in self.test_versioneds.items():
73             self.check_saving_of_versioned(k, v)
74
75     def check_storage(self, content: list[Any]) -> None:
76         """Test cache and DB equal content."""
77         expected_cache = {}
78         for item in content:
79             expected_cache[item.id_] = item
80         self.assertEqual(self.checked_class.get_cache(), expected_cache)
81         hashes_content = [hash(x) for x in content]
82         db_found: list[Any] = []
83         for item in content:
84             assert isinstance(item.id_, type(self.default_ids[0]))
85             for row in self.db_conn.row_where(self.checked_class.table_name,
86                                               'id', item.id_):
87                 db_found += [self.checked_class.from_table_row(self.db_conn,
88                                                                row)]
89         hashes_db_found = [hash(x) for x in db_found]
90         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
91
92     def check_saving_and_caching(self, **kwargs: Any) -> None:
93         """Test instance.save in its core without relations."""
94         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
95         # check object init itself doesn't store anything yet
96         self.check_storage([])
97         # check saving sets core attributes properly
98         obj.save(self.db_conn)
99         for key, value in kwargs.items():
100             self.assertEqual(getattr(obj, key), value)
101         # check saving stored properly in cache and DB
102         self.check_storage([obj])
103
104     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
105         """Test owner's versioned attributes."""
106         owner = self.checked_class(None)
107         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
108         attr = getattr(owner, attr_name)
109         attr.set(vals[0])
110         attr.set(vals[1])
111         owner.save(self.db_conn)
112         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
113         attr = getattr(retrieved, attr_name)
114         self.assertEqual(sorted(attr.history.values()), vals)
115
116     def check_by_id(self) -> None:
117         """Test .by_id(), including creation."""
118         # check failure if not yet saved
119         id1, id2 = self.default_ids[0], self.default_ids[1]
120         obj = self.checked_class(id1)  # pylint: disable=not-callable
121         with self.assertRaises(NotFoundException):
122             self.checked_class.by_id(self.db_conn, id1)
123         # check identity of saved and retrieved
124         obj.save(self.db_conn)
125         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
126         # check create=True acts like normal instantiation (sans saving)
127         by_id_created = self.checked_class.by_id(self.db_conn, id2,
128                                                  create=True)
129         # pylint: disable=not-callable
130         self.assertEqual(self.checked_class(id2), by_id_created)
131         self.check_storage([obj])
132
133     def test_from_table_row(self) -> None:
134         """Test .from_table_row() properly reads in class from DB."""
135         if not hasattr(self, 'checked_class'):
136             return
137         id_ = self.default_ids[0]
138         obj = self.checked_class(id_, **self.default_init_kwargs)
139         obj.save(self.db_conn)
140         assert isinstance(obj.id_, type(self.default_ids[0]))
141         for row in self.db_conn.row_where(self.checked_class.table_name,
142                                           'id', obj.id_):
143             # check .from_table_row reproduces state saved, no matter if obj
144             # later changed (with caching even)
145             hash_original = hash(obj)
146             attr_name = self.checked_class.to_save[-1]
147             attr = getattr(obj, attr_name)
148             if isinstance(attr, (int, float)):
149                 setattr(obj, attr_name, attr + 1)
150             elif isinstance(attr, str):
151                 setattr(obj, attr_name, attr + "_")
152             elif isinstance(attr, bool):
153                 setattr(obj, attr_name, not attr)
154             obj.cache()
155             to_cmp = getattr(obj, attr_name)
156             retrieved = self.checked_class.from_table_row(self.db_conn, row)
157             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
158             self.assertEqual(hash_original, hash(retrieved))
159             # check cache contains what .from_table_row just produced
160             self.assertEqual({retrieved.id_: retrieved},
161                              self.checked_class.get_cache())
162
163     def check_versioned_from_table_row(self, attr_name: str,
164                                        type_: type) -> None:
165         """Test .from_table_row() reads versioned attributes from DB."""
166         owner = self.checked_class(None)
167         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
168         attr = getattr(owner, attr_name)
169         attr.set(vals[0])
170         attr.set(vals[1])
171         owner.save(self.db_conn)
172         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
173             retrieved = owner.__class__.from_table_row(self.db_conn, row)
174             attr = getattr(retrieved, attr_name)
175             self.assertEqual(sorted(attr.history.values()), vals)
176
177     def check_all(self) -> tuple[Any, Any, Any]:
178         """Test .all()."""
179         # pylint: disable=not-callable
180         item1 = self.checked_class(self.default_ids[0])
181         item2 = self.checked_class(self.default_ids[1])
182         item3 = self.checked_class(self.default_ids[2])
183         # check pre-save .all() returns empty list
184         self.assertEqual(self.checked_class.all(self.db_conn), [])
185         # check that all() shows all saved, but no unsaved items
186         item1.save(self.db_conn)
187         item3.save(self.db_conn)
188         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
189                          sorted([item1, item3]))
190         item2.save(self.db_conn)
191         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
192                          sorted([item1, item2, item3]))
193         return item1, item2, item3
194
195     def test_singularity(self)-> None:
196         """Test pointers made for single object keep pointing to it."""
197         if not hasattr(self, 'checked_class'):
198             return
199         id1 = self.default_ids[0]
200         obj = self.checked_class(id1, **self.default_init_kwargs)
201         obj.save(self.db_conn)
202         attr_name = self.checked_class.to_save[-1]
203         attr = getattr(obj, attr_name)
204         if isinstance(attr, (int, float)):
205             new_attr = attr + 1
206         elif isinstance(attr, str):
207             new_attr = attr + '_'
208         elif isinstance(attr, bool):
209             new_attr = not attr
210         setattr(obj, attr_name, new_attr)
211         retrieved = self.checked_class.by_id(self.db_conn, id1)
212         self.assertEqual(new_attr, getattr(retrieved, attr_name))
213
214     def check_versioned_singularity(self) -> None:
215         """Test singularity of VersionedAttributes on saving (with .title)."""
216         obj = self.checked_class(None)  # pylint: disable=not-callable
217         obj.save(self.db_conn)
218         assert isinstance(obj.id_, int)
219         obj.title.set('named')
220         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
221         self.assertEqual(obj.title.history, retrieved.title.history)
222
223     def check_remove(self, *args: Any) -> None:
224         """Test .remove() effects on DB and cache."""
225         id_ = self.default_ids[0]
226         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
227         with self.assertRaises(HandledException):
228             obj.remove(self.db_conn)
229         obj.save(self.db_conn)
230         obj.remove(self.db_conn)
231         self.check_storage([])
232
233
234 class TestCaseWithServer(TestCaseWithDB):
235     """Module tests against our HTTP server/handler (and database)."""
236
237     def setUp(self) -> None:
238         super().setUp()
239         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
240         self.server_thread = Thread(target=self.httpd.serve_forever)
241         self.server_thread.daemon = True
242         self.server_thread.start()
243         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
244                                    self.httpd.server_address[1])
245         self.httpd.set_json_mode()
246
247     def tearDown(self) -> None:
248         self.httpd.shutdown()
249         self.httpd.server_close()
250         self.server_thread.join()
251         super().tearDown()
252
253     def check_redirect(self, target: str) -> None:
254         """Check that self.conn answers with a 302 redirect to target."""
255         response = self.conn.getresponse()
256         self.assertEqual(response.status, 302)
257         self.assertEqual(response.getheader('Location'), target)
258
259     def check_get(self, target: str, expected_code: int) -> None:
260         """Check that a GET to target yields expected_code."""
261         self.conn.request('GET', target)
262         self.assertEqual(self.conn.getresponse().status, expected_code)
263
264     def check_post(self, data: Mapping[str, object], target: str,
265                    expected_code: int, redirect_location: str = '') -> None:
266         """Check that POST of data to target yields expected_code."""
267         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
268         headers = {'Content-Type': 'application/x-www-form-urlencoded',
269                    'Content-Length': str(len(encoded_form_data))}
270         self.conn.request('POST', target,
271                           body=encoded_form_data, headers=headers)
272         if 302 == expected_code:
273             if redirect_location == '':
274                 redirect_location = target
275             self.check_redirect(redirect_location)
276         else:
277             self.assertEqual(self.conn.getresponse().status, expected_code)
278
279     def check_get_defaults(self, path: str) -> None:
280         """Some standard model paths to test."""
281         self.check_get(path, 200)
282         self.check_get(f'{path}?id=', 200)
283         self.check_get(f'{path}?id=foo', 400)
284         self.check_get(f'/{path}?id=0', 500)
285         self.check_get(f'{path}?id=1', 200)
286
287     def post_process(self, id_: int = 1,
288                      form_data: dict[str, Any] | None = None
289                      ) -> dict[str, Any]:
290         """POST basic Process."""
291         if not form_data:
292             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
293         self.check_post(form_data, f'/process?id={id_}', 302,
294                         f'/process?id={id_}')
295         return form_data
296
297     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
298         """Compare JSON on GET path with expected.
299
300         To simplify comparison of VersionedAttribute histories, transforms
301         timestamp keys of VersionedAttribute history keys into integers
302         counting chronologically forward from 0.
303         """
304         def rewrite_history_keys_in(item: Any) -> Any:
305             if isinstance(item, dict):
306                 if '_versioned' in item.keys():
307                     for k in item['_versioned']:
308                         vals = item['_versioned'][k].values()
309                         history = {}
310                         for i, val in enumerate(vals):
311                             history[i] = val
312                         item['_versioned'][k] = history
313                 for k in list(item.keys()):
314                     rewrite_history_keys_in(item[k])
315             elif isinstance(item, list):
316                 item[:] = [rewrite_history_keys_in(i) for i in item]
317             return item
318         self.conn.request('GET', path)
319         response = self.conn.getresponse()
320         self.assertEqual(response.status, 200)
321         retrieved = json_loads(response.read().decode())
322         rewrite_history_keys_in(retrieved)
323         self.assertEqual(expected, retrieved)