home · contact · privacy
Some test utils refactoring.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 class TestCaseSansDB(TestCase):
21     """Tests requiring no DB setup."""
22     checked_class: Any
23     do_id_test: bool = False
24     default_init_args: list[Any] = []
25     versioned_defaults_to_test: dict[str, str | float] = {}
26
27     def test_id_setting(self) -> None:
28         """Test .id_ being set and its legal range being enforced."""
29         if not self.do_id_test:
30             return
31         with self.assertRaises(HandledException):
32             self.checked_class(0, *self.default_init_args)
33         obj = self.checked_class(5, *self.default_init_args)
34         self.assertEqual(obj.id_, 5)
35
36     def test_versioned_defaults(self) -> None:
37         """Test defaults of VersionedAttributes."""
38         if len(self.versioned_defaults_to_test) == 0:
39             return
40         obj = self.checked_class(1, *self.default_init_args)
41         for k, v in self.versioned_defaults_to_test.items():
42             self.assertEqual(getattr(obj, k).newest, v)
43
44
45 class TestCaseWithDB(TestCase):
46     """Module tests not requiring DB setup."""
47     checked_class: Any
48     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49     default_init_kwargs: dict[str, Any] = {}
50     test_versioneds: dict[str, type] = {}
51
52     def setUp(self) -> None:
53         Condition.empty_cache()
54         Day.empty_cache()
55         Process.empty_cache()
56         ProcessStep.empty_cache()
57         Todo.empty_cache()
58         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59         self.db_conn = DatabaseConnection(self.db_file)
60
61     def tearDown(self) -> None:
62         self.db_conn.close()
63         remove_file(self.db_file.path)
64
65     @staticmethod
66     def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67         def wrapper(self: TestCaseWithDB) -> None:
68             if hasattr(self, 'checked_class'):
69                 f(self)
70         return wrapper
71
72     @_within_checked_class
73     def test_saving_and_caching(self) -> None:
74         """Test storage and initialization of instances and attributes."""
75         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
76         obj = self.checked_class(None, **self.default_init_kwargs)
77         obj.save(self.db_conn)
78         self.assertEqual(obj.id_, 2)
79         for attr_name, type_ in self.test_versioneds.items():
80             owner = self.checked_class(None)
81             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
82             attr = getattr(owner, attr_name)
83             attr.set(vals[0])
84             attr.set(vals[1])
85             owner.save(self.db_conn)
86             retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
87             attr = getattr(retrieved, attr_name)
88             self.assertEqual(sorted(attr.history.values()), vals)
89
90     def check_storage(self, content: list[Any]) -> None:
91         """Test cache and DB equal content."""
92         expected_cache = {}
93         for item in content:
94             expected_cache[item.id_] = item
95         self.assertEqual(self.checked_class.get_cache(), expected_cache)
96         hashes_content = [hash(x) for x in content]
97         db_found: list[Any] = []
98         for item in content:
99             assert isinstance(item.id_, type(self.default_ids[0]))
100             for row in self.db_conn.row_where(self.checked_class.table_name,
101                                               'id', item.id_):
102                 db_found += [self.checked_class.from_table_row(self.db_conn,
103                                                                row)]
104         hashes_db_found = [hash(x) for x in db_found]
105         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
106
107     def check_saving_and_caching(self, **kwargs: Any) -> None:
108         """Test instance.save in its core without relations."""
109         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
110         # check object init itself doesn't store anything yet
111         self.check_storage([])
112         # check saving sets core attributes properly
113         obj.save(self.db_conn)
114         for key, value in kwargs.items():
115             self.assertEqual(getattr(obj, key), value)
116         # check saving stored properly in cache and DB
117         self.check_storage([obj])
118
119     def check_by_id(self) -> None:
120         """Test .by_id(), including creation."""
121         # check failure if not yet saved
122         id1, id2 = self.default_ids[0], self.default_ids[1]
123         obj = self.checked_class(id1)  # pylint: disable=not-callable
124         with self.assertRaises(NotFoundException):
125             self.checked_class.by_id(self.db_conn, id1)
126         # check identity of saved and retrieved
127         obj.save(self.db_conn)
128         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
129         # check create=True acts like normal instantiation (sans saving)
130         by_id_created = self.checked_class.by_id(self.db_conn, id2,
131                                                  create=True)
132         # pylint: disable=not-callable
133         self.assertEqual(self.checked_class(id2), by_id_created)
134         self.check_storage([obj])
135
136     @_within_checked_class
137     def test_from_table_row(self) -> None:
138         """Test .from_table_row() properly reads in class directly from DB."""
139         id_ = self.default_ids[0]
140         obj = self.checked_class(id_, **self.default_init_kwargs)
141         obj.save(self.db_conn)
142         assert isinstance(obj.id_, type(self.default_ids[0]))
143         for row in self.db_conn.row_where(self.checked_class.table_name,
144                                           'id', obj.id_):
145             # check .from_table_row reproduces state saved, no matter if obj
146             # later changed (with caching even)
147             hash_original = hash(obj)
148             attr_name = self.checked_class.to_save[-1]
149             attr = getattr(obj, attr_name)
150             if isinstance(attr, (int, float)):
151                 setattr(obj, attr_name, attr + 1)
152             elif isinstance(attr, str):
153                 setattr(obj, attr_name, attr + "_")
154             elif isinstance(attr, bool):
155                 setattr(obj, attr_name, not attr)
156             obj.cache()
157             to_cmp = getattr(obj, attr_name)
158             retrieved = self.checked_class.from_table_row(self.db_conn, row)
159             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
160             self.assertEqual(hash_original, hash(retrieved))
161             # check cache contains what .from_table_row just produced
162             self.assertEqual({retrieved.id_: retrieved},
163                              self.checked_class.get_cache())
164
165     def check_versioned_from_table_row(self, attr_name: str,
166                                        type_: type) -> None:
167         """Test .from_table_row() reads versioned attributes from DB."""
168         owner = self.checked_class(None)
169         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
170         attr = getattr(owner, attr_name)
171         attr.set(vals[0])
172         attr.set(vals[1])
173         owner.save(self.db_conn)
174         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
175             retrieved = owner.__class__.from_table_row(self.db_conn, row)
176             attr = getattr(retrieved, attr_name)
177             self.assertEqual(sorted(attr.history.values()), vals)
178
179     @_within_checked_class
180     def test_all(self) -> None:
181         """Test .all() and its relation to cache and savings."""
182         id_1, id_2, id_3 = self.default_ids
183         item1 = self.checked_class(id_1, **self.default_init_kwargs)
184         item2 = self.checked_class(id_2, **self.default_init_kwargs)
185         item3 = self.checked_class(id_3, **self.default_init_kwargs)
186         # check .all() returns empty list on un-cached items
187         self.assertEqual(self.checked_class.all(self.db_conn), [])
188         # check that all() shows only cached/saved items
189         item1.cache()
190         item3.save(self.db_conn)
191         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
192                          sorted([item1, item3]))
193         item2.save(self.db_conn)
194         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
195                          sorted([item1, item2, item3]))
196
197     @_within_checked_class
198     def test_singularity(self) -> None:
199         """Test pointers made for single object keep pointing to it."""
200         id1 = self.default_ids[0]
201         obj = self.checked_class(id1, **self.default_init_kwargs)
202         obj.save(self.db_conn)
203         attr_name = self.checked_class.to_save[-1]
204         attr = getattr(obj, attr_name)
205         new_attr: str | int | float | bool
206         if isinstance(attr, (int, float)):
207             new_attr = attr + 1
208         elif isinstance(attr, str):
209             new_attr = attr + '_'
210         elif isinstance(attr, bool):
211             new_attr = not attr
212         setattr(obj, attr_name, new_attr)
213         retrieved = self.checked_class.by_id(self.db_conn, id1)
214         self.assertEqual(new_attr, getattr(retrieved, attr_name))
215
216     def check_versioned_singularity(self) -> None:
217         """Test singularity of VersionedAttributes on saving (with .title)."""
218         obj = self.checked_class(None)  # pylint: disable=not-callable
219         obj.save(self.db_conn)
220         assert isinstance(obj.id_, int)
221         obj.title.set('named')
222         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
223         self.assertEqual(obj.title.history, retrieved.title.history)
224
225     def check_remove(self, *args: Any) -> None:
226         """Test .remove() effects on DB and cache."""
227         id_ = self.default_ids[0]
228         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
229         with self.assertRaises(HandledException):
230             obj.remove(self.db_conn)
231         obj.save(self.db_conn)
232         obj.remove(self.db_conn)
233         self.check_storage([])
234
235
236 class TestCaseWithServer(TestCaseWithDB):
237     """Module tests against our HTTP server/handler (and database)."""
238
239     def setUp(self) -> None:
240         super().setUp()
241         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
242         self.server_thread = Thread(target=self.httpd.serve_forever)
243         self.server_thread.daemon = True
244         self.server_thread.start()
245         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
246                                    self.httpd.server_address[1])
247         self.httpd.set_json_mode()
248
249     def tearDown(self) -> None:
250         self.httpd.shutdown()
251         self.httpd.server_close()
252         self.server_thread.join()
253         super().tearDown()
254
255     def check_redirect(self, target: str) -> None:
256         """Check that self.conn answers with a 302 redirect to target."""
257         response = self.conn.getresponse()
258         self.assertEqual(response.status, 302)
259         self.assertEqual(response.getheader('Location'), target)
260
261     def check_get(self, target: str, expected_code: int) -> None:
262         """Check that a GET to target yields expected_code."""
263         self.conn.request('GET', target)
264         self.assertEqual(self.conn.getresponse().status, expected_code)
265
266     def check_post(self, data: Mapping[str, object], target: str,
267                    expected_code: int, redirect_location: str = '') -> None:
268         """Check that POST of data to target yields expected_code."""
269         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
270         headers = {'Content-Type': 'application/x-www-form-urlencoded',
271                    'Content-Length': str(len(encoded_form_data))}
272         self.conn.request('POST', target,
273                           body=encoded_form_data, headers=headers)
274         if 302 == expected_code:
275             if redirect_location == '':
276                 redirect_location = target
277             self.check_redirect(redirect_location)
278         else:
279             self.assertEqual(self.conn.getresponse().status, expected_code)
280
281     def check_get_defaults(self, path: str) -> None:
282         """Some standard model paths to test."""
283         self.check_get(path, 200)
284         self.check_get(f'{path}?id=', 200)
285         self.check_get(f'{path}?id=foo', 400)
286         self.check_get(f'/{path}?id=0', 500)
287         self.check_get(f'{path}?id=1', 200)
288
289     def post_process(self, id_: int = 1,
290                      form_data: dict[str, Any] | None = None
291                      ) -> dict[str, Any]:
292         """POST basic Process."""
293         if not form_data:
294             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
295         self.check_post(form_data, f'/process?id={id_}', 302,
296                         f'/process?id={id_}')
297         return form_data
298
299     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
300         """Compare JSON on GET path with expected.
301
302         To simplify comparison of VersionedAttribute histories, transforms
303         timestamp keys of VersionedAttribute history keys into integers
304         counting chronologically forward from 0.
305         """
306         def rewrite_history_keys_in(item: Any) -> Any:
307             if isinstance(item, dict):
308                 if '_versioned' in item.keys():
309                     for k in item['_versioned']:
310                         vals = item['_versioned'][k].values()
311                         history = {}
312                         for i, val in enumerate(vals):
313                             history[i] = val
314                         item['_versioned'][k] = history
315                 for k in list(item.keys()):
316                     rewrite_history_keys_in(item[k])
317             elif isinstance(item, list):
318                 item[:] = [rewrite_history_keys_in(i) for i in item]
319             return item
320         self.conn.request('GET', path)
321         response = self.conn.getresponse()
322         self.assertEqual(response.status, 200)
323         retrieved = json_loads(response.read().decode())
324         rewrite_history_keys_in(retrieved)
325         self.assertEqual(expected, retrieved)