1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
20 class TestCaseSansDB(TestCase):
21 """Tests requiring no DB setup."""
23 do_id_test: bool = False
24 default_init_args: list[Any] = []
25 versioned_defaults_to_test: dict[str, str | float] = {}
27 def test_id_setting(self) -> None:
28 """Test .id_ being set and its legal range being enforced."""
29 if not self.do_id_test:
31 with self.assertRaises(HandledException):
32 self.checked_class(0, *self.default_init_args)
33 obj = self.checked_class(5, *self.default_init_args)
34 self.assertEqual(obj.id_, 5)
36 def test_versioned_defaults(self) -> None:
37 """Test defaults of VersionedAttributes."""
38 if len(self.versioned_defaults_to_test) == 0:
40 obj = self.checked_class(1, *self.default_init_args)
41 for k, v in self.versioned_defaults_to_test.items():
42 self.assertEqual(getattr(obj, k).newest, v)
45 class TestCaseWithDB(TestCase):
46 """Module tests not requiring DB setup."""
48 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49 default_init_kwargs: dict[str, Any] = {}
50 test_versioneds: dict[str, type] = {}
52 def setUp(self) -> None:
53 Condition.empty_cache()
56 ProcessStep.empty_cache()
58 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59 self.db_conn = DatabaseConnection(self.db_file)
61 def tearDown(self) -> None:
63 remove_file(self.db_file.path)
66 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67 def wrapper(self: TestCaseWithDB) -> None:
68 if hasattr(self, 'checked_class'):
72 def _load_from_db(self, id_: int | str) -> list[object]:
73 db_found: list[object] = []
74 for row in self.db_conn.row_where(self.checked_class.table_name,
76 db_found += [self.checked_class.from_table_row(self.db_conn,
80 def _change_obj(self, obj: object) -> str:
81 attr_name: str = self.checked_class.to_save[-1]
82 attr = getattr(obj, attr_name)
83 new_attr: str | int | float | bool
84 if isinstance(attr, (int, float)):
86 elif isinstance(attr, str):
88 elif isinstance(attr, bool):
90 setattr(obj, attr_name, new_attr)
93 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
94 """Test both cache and DB equal content."""
97 expected_cache[item.id_] = item
98 self.assertEqual(self.checked_class.get_cache(), expected_cache)
99 hashes_content = [hash(x) for x in content]
100 db_found: list[Any] = []
102 assert isinstance(item.id_, type(self.default_ids[0]))
103 db_found += self._load_from_db(item.id_)
104 hashes_db_found = [hash(x) for x in db_found]
105 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
107 @_within_checked_class
108 def test_saving_versioned(self) -> None:
109 """Test storage and initialization of versioned attributes."""
110 def retrieve_attr_vals() -> list[object]:
111 attr_vals_saved: list[object] = []
112 assert hasattr(retrieved, 'id_')
113 for row in self.db_conn.row_where(attr.table_name, 'parent',
115 attr_vals_saved += [row[2]]
116 return attr_vals_saved
117 for attr_name, type_ in self.test_versioneds.items():
118 # fail saving attributes on non-saved owner
119 owner = self.checked_class(None, **self.default_init_kwargs)
120 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
121 attr = getattr(owner, attr_name)
124 with self.assertRaises(NotFoundException):
125 attr.save(self.db_conn)
126 owner.save(self.db_conn)
127 # check stored attribute is as expected
128 retrieved = self._load_from_db(owner.id_)[0]
129 attr = getattr(retrieved, attr_name)
130 self.assertEqual(sorted(attr.history.values()), vals)
131 # check owner.save() created entries in attr table
132 attr_vals_saved = retrieve_attr_vals()
133 self.assertEqual(vals, attr_vals_saved)
134 # check setting new val to attr inconsequential to DB without save
136 attr_vals_saved = retrieve_attr_vals()
137 self.assertEqual(vals, attr_vals_saved)
138 # check save finally adds new val
139 attr.save(self.db_conn)
140 attr_vals_saved = retrieve_attr_vals()
141 self.assertEqual(vals + [vals[0]], attr_vals_saved)
143 @_within_checked_class
144 def test_saving_and_caching(self) -> None:
145 """Test effects of .cache() and .save()."""
146 id1 = self.default_ids[0]
147 # check failure to cache without ID (if None-ID input possible)
148 if isinstance(id1, int):
149 obj0 = self.checked_class(None, **self.default_init_kwargs)
150 with self.assertRaises(HandledException):
152 # check mere object init itself doesn't even store in cache
153 obj1 = self.checked_class(id1, **self.default_init_kwargs)
154 self.assertEqual(self.checked_class.get_cache(), {})
155 # check .cache() fills cache, but not DB
157 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
158 db_found = self._load_from_db(id1)
159 self.assertEqual(db_found, [])
160 # check .save() sets ID (for int IDs), updates cache, and fills DB
161 # (expect ID to be set to id1, despite obj1 already having that as ID:
162 # it's generated by cursor.lastrowid on the DB table, and with obj1
163 # not written there, obj2 should get it first!)
164 id_input = None if isinstance(id1, int) else id1
165 obj2 = self.checked_class(id_input, **self.default_init_kwargs)
166 obj2.save(self.db_conn)
167 obj2_hash = hash(obj2)
168 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
169 db_found += self._load_from_db(id1)
170 self.assertEqual([hash(o) for o in db_found], [obj2_hash])
171 # check we cannot overwrite obj2 with obj1 despite its same ID,
172 # since it has disappeared now
173 with self.assertRaises(HandledException):
174 obj1.save(self.db_conn)
176 @_within_checked_class
177 def test_by_id(self) -> None:
179 id1, id2, _ = self.default_ids
180 # check failure if not yet saved
181 obj1 = self.checked_class(id1, **self.default_init_kwargs)
182 with self.assertRaises(NotFoundException):
183 self.checked_class.by_id(self.db_conn, id1)
184 # check identity of cached and retrieved
186 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
187 # check identity of saved and retrieved
188 obj2 = self.checked_class(id2, **self.default_init_kwargs)
189 obj2.save(self.db_conn)
190 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
192 @_within_checked_class
193 def test_by_id_or_create(self) -> None:
194 """Test .by_id_or_create."""
195 # check .by_id_or_create acts like normal instantiation (sans saving)
196 id_ = self.default_ids[0]
197 if not self.checked_class.can_create_by_id:
198 with self.assertRaises(HandledException):
199 self.checked_class.by_id_or_create(self.db_conn, id_)
200 # check .by_id_or_create fails if wrong class
202 by_id_created = self.checked_class.by_id_or_create(self.db_conn,
204 with self.assertRaises(NotFoundException):
205 self.checked_class.by_id(self.db_conn, id_)
206 self.assertEqual(self.checked_class(id_), by_id_created)
208 @_within_checked_class
209 def test_from_table_row(self) -> None:
210 """Test .from_table_row() properly reads in class directly from DB."""
211 id_ = self.default_ids[0]
212 obj = self.checked_class(id_, **self.default_init_kwargs)
213 obj.save(self.db_conn)
214 assert isinstance(obj.id_, type(id_))
215 for row in self.db_conn.row_where(self.checked_class.table_name,
217 # check .from_table_row reproduces state saved, no matter if obj
218 # later changed (with caching even)
219 hash_original = hash(obj)
220 attr_name = self._change_obj(obj)
222 to_cmp = getattr(obj, attr_name)
223 retrieved = self.checked_class.from_table_row(self.db_conn, row)
224 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
225 self.assertEqual(hash_original, hash(retrieved))
226 # check cache contains what .from_table_row just produced
227 self.assertEqual({retrieved.id_: retrieved},
228 self.checked_class.get_cache())
229 # check .from_table_row also reads versioned attributes from DB
230 for attr_name, type_ in self.test_versioneds.items():
231 owner = self.checked_class(None)
232 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
233 attr = getattr(owner, attr_name)
236 owner.save(self.db_conn)
237 for row in self.db_conn.row_where(owner.table_name, 'id',
239 retrieved = owner.__class__.from_table_row(self.db_conn, row)
240 attr = getattr(retrieved, attr_name)
241 self.assertEqual(sorted(attr.history.values()), vals)
243 @_within_checked_class
244 def test_all(self) -> None:
245 """Test .all() and its relation to cache and savings."""
246 id_1, id_2, id_3 = self.default_ids
247 item1 = self.checked_class(id_1, **self.default_init_kwargs)
248 item2 = self.checked_class(id_2, **self.default_init_kwargs)
249 item3 = self.checked_class(id_3, **self.default_init_kwargs)
250 # check .all() returns empty list on un-cached items
251 self.assertEqual(self.checked_class.all(self.db_conn), [])
252 # check that all() shows only cached/saved items
254 item3.save(self.db_conn)
255 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
256 sorted([item1, item3]))
257 item2.save(self.db_conn)
258 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
259 sorted([item1, item2, item3]))
261 @_within_checked_class
262 def test_singularity(self) -> None:
263 """Test pointers made for single object keep pointing to it."""
264 id1 = self.default_ids[0]
265 obj = self.checked_class(id1, **self.default_init_kwargs)
266 obj.save(self.db_conn)
267 # change object, expect retrieved through .by_id to carry change
268 attr_name = self._change_obj(obj)
269 new_attr = getattr(obj, attr_name)
270 retrieved = self.checked_class.by_id(self.db_conn, id1)
271 self.assertEqual(new_attr, getattr(retrieved, attr_name))
273 @_within_checked_class
274 def test_versioned_singularity_title(self) -> None:
275 """Test singularity of VersionedAttributes on saving (with .title)."""
276 if 'title' in self.test_versioneds:
277 obj = self.checked_class(None)
278 obj.save(self.db_conn)
279 assert isinstance(obj.id_, int)
280 # change obj, expect retrieved through .by_id to carry change
281 obj.title.set('named')
282 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
283 self.assertEqual(obj.title.history, retrieved.title.history)
285 @_within_checked_class
286 def test_remove(self) -> None:
287 """Test .remove() effects on DB and cache."""
288 id_ = self.default_ids[0]
289 obj = self.checked_class(id_, **self.default_init_kwargs)
290 # check removal only works after saving
291 with self.assertRaises(HandledException):
292 obj.remove(self.db_conn)
293 obj.save(self.db_conn)
294 obj.remove(self.db_conn)
295 # check access to obj fails after removal
296 with self.assertRaises(HandledException):
298 # check DB and cache now empty
299 self.check_identity_with_cache_and_db([])
302 class TestCaseWithServer(TestCaseWithDB):
303 """Module tests against our HTTP server/handler (and database)."""
305 def setUp(self) -> None:
307 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
308 self.server_thread = Thread(target=self.httpd.serve_forever)
309 self.server_thread.daemon = True
310 self.server_thread.start()
311 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
312 self.httpd.server_address[1])
313 self.httpd.set_json_mode()
315 def tearDown(self) -> None:
316 self.httpd.shutdown()
317 self.httpd.server_close()
318 self.server_thread.join()
321 def check_redirect(self, target: str) -> None:
322 """Check that self.conn answers with a 302 redirect to target."""
323 response = self.conn.getresponse()
324 self.assertEqual(response.status, 302)
325 self.assertEqual(response.getheader('Location'), target)
327 def check_get(self, target: str, expected_code: int) -> None:
328 """Check that a GET to target yields expected_code."""
329 self.conn.request('GET', target)
330 self.assertEqual(self.conn.getresponse().status, expected_code)
332 def check_post(self, data: Mapping[str, object], target: str,
333 expected_code: int, redirect_location: str = '') -> None:
334 """Check that POST of data to target yields expected_code."""
335 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
336 headers = {'Content-Type': 'application/x-www-form-urlencoded',
337 'Content-Length': str(len(encoded_form_data))}
338 self.conn.request('POST', target,
339 body=encoded_form_data, headers=headers)
340 if 302 == expected_code:
341 if redirect_location == '':
342 redirect_location = target
343 self.check_redirect(redirect_location)
345 self.assertEqual(self.conn.getresponse().status, expected_code)
347 def check_get_defaults(self, path: str) -> None:
348 """Some standard model paths to test."""
349 self.check_get(path, 200)
350 self.check_get(f'{path}?id=', 200)
351 self.check_get(f'{path}?id=foo', 400)
352 self.check_get(f'/{path}?id=0', 500)
353 self.check_get(f'{path}?id=1', 200)
355 def post_process(self, id_: int = 1,
356 form_data: dict[str, Any] | None = None
358 """POST basic Process."""
360 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
361 self.check_post(form_data, f'/process?id={id_}', 302,
362 f'/process?id={id_}')
365 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
366 """Compare JSON on GET path with expected.
368 To simplify comparison of VersionedAttribute histories, transforms
369 timestamp keys of VersionedAttribute history keys into integers
370 counting chronologically forward from 0.
372 def rewrite_history_keys_in(item: Any) -> Any:
373 if isinstance(item, dict):
374 if '_versioned' in item.keys():
375 for k in item['_versioned']:
376 vals = item['_versioned'][k].values()
378 for i, val in enumerate(vals):
380 item['_versioned'][k] = history
381 for k in list(item.keys()):
382 rewrite_history_keys_in(item[k])
383 elif isinstance(item, list):
384 item[:] = [rewrite_history_keys_in(i) for i in item]
386 self.conn.request('GET', path)
387 response = self.conn.getresponse()
388 self.assertEqual(response.status, 200)
389 retrieved = json_loads(response.read().decode())
390 rewrite_history_keys_in(retrieved)
391 self.assertEqual(expected, retrieved)