home · contact · privacy
Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
9 from uuid import uuid4
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
18
19
20 class TestCaseSansDB(TestCase):
21     """Tests requiring no DB setup."""
22     checked_class: Any
23     do_id_test: bool = False
24     default_init_args: list[Any] = []
25     versioned_defaults_to_test: dict[str, str | float] = {}
26
27     def test_id_setting(self) -> None:
28         """Test .id_ being set and its legal range being enforced."""
29         if not self.do_id_test:
30             return
31         with self.assertRaises(HandledException):
32             self.checked_class(0, *self.default_init_args)
33         obj = self.checked_class(5, *self.default_init_args)
34         self.assertEqual(obj.id_, 5)
35
36     def test_versioned_defaults(self) -> None:
37         """Test defaults of VersionedAttributes."""
38         if len(self.versioned_defaults_to_test) == 0:
39             return
40         obj = self.checked_class(1, *self.default_init_args)
41         for k, v in self.versioned_defaults_to_test.items():
42             self.assertEqual(getattr(obj, k).newest, v)
43
44
45 class TestCaseWithDB(TestCase):
46     """Module tests not requiring DB setup."""
47     checked_class: Any
48     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49     default_init_kwargs: dict[str, Any] = {}
50     test_versioneds: dict[str, type] = {}
51
52     def setUp(self) -> None:
53         Condition.empty_cache()
54         Day.empty_cache()
55         Process.empty_cache()
56         ProcessStep.empty_cache()
57         Todo.empty_cache()
58         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59         self.db_conn = DatabaseConnection(self.db_file)
60
61     def tearDown(self) -> None:
62         self.db_conn.close()
63         remove_file(self.db_file.path)
64
65     @staticmethod
66     def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67         def wrapper(self: TestCaseWithDB) -> None:
68             if hasattr(self, 'checked_class'):
69                 f(self)
70         return wrapper
71
72     @_within_checked_class
73     def test_saving_and_caching(self) -> None:
74         """Test storage and initialization of instances and attributes."""
75         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
76         obj = self.checked_class(None, **self.default_init_kwargs)
77         obj.save(self.db_conn)
78         self.assertEqual(obj.id_, 2)
79         for attr_name, type_ in self.test_versioneds.items():
80             owner = self.checked_class(None)
81             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
82             attr = getattr(owner, attr_name)
83             attr.set(vals[0])
84             attr.set(vals[1])
85             owner.save(self.db_conn)
86             retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
87             attr = getattr(retrieved, attr_name)
88             self.assertEqual(sorted(attr.history.values()), vals)
89
90     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
91         """Test both cache and DB equal content."""
92         expected_cache = {}
93         for item in content:
94             expected_cache[item.id_] = item
95         self.assertEqual(self.checked_class.get_cache(), expected_cache)
96         hashes_content = [hash(x) for x in content]
97         db_found: list[Any] = []
98         for item in content:
99             assert isinstance(item.id_, type(self.default_ids[0]))
100             for row in self.db_conn.row_where(self.checked_class.table_name,
101                                               'id', item.id_):
102                 db_found += [self.checked_class.from_table_row(self.db_conn,
103                                                                row)]
104         hashes_db_found = [hash(x) for x in db_found]
105         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
106
107     def check_saving_and_caching(self, **kwargs: Any) -> None:
108         """Test instance.save in its core without relations."""
109         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
110         # check object init itself doesn't store anything yet
111         self.check_identity_with_cache_and_db([])
112         # check saving sets core attributes properly
113         obj.save(self.db_conn)
114         for key, value in kwargs.items():
115             self.assertEqual(getattr(obj, key), value)
116         # check saving stored properly in cache and DB
117         self.check_identity_with_cache_and_db([obj])
118
119     @_within_checked_class
120     def test_by_id(self) -> None:
121         """Test .by_id()."""
122         id1, id2, _ = self.default_ids
123         # check failure if not yet saved
124         obj1 = self.checked_class(id1, **self.default_init_kwargs)
125         with self.assertRaises(NotFoundException):
126             self.checked_class.by_id(self.db_conn, id1)
127         # check identity of cached and retrieved
128         obj1.cache()
129         self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
130         # check identity of saved and retrieved
131         obj2 = self.checked_class(id2, **self.default_init_kwargs)
132         obj2.save(self.db_conn)
133         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
134         # obj1.save(self.db_conn)
135         # self.check_identity_with_cache_and_db([obj1, obj2])
136
137     @_within_checked_class
138     def test_by_id_or_create(self) -> None:
139         """Test .by_id_or_create."""
140         # check .by_id_or_create acts like normal instantiation (sans saving)
141         id_ = self.default_ids[0]
142         if not self.checked_class.can_create_by_id:
143             with self.assertRaises(HandledException):
144                 self.checked_class.by_id_or_create(self.db_conn, id_)
145         # check .by_id_or_create fails if wrong class
146         else:
147             by_id_created = self.checked_class.by_id_or_create(self.db_conn,
148                                                                id_)
149             with self.assertRaises(NotFoundException):
150                 self.checked_class.by_id(self.db_conn, id_)
151             self.assertEqual(self.checked_class(id_), by_id_created)
152
153     @_within_checked_class
154     def test_from_table_row(self) -> None:
155         """Test .from_table_row() properly reads in class directly from DB."""
156         id_ = self.default_ids[0]
157         obj = self.checked_class(id_, **self.default_init_kwargs)
158         obj.save(self.db_conn)
159         assert isinstance(obj.id_, type(self.default_ids[0]))
160         for row in self.db_conn.row_where(self.checked_class.table_name,
161                                           'id', obj.id_):
162             # check .from_table_row reproduces state saved, no matter if obj
163             # later changed (with caching even)
164             hash_original = hash(obj)
165             attr_name = self.checked_class.to_save[-1]
166             attr = getattr(obj, attr_name)
167             if isinstance(attr, (int, float)):
168                 setattr(obj, attr_name, attr + 1)
169             elif isinstance(attr, str):
170                 setattr(obj, attr_name, attr + "_")
171             elif isinstance(attr, bool):
172                 setattr(obj, attr_name, not attr)
173             obj.cache()
174             to_cmp = getattr(obj, attr_name)
175             retrieved = self.checked_class.from_table_row(self.db_conn, row)
176             self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
177             self.assertEqual(hash_original, hash(retrieved))
178             # check cache contains what .from_table_row just produced
179             self.assertEqual({retrieved.id_: retrieved},
180                              self.checked_class.get_cache())
181
182     def check_versioned_from_table_row(self, attr_name: str,
183                                        type_: type) -> None:
184         """Test .from_table_row() reads versioned attributes from DB."""
185         owner = self.checked_class(None)
186         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
187         attr = getattr(owner, attr_name)
188         attr.set(vals[0])
189         attr.set(vals[1])
190         owner.save(self.db_conn)
191         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
192             retrieved = owner.__class__.from_table_row(self.db_conn, row)
193             attr = getattr(retrieved, attr_name)
194             self.assertEqual(sorted(attr.history.values()), vals)
195
196     @_within_checked_class
197     def test_all(self) -> None:
198         """Test .all() and its relation to cache and savings."""
199         id_1, id_2, id_3 = self.default_ids
200         item1 = self.checked_class(id_1, **self.default_init_kwargs)
201         item2 = self.checked_class(id_2, **self.default_init_kwargs)
202         item3 = self.checked_class(id_3, **self.default_init_kwargs)
203         # check .all() returns empty list on un-cached items
204         self.assertEqual(self.checked_class.all(self.db_conn), [])
205         # check that all() shows only cached/saved items
206         item1.cache()
207         item3.save(self.db_conn)
208         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
209                          sorted([item1, item3]))
210         item2.save(self.db_conn)
211         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
212                          sorted([item1, item2, item3]))
213
214     @_within_checked_class
215     def test_singularity(self) -> None:
216         """Test pointers made for single object keep pointing to it."""
217         id1 = self.default_ids[0]
218         obj = self.checked_class(id1, **self.default_init_kwargs)
219         obj.save(self.db_conn)
220         attr_name = self.checked_class.to_save[-1]
221         attr = getattr(obj, attr_name)
222         new_attr: str | int | float | bool
223         if isinstance(attr, (int, float)):
224             new_attr = attr + 1
225         elif isinstance(attr, str):
226             new_attr = attr + '_'
227         elif isinstance(attr, bool):
228             new_attr = not attr
229         setattr(obj, attr_name, new_attr)
230         retrieved = self.checked_class.by_id(self.db_conn, id1)
231         self.assertEqual(new_attr, getattr(retrieved, attr_name))
232
233     def check_versioned_singularity(self) -> None:
234         """Test singularity of VersionedAttributes on saving (with .title)."""
235         obj = self.checked_class(None)  # pylint: disable=not-callable
236         obj.save(self.db_conn)
237         assert isinstance(obj.id_, int)
238         obj.title.set('named')
239         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
240         self.assertEqual(obj.title.history, retrieved.title.history)
241
242     def check_remove(self, *args: Any) -> None:
243         """Test .remove() effects on DB and cache."""
244         id_ = self.default_ids[0]
245         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
246         with self.assertRaises(HandledException):
247             obj.remove(self.db_conn)
248         obj.save(self.db_conn)
249         obj.remove(self.db_conn)
250         self.check_identity_with_cache_and_db([])
251
252
253 class TestCaseWithServer(TestCaseWithDB):
254     """Module tests against our HTTP server/handler (and database)."""
255
256     def setUp(self) -> None:
257         super().setUp()
258         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
259         self.server_thread = Thread(target=self.httpd.serve_forever)
260         self.server_thread.daemon = True
261         self.server_thread.start()
262         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
263                                    self.httpd.server_address[1])
264         self.httpd.set_json_mode()
265
266     def tearDown(self) -> None:
267         self.httpd.shutdown()
268         self.httpd.server_close()
269         self.server_thread.join()
270         super().tearDown()
271
272     def check_redirect(self, target: str) -> None:
273         """Check that self.conn answers with a 302 redirect to target."""
274         response = self.conn.getresponse()
275         self.assertEqual(response.status, 302)
276         self.assertEqual(response.getheader('Location'), target)
277
278     def check_get(self, target: str, expected_code: int) -> None:
279         """Check that a GET to target yields expected_code."""
280         self.conn.request('GET', target)
281         self.assertEqual(self.conn.getresponse().status, expected_code)
282
283     def check_post(self, data: Mapping[str, object], target: str,
284                    expected_code: int, redirect_location: str = '') -> None:
285         """Check that POST of data to target yields expected_code."""
286         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
287         headers = {'Content-Type': 'application/x-www-form-urlencoded',
288                    'Content-Length': str(len(encoded_form_data))}
289         self.conn.request('POST', target,
290                           body=encoded_form_data, headers=headers)
291         if 302 == expected_code:
292             if redirect_location == '':
293                 redirect_location = target
294             self.check_redirect(redirect_location)
295         else:
296             self.assertEqual(self.conn.getresponse().status, expected_code)
297
298     def check_get_defaults(self, path: str) -> None:
299         """Some standard model paths to test."""
300         self.check_get(path, 200)
301         self.check_get(f'{path}?id=', 200)
302         self.check_get(f'{path}?id=foo', 400)
303         self.check_get(f'/{path}?id=0', 500)
304         self.check_get(f'{path}?id=1', 200)
305
306     def post_process(self, id_: int = 1,
307                      form_data: dict[str, Any] | None = None
308                      ) -> dict[str, Any]:
309         """POST basic Process."""
310         if not form_data:
311             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
312         self.check_post(form_data, f'/process?id={id_}', 302,
313                         f'/process?id={id_}')
314         return form_data
315
316     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
317         """Compare JSON on GET path with expected.
318
319         To simplify comparison of VersionedAttribute histories, transforms
320         timestamp keys of VersionedAttribute history keys into integers
321         counting chronologically forward from 0.
322         """
323         def rewrite_history_keys_in(item: Any) -> Any:
324             if isinstance(item, dict):
325                 if '_versioned' in item.keys():
326                     for k in item['_versioned']:
327                         vals = item['_versioned'][k].values()
328                         history = {}
329                         for i, val in enumerate(vals):
330                             history[i] = val
331                         item['_versioned'][k] = history
332                 for k in list(item.keys()):
333                     rewrite_history_keys_in(item[k])
334             elif isinstance(item, list):
335                 item[:] = [rewrite_history_keys_in(i) for i in item]
336             return item
337         self.conn.request('GET', path)
338         response = self.conn.getresponse()
339         self.assertEqual(response.status, 200)
340         retrieved = json_loads(response.read().decode())
341         rewrite_history_keys_in(retrieved)
342         self.assertEqual(expected, retrieved)