home · contact · privacy
13e4f949d177a7d73456d2546db9f49de3e33fb1
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from uuid import uuid4
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseSansDB(TestCase):
19     """Tests requiring no DB setup."""
20     checked_class: Any
21     do_id_test: bool = False
22     default_init_args: list[Any] = []
23     versioned_defaults_to_test: dict[str, str | float] = {}
24
25     def test_id_setting(self) -> None:
26         """Test .id_ being set and its legal range being enforced."""
27         if not self.do_id_test:
28             return
29         with self.assertRaises(HandledException):
30             self.checked_class(0, *self.default_init_args)
31         obj = self.checked_class(5, *self.default_init_args)
32         self.assertEqual(obj.id_, 5)
33
34     def test_versioned_defaults(self) -> None:
35         """Test defaults of VersionedAttributes."""
36         if len(self.versioned_defaults_to_test) == 0:
37             return
38         obj = self.checked_class(1, *self.default_init_args)
39         for k, v in self.versioned_defaults_to_test.items():
40             self.assertEqual(getattr(obj, k).newest, v)
41
42
43 class TestCaseWithDB(TestCase):
44     """Module tests not requiring DB setup."""
45     checked_class: Any
46     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47     default_init_kwargs: dict[str, Any] = {}
48     test_versioneds: dict[str, type] = {}
49
50     def setUp(self) -> None:
51         Condition.empty_cache()
52         Day.empty_cache()
53         Process.empty_cache()
54         ProcessStep.empty_cache()
55         Todo.empty_cache()
56         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
57         self.db_conn = DatabaseConnection(self.db_file)
58
59     def tearDown(self) -> None:
60         self.db_conn.close()
61         remove_file(self.db_file.path)
62
63     def test_saving_and_caching(self) -> None:
64         """Test storage and initialization of instances and attributes."""
65         if not hasattr(self, 'checked_class'):
66             return
67         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
68         obj = self.checked_class(None, **self.default_init_kwargs)
69         obj.save(self.db_conn)
70         self.assertEqual(obj.id_, 2)
71         for k, v in self.test_versioneds.items():
72             self.check_saving_of_versioned(k, v)
73
74     def check_storage(self, content: list[Any]) -> None:
75         """Test cache and DB equal content."""
76         expected_cache = {}
77         for item in content:
78             expected_cache[item.id_] = item
79         self.assertEqual(self.checked_class.get_cache(), expected_cache)
80         hashes_content = [hash(x) for x in content]
81         db_found: list[Any] = []
82         for item in content:
83             assert isinstance(item.id_, type(self.default_ids[0]))
84             for row in self.db_conn.row_where(self.checked_class.table_name,
85                                               'id', item.id_):
86                 db_found += [self.checked_class.from_table_row(self.db_conn,
87                                                                row)]
88         hashes_db_found = [hash(x) for x in db_found]
89         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
90
91     def check_saving_and_caching(self, **kwargs: Any) -> None:
92         """Test instance.save in its core without relations."""
93         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
94         # check object init itself doesn't store anything yet
95         self.check_storage([])
96         # check saving sets core attributes properly
97         obj.save(self.db_conn)
98         for key, value in kwargs.items():
99             self.assertEqual(getattr(obj, key), value)
100         # check saving stored properly in cache and DB
101         self.check_storage([obj])
102
103     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
104         """Test owner's versioned attributes."""
105         owner = self.checked_class(None)
106         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
107         attr = getattr(owner, attr_name)
108         attr.set(vals[0])
109         attr.set(vals[1])
110         owner.save(self.db_conn)
111         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
112         attr = getattr(retrieved, attr_name)
113         self.assertEqual(sorted(attr.history.values()), vals)
114
115     def check_by_id(self) -> None:
116         """Test .by_id(), including creation."""
117         # check failure if not yet saved
118         id1, id2 = self.default_ids[0], self.default_ids[1]
119         obj = self.checked_class(id1)  # pylint: disable=not-callable
120         with self.assertRaises(NotFoundException):
121             self.checked_class.by_id(self.db_conn, id1)
122         # check identity of saved and retrieved
123         obj.save(self.db_conn)
124         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
125         # check create=True acts like normal instantiation (sans saving)
126         by_id_created = self.checked_class.by_id(self.db_conn, id2,
127                                                  create=True)
128         # pylint: disable=not-callable
129         self.assertEqual(self.checked_class(id2), by_id_created)
130         self.check_storage([obj])
131
132     def check_from_table_row(self, *args: Any) -> None:
133         """Test .from_table_row() properly reads in class from DB"""
134         id_ = self.default_ids[0]
135         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
136         obj.save(self.db_conn)
137         assert isinstance(obj.id_, type(self.default_ids[0]))
138         for row in self.db_conn.row_where(self.checked_class.table_name,
139                                           'id', obj.id_):
140             hash_original = hash(obj)
141             retrieved = self.checked_class.from_table_row(self.db_conn, row)
142             self.assertEqual(hash_original, hash(retrieved))
143             self.assertEqual({retrieved.id_: retrieved},
144                              self.checked_class.get_cache())
145
146     def check_versioned_from_table_row(self, attr_name: str,
147                                        type_: type) -> None:
148         """Test .from_table_row() reads versioned attributes from DB."""
149         owner = self.checked_class(None)
150         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
151         attr = getattr(owner, attr_name)
152         attr.set(vals[0])
153         attr.set(vals[1])
154         owner.save(self.db_conn)
155         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
156             retrieved = owner.__class__.from_table_row(self.db_conn, row)
157             attr = getattr(retrieved, attr_name)
158             self.assertEqual(sorted(attr.history.values()), vals)
159
160     def check_all(self) -> tuple[Any, Any, Any]:
161         """Test .all()."""
162         # pylint: disable=not-callable
163         item1 = self.checked_class(self.default_ids[0])
164         item2 = self.checked_class(self.default_ids[1])
165         item3 = self.checked_class(self.default_ids[2])
166         # check pre-save .all() returns empty list
167         self.assertEqual(self.checked_class.all(self.db_conn), [])
168         # check that all() shows all saved, but no unsaved items
169         item1.save(self.db_conn)
170         item3.save(self.db_conn)
171         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
172                          sorted([item1, item3]))
173         item2.save(self.db_conn)
174         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
175                          sorted([item1, item2, item3]))
176         return item1, item2, item3
177
178     def check_singularity(self, defaulting_field: str,
179                           non_default_value: Any, *args: Any) -> None:
180         """Test pointers made for single object keep pointing to it."""
181         id1 = self.default_ids[0]
182         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
183         obj.save(self.db_conn)
184         setattr(obj, defaulting_field, non_default_value)
185         retrieved = self.checked_class.by_id(self.db_conn, id1)
186         self.assertEqual(non_default_value,
187                          getattr(retrieved, defaulting_field))
188
189     def check_versioned_singularity(self) -> None:
190         """Test singularity of VersionedAttributes on saving (with .title)."""
191         obj = self.checked_class(None)  # pylint: disable=not-callable
192         obj.save(self.db_conn)
193         assert isinstance(obj.id_, int)
194         obj.title.set('named')
195         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
196         self.assertEqual(obj.title.history, retrieved.title.history)
197
198     def check_remove(self, *args: Any) -> None:
199         """Test .remove() effects on DB and cache."""
200         id_ = self.default_ids[0]
201         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
202         with self.assertRaises(HandledException):
203             obj.remove(self.db_conn)
204         obj.save(self.db_conn)
205         obj.remove(self.db_conn)
206         self.check_storage([])
207
208
209 class TestCaseWithServer(TestCaseWithDB):
210     """Module tests against our HTTP server/handler (and database)."""
211
212     def setUp(self) -> None:
213         super().setUp()
214         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
215         self.server_thread = Thread(target=self.httpd.serve_forever)
216         self.server_thread.daemon = True
217         self.server_thread.start()
218         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
219                                    self.httpd.server_address[1])
220         self.httpd.set_json_mode()
221
222     def tearDown(self) -> None:
223         self.httpd.shutdown()
224         self.httpd.server_close()
225         self.server_thread.join()
226         super().tearDown()
227
228     def check_redirect(self, target: str) -> None:
229         """Check that self.conn answers with a 302 redirect to target."""
230         response = self.conn.getresponse()
231         self.assertEqual(response.status, 302)
232         self.assertEqual(response.getheader('Location'), target)
233
234     def check_get(self, target: str, expected_code: int) -> None:
235         """Check that a GET to target yields expected_code."""
236         self.conn.request('GET', target)
237         self.assertEqual(self.conn.getresponse().status, expected_code)
238
239     def check_post(self, data: Mapping[str, object], target: str,
240                    expected_code: int, redirect_location: str = '') -> None:
241         """Check that POST of data to target yields expected_code."""
242         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
243         headers = {'Content-Type': 'application/x-www-form-urlencoded',
244                    'Content-Length': str(len(encoded_form_data))}
245         self.conn.request('POST', target,
246                           body=encoded_form_data, headers=headers)
247         if 302 == expected_code:
248             if redirect_location == '':
249                 redirect_location = target
250             self.check_redirect(redirect_location)
251         else:
252             self.assertEqual(self.conn.getresponse().status, expected_code)
253
254     def check_get_defaults(self, path: str) -> None:
255         """Some standard model paths to test."""
256         self.check_get(path, 200)
257         self.check_get(f'{path}?id=', 200)
258         self.check_get(f'{path}?id=foo', 400)
259         self.check_get(f'/{path}?id=0', 500)
260         self.check_get(f'{path}?id=1', 200)
261
262     def post_process(self, id_: int = 1,
263                      form_data: dict[str, Any] | None = None
264                      ) -> dict[str, Any]:
265         """POST basic Process."""
266         if not form_data:
267             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
268         self.check_post(form_data, f'/process?id={id_}', 302,
269                         f'/process?id={id_}')
270         return form_data
271
272     @staticmethod
273     def blank_history_keys_in(d: dict[str, object]) -> None:
274         """Re-write "history" object keys to bracketed integer strings."""
275         def walk_tree(d: Any) -> Any:
276             if isinstance(d, dict):
277                 if 'history' in d.keys():
278                     vals = d['history'].values()
279                     history = {}
280                     for i, val in enumerate(vals):
281                         history[f'[{i}]'] = val
282                     d['history'] = history
283                 for k in list(d.keys()):
284                     walk_tree(d[k])
285             elif isinstance(d, list):
286                 d[:] = [walk_tree(i) for i in d]
287             return d
288         walk_tree(d)