1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
20 class TestCaseSansDB(TestCase):
21 """Tests requiring no DB setup."""
23 do_id_test: bool = False
24 default_init_args: list[Any] = []
25 versioned_defaults_to_test: dict[str, str | float] = {}
27 def test_id_setting(self) -> None:
28 """Test .id_ being set and its legal range being enforced."""
29 if not self.do_id_test:
31 with self.assertRaises(HandledException):
32 self.checked_class(0, *self.default_init_args)
33 obj = self.checked_class(5, *self.default_init_args)
34 self.assertEqual(obj.id_, 5)
36 def test_versioned_defaults(self) -> None:
37 """Test defaults of VersionedAttributes."""
38 if len(self.versioned_defaults_to_test) == 0:
40 obj = self.checked_class(1, *self.default_init_args)
41 for k, v in self.versioned_defaults_to_test.items():
42 self.assertEqual(getattr(obj, k).newest, v)
45 class TestCaseWithDB(TestCase):
46 """Module tests not requiring DB setup."""
48 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49 default_init_kwargs: dict[str, Any] = {}
50 test_versioneds: dict[str, type] = {}
52 def setUp(self) -> None:
53 Condition.empty_cache()
56 ProcessStep.empty_cache()
58 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59 self.db_conn = DatabaseConnection(self.db_file)
61 def tearDown(self) -> None:
63 remove_file(self.db_file.path)
66 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67 def wrapper(self: TestCaseWithDB) -> None:
68 if hasattr(self, 'checked_class'):
72 def _load_from_db(self, id_: int | str) -> list[object]:
73 db_found: list[object] = []
74 for row in self.db_conn.row_where(self.checked_class.table_name,
76 db_found += [self.checked_class.from_table_row(self.db_conn,
80 def _change_obj(self, obj: object) -> str:
81 attr_name: str = self.checked_class.to_save[-1]
82 attr = getattr(obj, attr_name)
83 new_attr: str | int | float | bool
84 if isinstance(attr, (int, float)):
86 elif isinstance(attr, str):
88 elif isinstance(attr, bool):
90 setattr(obj, attr_name, new_attr)
93 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
94 """Test both cache and DB equal content."""
97 expected_cache[item.id_] = item
98 self.assertEqual(self.checked_class.get_cache(), expected_cache)
99 hashes_content = [hash(x) for x in content]
100 db_found: list[Any] = []
102 assert isinstance(item.id_, type(self.default_ids[0]))
103 db_found += self._load_from_db(item.id_)
104 hashes_db_found = [hash(x) for x in db_found]
105 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
107 @_within_checked_class
108 def test_saving_versioned(self) -> None:
109 """Test storage and initialization of versioned attributes."""
110 def retrieve_attr_vals() -> list[object]:
111 attr_vals_saved: list[object] = []
112 assert hasattr(retrieved, 'id_')
113 for row in self.db_conn.row_where(attr.table_name, 'parent',
115 attr_vals_saved += [row[2]]
116 return attr_vals_saved
117 for attr_name, type_ in self.test_versioneds.items():
118 # fail saving attributes on non-saved owner
119 owner = self.checked_class(None, **self.default_init_kwargs)
120 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
121 attr = getattr(owner, attr_name)
124 with self.assertRaises(NotFoundException):
125 attr.save(self.db_conn)
126 owner.save(self.db_conn)
127 # check stored attribute is as expected
128 retrieved = self._load_from_db(owner.id_)[0]
129 attr = getattr(retrieved, attr_name)
130 self.assertEqual(sorted(attr.history.values()), vals)
131 # check owner.save() created entries in attr table
132 attr_vals_saved = retrieve_attr_vals()
133 self.assertEqual(vals, attr_vals_saved)
134 # check setting new val to attr inconsequential to DB without save
136 attr_vals_saved = retrieve_attr_vals()
137 self.assertEqual(vals, attr_vals_saved)
138 # check save finally adds new val
139 attr.save(self.db_conn)
140 attr_vals_saved = retrieve_attr_vals()
141 self.assertEqual(vals + [vals[0]], attr_vals_saved)
143 @_within_checked_class
144 def test_saving_and_caching(self) -> None:
145 """Test effects of .cache() and .save()."""
146 id1 = self.default_ids[0]
147 # check failure to cache without ID (if None-ID input possible)
148 if isinstance(id1, int):
149 obj0 = self.checked_class(None, **self.default_init_kwargs)
150 with self.assertRaises(HandledException):
152 # check mere object init itself doesn't even store in cache
153 obj1 = self.checked_class(id1, **self.default_init_kwargs)
154 self.assertEqual(self.checked_class.get_cache(), {})
155 # check .cache() fills cache, but not DB
157 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
158 db_found = self._load_from_db(id1)
159 self.assertEqual(db_found, [])
160 # check .save() sets ID (for int IDs), updates cache, and fills DB
161 # (expect ID to be set to id1, despite obj1 already having that as ID:
162 # it's generated by cursor.lastrowid on the DB table, and with obj1
163 # not written there, obj2 should get it first!)
164 id_input = None if isinstance(id1, int) else id1
165 obj2 = self.checked_class(id_input, **self.default_init_kwargs)
166 obj2.save(self.db_conn)
167 obj2_hash = hash(obj2)
168 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
169 db_found += self._load_from_db(id1)
170 self.assertEqual([hash(o) for o in db_found], [obj2_hash])
171 # check we cannot overwrite obj2 with obj1 despite its same ID,
172 # since it has disappeared now
173 with self.assertRaises(HandledException):
174 obj1.save(self.db_conn)
176 @_within_checked_class
177 def test_by_id(self) -> None:
179 id1, id2, _ = self.default_ids
180 # check failure if not yet saved
181 obj1 = self.checked_class(id1, **self.default_init_kwargs)
182 with self.assertRaises(NotFoundException):
183 self.checked_class.by_id(self.db_conn, id1)
184 # check identity of cached and retrieved
186 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
187 # check identity of saved and retrieved
188 obj2 = self.checked_class(id2, **self.default_init_kwargs)
189 obj2.save(self.db_conn)
190 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
192 @_within_checked_class
193 def test_by_id_or_create(self) -> None:
194 """Test .by_id_or_create."""
195 # check .by_id_or_create fails if wrong class
196 if not self.checked_class.can_create_by_id:
197 with self.assertRaises(HandledException):
198 self.checked_class.by_id_or_create(self.db_conn, None)
200 # check ID input of None creates, on saving, ID=1,2,… for int IDs
201 if isinstance(self.default_ids[0], int):
203 item = self.checked_class.by_id_or_create(self.db_conn, None)
204 self.assertEqual(item.id_, None)
205 item.save(self.db_conn)
206 self.assertEqual(item.id_, n+1)
207 # check .by_id_or_create acts like normal instantiation (sans saving)
208 id_ = self.default_ids[2]
209 item = self.checked_class.by_id_or_create(self.db_conn, id_)
210 self.assertEqual(item.id_, id_)
211 with self.assertRaises(NotFoundException):
212 self.checked_class.by_id(self.db_conn, item.id_)
213 self.assertEqual(self.checked_class(item.id_), item)
215 @_within_checked_class
216 def test_from_table_row(self) -> None:
217 """Test .from_table_row() properly reads in class directly from DB."""
218 id_ = self.default_ids[0]
219 obj = self.checked_class(id_, **self.default_init_kwargs)
220 obj.save(self.db_conn)
221 assert isinstance(obj.id_, type(id_))
222 for row in self.db_conn.row_where(self.checked_class.table_name,
224 # check .from_table_row reproduces state saved, no matter if obj
225 # later changed (with caching even)
226 hash_original = hash(obj)
227 attr_name = self._change_obj(obj)
229 to_cmp = getattr(obj, attr_name)
230 retrieved = self.checked_class.from_table_row(self.db_conn, row)
231 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
232 self.assertEqual(hash_original, hash(retrieved))
233 # check cache contains what .from_table_row just produced
234 self.assertEqual({retrieved.id_: retrieved},
235 self.checked_class.get_cache())
236 # check .from_table_row also reads versioned attributes from DB
237 for attr_name, type_ in self.test_versioneds.items():
238 owner = self.checked_class(None)
239 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
240 attr = getattr(owner, attr_name)
243 owner.save(self.db_conn)
244 for row in self.db_conn.row_where(owner.table_name, 'id',
246 retrieved = owner.__class__.from_table_row(self.db_conn, row)
247 attr = getattr(retrieved, attr_name)
248 self.assertEqual(sorted(attr.history.values()), vals)
250 @_within_checked_class
251 def test_all(self) -> None:
252 """Test .all() and its relation to cache and savings."""
253 id_1, id_2, id_3 = self.default_ids
254 item1 = self.checked_class(id_1, **self.default_init_kwargs)
255 item2 = self.checked_class(id_2, **self.default_init_kwargs)
256 item3 = self.checked_class(id_3, **self.default_init_kwargs)
257 # check .all() returns empty list on un-cached items
258 self.assertEqual(self.checked_class.all(self.db_conn), [])
259 # check that all() shows only cached/saved items
261 item3.save(self.db_conn)
262 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
263 sorted([item1, item3]))
264 item2.save(self.db_conn)
265 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
266 sorted([item1, item2, item3]))
268 @_within_checked_class
269 def test_singularity(self) -> None:
270 """Test pointers made for single object keep pointing to it."""
271 id1 = self.default_ids[0]
272 obj = self.checked_class(id1, **self.default_init_kwargs)
273 obj.save(self.db_conn)
274 # change object, expect retrieved through .by_id to carry change
275 attr_name = self._change_obj(obj)
276 new_attr = getattr(obj, attr_name)
277 retrieved = self.checked_class.by_id(self.db_conn, id1)
278 self.assertEqual(new_attr, getattr(retrieved, attr_name))
280 @_within_checked_class
281 def test_versioned_singularity_title(self) -> None:
282 """Test singularity of VersionedAttributes on saving (with .title)."""
283 if 'title' in self.test_versioneds:
284 obj = self.checked_class(None)
285 obj.save(self.db_conn)
286 assert isinstance(obj.id_, int)
287 # change obj, expect retrieved through .by_id to carry change
288 obj.title.set('named')
289 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
290 self.assertEqual(obj.title.history, retrieved.title.history)
292 @_within_checked_class
293 def test_remove(self) -> None:
294 """Test .remove() effects on DB and cache."""
295 id_ = self.default_ids[0]
296 obj = self.checked_class(id_, **self.default_init_kwargs)
297 # check removal only works after saving
298 with self.assertRaises(HandledException):
299 obj.remove(self.db_conn)
300 obj.save(self.db_conn)
301 obj.remove(self.db_conn)
302 # check access to obj fails after removal
303 with self.assertRaises(HandledException):
305 # check DB and cache now empty
306 self.check_identity_with_cache_and_db([])
309 class TestCaseWithServer(TestCaseWithDB):
310 """Module tests against our HTTP server/handler (and database)."""
312 def setUp(self) -> None:
314 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
315 self.server_thread = Thread(target=self.httpd.serve_forever)
316 self.server_thread.daemon = True
317 self.server_thread.start()
318 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
319 self.httpd.server_address[1])
320 self.httpd.set_json_mode()
322 def tearDown(self) -> None:
323 self.httpd.shutdown()
324 self.httpd.server_close()
325 self.server_thread.join()
328 def check_redirect(self, target: str) -> None:
329 """Check that self.conn answers with a 302 redirect to target."""
330 response = self.conn.getresponse()
331 self.assertEqual(response.status, 302)
332 self.assertEqual(response.getheader('Location'), target)
334 def check_get(self, target: str, expected_code: int) -> None:
335 """Check that a GET to target yields expected_code."""
336 self.conn.request('GET', target)
337 self.assertEqual(self.conn.getresponse().status, expected_code)
339 def check_post(self, data: Mapping[str, object], target: str,
340 expected_code: int, redirect_location: str = '') -> None:
341 """Check that POST of data to target yields expected_code."""
342 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
343 headers = {'Content-Type': 'application/x-www-form-urlencoded',
344 'Content-Length': str(len(encoded_form_data))}
345 self.conn.request('POST', target,
346 body=encoded_form_data, headers=headers)
347 if 302 == expected_code:
348 if redirect_location == '':
349 redirect_location = target
350 self.check_redirect(redirect_location)
352 self.assertEqual(self.conn.getresponse().status, expected_code)
354 def check_get_defaults(self, path: str) -> None:
355 """Some standard model paths to test."""
356 self.check_get(path, 200)
357 self.check_get(f'{path}?id=', 200)
358 self.check_get(f'{path}?id=foo', 400)
359 self.check_get(f'/{path}?id=0', 500)
360 self.check_get(f'{path}?id=1', 200)
362 def post_process(self, id_: int = 1,
363 form_data: dict[str, Any] | None = None
365 """POST basic Process."""
367 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
368 self.check_post(form_data, f'/process?id={id_}', 302,
369 f'/process?id={id_}')
372 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
373 """Compare JSON on GET path with expected.
375 To simplify comparison of VersionedAttribute histories, transforms
376 timestamp keys of VersionedAttribute history keys into integers
377 counting chronologically forward from 0.
379 def rewrite_history_keys_in(item: Any) -> Any:
380 if isinstance(item, dict):
381 if '_versioned' in item.keys():
382 for k in item['_versioned']:
383 vals = item['_versioned'][k].values()
385 for i, val in enumerate(vals):
387 item['_versioned'][k] = history
388 for k in list(item.keys()):
389 rewrite_history_keys_in(item[k])
390 elif isinstance(item, list):
391 item[:] = [rewrite_history_keys_in(i) for i in item]
393 self.conn.request('GET', path)
394 response = self.conn.getresponse()
395 self.assertEqual(response.status, 200)
396 retrieved = json_loads(response.read().decode())
397 rewrite_history_keys_in(retrieved)
398 self.assertEqual(expected, retrieved)