1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from json import loads as json_loads
6 from urllib.parse import urlencode
8 from os import remove as remove_file
9 from typing import Mapping, Any
10 from plomtask.db import DatabaseFile, DatabaseConnection
11 from plomtask.http import TaskHandler, TaskServer
12 from plomtask.processes import Process, ProcessStep
13 from plomtask.conditions import Condition
14 from plomtask.days import Day
15 from plomtask.todos import Todo
16 from plomtask.exceptions import NotFoundException, HandledException
19 class TestCaseSansDB(TestCase):
20 """Tests requiring no DB setup."""
22 do_id_test: bool = False
23 default_init_args: list[Any] = []
24 versioned_defaults_to_test: dict[str, str | float] = {}
26 def test_id_setting(self) -> None:
27 """Test .id_ being set and its legal range being enforced."""
28 if not self.do_id_test:
30 with self.assertRaises(HandledException):
31 self.checked_class(0, *self.default_init_args)
32 obj = self.checked_class(5, *self.default_init_args)
33 self.assertEqual(obj.id_, 5)
35 def test_versioned_defaults(self) -> None:
36 """Test defaults of VersionedAttributes."""
37 if len(self.versioned_defaults_to_test) == 0:
39 obj = self.checked_class(1, *self.default_init_args)
40 for k, v in self.versioned_defaults_to_test.items():
41 self.assertEqual(getattr(obj, k).newest, v)
44 class TestCaseWithDB(TestCase):
45 """Module tests not requiring DB setup."""
47 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
48 default_init_kwargs: dict[str, Any] = {}
49 test_versioneds: dict[str, type] = {}
51 def setUp(self) -> None:
52 Condition.empty_cache()
55 ProcessStep.empty_cache()
57 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
58 self.db_conn = DatabaseConnection(self.db_file)
60 def tearDown(self) -> None:
62 remove_file(self.db_file.path)
64 def test_saving_and_caching(self) -> None:
65 """Test storage and initialization of instances and attributes."""
66 if not hasattr(self, 'checked_class'):
68 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69 obj = self.checked_class(None, **self.default_init_kwargs)
70 obj.save(self.db_conn)
71 self.assertEqual(obj.id_, 2)
72 for k, v in self.test_versioneds.items():
73 self.check_saving_of_versioned(k, v)
75 def check_storage(self, content: list[Any]) -> None:
76 """Test cache and DB equal content."""
79 expected_cache[item.id_] = item
80 self.assertEqual(self.checked_class.get_cache(), expected_cache)
81 hashes_content = [hash(x) for x in content]
82 db_found: list[Any] = []
84 assert isinstance(item.id_, type(self.default_ids[0]))
85 for row in self.db_conn.row_where(self.checked_class.table_name,
87 db_found += [self.checked_class.from_table_row(self.db_conn,
89 hashes_db_found = [hash(x) for x in db_found]
90 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
92 def check_saving_and_caching(self, **kwargs: Any) -> None:
93 """Test instance.save in its core without relations."""
94 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
95 # check object init itself doesn't store anything yet
96 self.check_storage([])
97 # check saving sets core attributes properly
98 obj.save(self.db_conn)
99 for key, value in kwargs.items():
100 self.assertEqual(getattr(obj, key), value)
101 # check saving stored properly in cache and DB
102 self.check_storage([obj])
104 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
105 """Test owner's versioned attributes."""
106 owner = self.checked_class(None)
107 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
108 attr = getattr(owner, attr_name)
111 owner.save(self.db_conn)
112 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
113 attr = getattr(retrieved, attr_name)
114 self.assertEqual(sorted(attr.history.values()), vals)
116 def check_by_id(self) -> None:
117 """Test .by_id(), including creation."""
118 # check failure if not yet saved
119 id1, id2 = self.default_ids[0], self.default_ids[1]
120 obj = self.checked_class(id1) # pylint: disable=not-callable
121 with self.assertRaises(NotFoundException):
122 self.checked_class.by_id(self.db_conn, id1)
123 # check identity of saved and retrieved
124 obj.save(self.db_conn)
125 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
126 # check create=True acts like normal instantiation (sans saving)
127 by_id_created = self.checked_class.by_id(self.db_conn, id2,
129 # pylint: disable=not-callable
130 self.assertEqual(self.checked_class(id2), by_id_created)
131 self.check_storage([obj])
133 def test_from_table_row(self) -> None:
134 """Test .from_table_row() properly reads in class directly from DB."""
135 if not hasattr(self, 'checked_class'):
137 id_ = self.default_ids[0]
138 obj = self.checked_class(id_, **self.default_init_kwargs)
139 obj.save(self.db_conn)
140 assert isinstance(obj.id_, type(self.default_ids[0]))
141 for row in self.db_conn.row_where(self.checked_class.table_name,
143 # check .from_table_row reproduces state saved, no matter if obj
144 # later changed (with caching even)
145 hash_original = hash(obj)
146 attr_name = self.checked_class.to_save[-1]
147 attr = getattr(obj, attr_name)
148 if isinstance(attr, (int, float)):
149 setattr(obj, attr_name, attr + 1)
150 elif isinstance(attr, str):
151 setattr(obj, attr_name, attr + "_")
152 elif isinstance(attr, bool):
153 setattr(obj, attr_name, not attr)
155 to_cmp = getattr(obj, attr_name)
156 retrieved = self.checked_class.from_table_row(self.db_conn, row)
157 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
158 self.assertEqual(hash_original, hash(retrieved))
159 # check cache contains what .from_table_row just produced
160 self.assertEqual({retrieved.id_: retrieved},
161 self.checked_class.get_cache())
163 def check_versioned_from_table_row(self, attr_name: str,
164 type_: type) -> None:
165 """Test .from_table_row() reads versioned attributes from DB."""
166 owner = self.checked_class(None)
167 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
168 attr = getattr(owner, attr_name)
171 owner.save(self.db_conn)
172 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
173 retrieved = owner.__class__.from_table_row(self.db_conn, row)
174 attr = getattr(retrieved, attr_name)
175 self.assertEqual(sorted(attr.history.values()), vals)
177 def test_all(self) -> None:
178 """Test .all() and its relation to cache and savings."""
179 if not hasattr(self, 'checked_class'):
181 id_1, id_2, id_3 = self.default_ids
182 item1 = self.checked_class(id_1, **self.default_init_kwargs)
183 item2 = self.checked_class(id_2, **self.default_init_kwargs)
184 item3 = self.checked_class(id_3, **self.default_init_kwargs)
185 # check .all() returns empty list on un-cached items
186 self.assertEqual(self.checked_class.all(self.db_conn), [])
187 # check that all() shows only cached/saved items
189 item3.save(self.db_conn)
190 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
191 sorted([item1, item3]))
192 item2.save(self.db_conn)
193 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
194 sorted([item1, item2, item3]))
196 def test_singularity(self) -> None:
197 """Test pointers made for single object keep pointing to it."""
198 if not hasattr(self, 'checked_class'):
200 id1 = self.default_ids[0]
201 obj = self.checked_class(id1, **self.default_init_kwargs)
202 obj.save(self.db_conn)
203 attr_name = self.checked_class.to_save[-1]
204 attr = getattr(obj, attr_name)
205 new_attr: str | int | float | bool
206 if isinstance(attr, (int, float)):
208 elif isinstance(attr, str):
209 new_attr = attr + '_'
210 elif isinstance(attr, bool):
212 setattr(obj, attr_name, new_attr)
213 retrieved = self.checked_class.by_id(self.db_conn, id1)
214 self.assertEqual(new_attr, getattr(retrieved, attr_name))
216 def check_versioned_singularity(self) -> None:
217 """Test singularity of VersionedAttributes on saving (with .title)."""
218 obj = self.checked_class(None) # pylint: disable=not-callable
219 obj.save(self.db_conn)
220 assert isinstance(obj.id_, int)
221 obj.title.set('named')
222 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
223 self.assertEqual(obj.title.history, retrieved.title.history)
225 def check_remove(self, *args: Any) -> None:
226 """Test .remove() effects on DB and cache."""
227 id_ = self.default_ids[0]
228 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
229 with self.assertRaises(HandledException):
230 obj.remove(self.db_conn)
231 obj.save(self.db_conn)
232 obj.remove(self.db_conn)
233 self.check_storage([])
236 class TestCaseWithServer(TestCaseWithDB):
237 """Module tests against our HTTP server/handler (and database)."""
239 def setUp(self) -> None:
241 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
242 self.server_thread = Thread(target=self.httpd.serve_forever)
243 self.server_thread.daemon = True
244 self.server_thread.start()
245 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
246 self.httpd.server_address[1])
247 self.httpd.set_json_mode()
249 def tearDown(self) -> None:
250 self.httpd.shutdown()
251 self.httpd.server_close()
252 self.server_thread.join()
255 def check_redirect(self, target: str) -> None:
256 """Check that self.conn answers with a 302 redirect to target."""
257 response = self.conn.getresponse()
258 self.assertEqual(response.status, 302)
259 self.assertEqual(response.getheader('Location'), target)
261 def check_get(self, target: str, expected_code: int) -> None:
262 """Check that a GET to target yields expected_code."""
263 self.conn.request('GET', target)
264 self.assertEqual(self.conn.getresponse().status, expected_code)
266 def check_post(self, data: Mapping[str, object], target: str,
267 expected_code: int, redirect_location: str = '') -> None:
268 """Check that POST of data to target yields expected_code."""
269 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
270 headers = {'Content-Type': 'application/x-www-form-urlencoded',
271 'Content-Length': str(len(encoded_form_data))}
272 self.conn.request('POST', target,
273 body=encoded_form_data, headers=headers)
274 if 302 == expected_code:
275 if redirect_location == '':
276 redirect_location = target
277 self.check_redirect(redirect_location)
279 self.assertEqual(self.conn.getresponse().status, expected_code)
281 def check_get_defaults(self, path: str) -> None:
282 """Some standard model paths to test."""
283 self.check_get(path, 200)
284 self.check_get(f'{path}?id=', 200)
285 self.check_get(f'{path}?id=foo', 400)
286 self.check_get(f'/{path}?id=0', 500)
287 self.check_get(f'{path}?id=1', 200)
289 def post_process(self, id_: int = 1,
290 form_data: dict[str, Any] | None = None
292 """POST basic Process."""
294 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
295 self.check_post(form_data, f'/process?id={id_}', 302,
296 f'/process?id={id_}')
299 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
300 """Compare JSON on GET path with expected.
302 To simplify comparison of VersionedAttribute histories, transforms
303 timestamp keys of VersionedAttribute history keys into integers
304 counting chronologically forward from 0.
306 def rewrite_history_keys_in(item: Any) -> Any:
307 if isinstance(item, dict):
308 if '_versioned' in item.keys():
309 for k in item['_versioned']:
310 vals = item['_versioned'][k].values()
312 for i, val in enumerate(vals):
314 item['_versioned'][k] = history
315 for k in list(item.keys()):
316 rewrite_history_keys_in(item[k])
317 elif isinstance(item, list):
318 item[:] = [rewrite_history_keys_in(i) for i in item]
320 self.conn.request('GET', path)
321 response = self.conn.getresponse()
322 self.assertEqual(response.status, 200)
323 retrieved = json_loads(response.read().decode())
324 rewrite_history_keys_in(retrieved)
325 self.assertEqual(expected, retrieved)