home · contact · privacy
d6c5b20ac7882281d4958e99e3dbcd6a35de708b
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from json import loads as json_loads
6 from urllib.parse import urlencode
7 from uuid import uuid4
8 from os import remove as remove_file
9 from typing import Mapping, Any
10 from plomtask.db import DatabaseFile, DatabaseConnection
11 from plomtask.http import TaskHandler, TaskServer
12 from plomtask.processes import Process, ProcessStep
13 from plomtask.conditions import Condition
14 from plomtask.days import Day
15 from plomtask.todos import Todo
16 from plomtask.exceptions import NotFoundException, HandledException
17
18
19 class TestCaseSansDB(TestCase):
20     """Tests requiring no DB setup."""
21     checked_class: Any
22     do_id_test: bool = False
23     default_init_args: list[Any] = []
24     versioned_defaults_to_test: dict[str, str | float] = {}
25
26     def test_id_setting(self) -> None:
27         """Test .id_ being set and its legal range being enforced."""
28         if not self.do_id_test:
29             return
30         with self.assertRaises(HandledException):
31             self.checked_class(0, *self.default_init_args)
32         obj = self.checked_class(5, *self.default_init_args)
33         self.assertEqual(obj.id_, 5)
34
35     def test_versioned_defaults(self) -> None:
36         """Test defaults of VersionedAttributes."""
37         if len(self.versioned_defaults_to_test) == 0:
38             return
39         obj = self.checked_class(1, *self.default_init_args)
40         for k, v in self.versioned_defaults_to_test.items():
41             self.assertEqual(getattr(obj, k).newest, v)
42
43
44 class TestCaseWithDB(TestCase):
45     """Module tests not requiring DB setup."""
46     checked_class: Any
47     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
48     default_init_kwargs: dict[str, Any] = {}
49     test_versioneds: dict[str, type] = {}
50
51     def setUp(self) -> None:
52         Condition.empty_cache()
53         Day.empty_cache()
54         Process.empty_cache()
55         ProcessStep.empty_cache()
56         Todo.empty_cache()
57         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
58         self.db_conn = DatabaseConnection(self.db_file)
59
60     def tearDown(self) -> None:
61         self.db_conn.close()
62         remove_file(self.db_file.path)
63
64     def test_saving_and_caching(self) -> None:
65         """Test storage and initialization of instances and attributes."""
66         if not hasattr(self, 'checked_class'):
67             return
68         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69         obj = self.checked_class(None, **self.default_init_kwargs)
70         obj.save(self.db_conn)
71         self.assertEqual(obj.id_, 2)
72         for k, v in self.test_versioneds.items():
73             self.check_saving_of_versioned(k, v)
74
75     def check_storage(self, content: list[Any]) -> None:
76         """Test cache and DB equal content."""
77         expected_cache = {}
78         for item in content:
79             expected_cache[item.id_] = item
80         self.assertEqual(self.checked_class.get_cache(), expected_cache)
81         hashes_content = [hash(x) for x in content]
82         db_found: list[Any] = []
83         for item in content:
84             assert isinstance(item.id_, type(self.default_ids[0]))
85             for row in self.db_conn.row_where(self.checked_class.table_name,
86                                               'id', item.id_):
87                 db_found += [self.checked_class.from_table_row(self.db_conn,
88                                                                row)]
89         hashes_db_found = [hash(x) for x in db_found]
90         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
91
92     def check_saving_and_caching(self, **kwargs: Any) -> None:
93         """Test instance.save in its core without relations."""
94         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
95         # check object init itself doesn't store anything yet
96         self.check_storage([])
97         # check saving sets core attributes properly
98         obj.save(self.db_conn)
99         for key, value in kwargs.items():
100             self.assertEqual(getattr(obj, key), value)
101         # check saving stored properly in cache and DB
102         self.check_storage([obj])
103
104     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
105         """Test owner's versioned attributes."""
106         owner = self.checked_class(None)
107         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
108         attr = getattr(owner, attr_name)
109         attr.set(vals[0])
110         attr.set(vals[1])
111         owner.save(self.db_conn)
112         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
113         attr = getattr(retrieved, attr_name)
114         self.assertEqual(sorted(attr.history.values()), vals)
115
116     def check_by_id(self) -> None:
117         """Test .by_id(), including creation."""
118         # check failure if not yet saved
119         id1, id2 = self.default_ids[0], self.default_ids[1]
120         obj = self.checked_class(id1)  # pylint: disable=not-callable
121         with self.assertRaises(NotFoundException):
122             self.checked_class.by_id(self.db_conn, id1)
123         # check identity of saved and retrieved
124         obj.save(self.db_conn)
125         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
126         # check create=True acts like normal instantiation (sans saving)
127         by_id_created = self.checked_class.by_id(self.db_conn, id2,
128                                                  create=True)
129         # pylint: disable=not-callable
130         self.assertEqual(self.checked_class(id2), by_id_created)
131         self.check_storage([obj])
132
133     def check_from_table_row(self, *args: Any) -> None:
134         """Test .from_table_row() properly reads in class from DB"""
135         id_ = self.default_ids[0]
136         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
137         obj.save(self.db_conn)
138         assert isinstance(obj.id_, type(self.default_ids[0]))
139         for row in self.db_conn.row_where(self.checked_class.table_name,
140                                           'id', obj.id_):
141             hash_original = hash(obj)
142             retrieved = self.checked_class.from_table_row(self.db_conn, row)
143             self.assertEqual(hash_original, hash(retrieved))
144             self.assertEqual({retrieved.id_: retrieved},
145                              self.checked_class.get_cache())
146
147     def check_versioned_from_table_row(self, attr_name: str,
148                                        type_: type) -> None:
149         """Test .from_table_row() reads versioned attributes from DB."""
150         owner = self.checked_class(None)
151         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
152         attr = getattr(owner, attr_name)
153         attr.set(vals[0])
154         attr.set(vals[1])
155         owner.save(self.db_conn)
156         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
157             retrieved = owner.__class__.from_table_row(self.db_conn, row)
158             attr = getattr(retrieved, attr_name)
159             self.assertEqual(sorted(attr.history.values()), vals)
160
161     def check_all(self) -> tuple[Any, Any, Any]:
162         """Test .all()."""
163         # pylint: disable=not-callable
164         item1 = self.checked_class(self.default_ids[0])
165         item2 = self.checked_class(self.default_ids[1])
166         item3 = self.checked_class(self.default_ids[2])
167         # check pre-save .all() returns empty list
168         self.assertEqual(self.checked_class.all(self.db_conn), [])
169         # check that all() shows all saved, but no unsaved items
170         item1.save(self.db_conn)
171         item3.save(self.db_conn)
172         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
173                          sorted([item1, item3]))
174         item2.save(self.db_conn)
175         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
176                          sorted([item1, item2, item3]))
177         return item1, item2, item3
178
179     def check_singularity(self, defaulting_field: str,
180                           non_default_value: Any, *args: Any) -> None:
181         """Test pointers made for single object keep pointing to it."""
182         id1 = self.default_ids[0]
183         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
184         obj.save(self.db_conn)
185         setattr(obj, defaulting_field, non_default_value)
186         retrieved = self.checked_class.by_id(self.db_conn, id1)
187         self.assertEqual(non_default_value,
188                          getattr(retrieved, defaulting_field))
189
190     def check_versioned_singularity(self) -> None:
191         """Test singularity of VersionedAttributes on saving (with .title)."""
192         obj = self.checked_class(None)  # pylint: disable=not-callable
193         obj.save(self.db_conn)
194         assert isinstance(obj.id_, int)
195         obj.title.set('named')
196         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
197         self.assertEqual(obj.title.history, retrieved.title.history)
198
199     def check_remove(self, *args: Any) -> None:
200         """Test .remove() effects on DB and cache."""
201         id_ = self.default_ids[0]
202         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
203         with self.assertRaises(HandledException):
204             obj.remove(self.db_conn)
205         obj.save(self.db_conn)
206         obj.remove(self.db_conn)
207         self.check_storage([])
208
209
210 class TestCaseWithServer(TestCaseWithDB):
211     """Module tests against our HTTP server/handler (and database)."""
212
213     def setUp(self) -> None:
214         super().setUp()
215         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
216         self.server_thread = Thread(target=self.httpd.serve_forever)
217         self.server_thread.daemon = True
218         self.server_thread.start()
219         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
220                                    self.httpd.server_address[1])
221         self.httpd.set_json_mode()
222
223     def tearDown(self) -> None:
224         self.httpd.shutdown()
225         self.httpd.server_close()
226         self.server_thread.join()
227         super().tearDown()
228
229     def check_redirect(self, target: str) -> None:
230         """Check that self.conn answers with a 302 redirect to target."""
231         response = self.conn.getresponse()
232         self.assertEqual(response.status, 302)
233         self.assertEqual(response.getheader('Location'), target)
234
235     def check_get(self, target: str, expected_code: int) -> None:
236         """Check that a GET to target yields expected_code."""
237         self.conn.request('GET', target)
238         self.assertEqual(self.conn.getresponse().status, expected_code)
239
240     def check_post(self, data: Mapping[str, object], target: str,
241                    expected_code: int, redirect_location: str = '') -> None:
242         """Check that POST of data to target yields expected_code."""
243         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
244         headers = {'Content-Type': 'application/x-www-form-urlencoded',
245                    'Content-Length': str(len(encoded_form_data))}
246         self.conn.request('POST', target,
247                           body=encoded_form_data, headers=headers)
248         if 302 == expected_code:
249             if redirect_location == '':
250                 redirect_location = target
251             self.check_redirect(redirect_location)
252         else:
253             self.assertEqual(self.conn.getresponse().status, expected_code)
254
255     def check_get_defaults(self, path: str) -> None:
256         """Some standard model paths to test."""
257         self.check_get(path, 200)
258         self.check_get(f'{path}?id=', 200)
259         self.check_get(f'{path}?id=foo', 400)
260         self.check_get(f'/{path}?id=0', 500)
261         self.check_get(f'{path}?id=1', 200)
262
263     def post_process(self, id_: int = 1,
264                      form_data: dict[str, Any] | None = None
265                      ) -> dict[str, Any]:
266         """POST basic Process."""
267         if not form_data:
268             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
269         self.check_post(form_data, f'/process?id={id_}', 302,
270                         f'/process?id={id_}')
271         return form_data
272
273     def check_json_get(self, path: str, expected: dict[str, object]) -> None:
274         """Compare JSON on GET path with expected.
275
276         To simplify comparison of VersionedAttribute histories, transforms
277         timestamp keys of VersionedAttribute history keys into integers
278         counting chronologically forward from 0.
279         """
280         def rewrite_history_keys_in(item: Any) -> Any:
281             if isinstance(item, dict):
282                 if '_versioned' in item.keys():
283                     for k in item['_versioned']:
284                         vals = item['_versioned'][k].values()
285                         history = {}
286                         for i, val in enumerate(vals):
287                             history[i] = val
288                         item['_versioned'][k] = history
289                 for k in list(item.keys()):
290                     rewrite_history_keys_in(item[k])
291             elif isinstance(item, list):
292                 item[:] = [rewrite_history_keys_in(i) for i in item]
293             return item
294         self.conn.request('GET', path)
295         response = self.conn.getresponse()
296         self.assertEqual(response.status, 200)
297         retrieved = json_loads(response.read().decode())
298         rewrite_history_keys_in(retrieved)
299         self.assertEqual(expected, retrieved)