1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
20 class TestCaseSansDB(TestCase):
21 """Tests requiring no DB setup."""
23 do_id_test: bool = False
24 default_init_args: list[Any] = []
25 versioned_defaults_to_test: dict[str, str | float] = {}
27 def test_id_setting(self) -> None:
28 """Test .id_ being set and its legal range being enforced."""
29 if not self.do_id_test:
31 with self.assertRaises(HandledException):
32 self.checked_class(0, *self.default_init_args)
33 obj = self.checked_class(5, *self.default_init_args)
34 self.assertEqual(obj.id_, 5)
36 def test_versioned_defaults(self) -> None:
37 """Test defaults of VersionedAttributes."""
38 if len(self.versioned_defaults_to_test) == 0:
40 obj = self.checked_class(1, *self.default_init_args)
41 for k, v in self.versioned_defaults_to_test.items():
42 self.assertEqual(getattr(obj, k).newest, v)
45 class TestCaseWithDB(TestCase):
46 """Module tests not requiring DB setup."""
48 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49 default_init_kwargs: dict[str, Any] = {}
50 test_versioneds: dict[str, type] = {}
52 def setUp(self) -> None:
53 Condition.empty_cache()
56 ProcessStep.empty_cache()
58 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59 self.db_conn = DatabaseConnection(self.db_file)
61 def tearDown(self) -> None:
63 remove_file(self.db_file.path)
66 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67 def wrapper(self: TestCaseWithDB) -> None:
68 if hasattr(self, 'checked_class'):
72 def _load_from_db(self, id_: int | str) -> list[object]:
73 db_found: list[object] = []
74 for row in self.db_conn.row_where(self.checked_class.table_name,
76 db_found += [self.checked_class.from_table_row(self.db_conn,
80 @_within_checked_class
81 def test_saving_versioned(self) -> None:
82 """Test storage and initialization of versioned attributes."""
83 def retrieve_attr_vals() -> list[object]:
84 attr_vals_saved: list[object] = []
85 assert hasattr(retrieved, 'id_')
86 for row in self.db_conn.row_where(attr.table_name, 'parent',
88 attr_vals_saved += [row[2]]
89 return attr_vals_saved
90 for attr_name, type_ in self.test_versioneds.items():
91 # fail saving attributes on non-saved owner
92 owner = self.checked_class(None, **self.default_init_kwargs)
93 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
94 attr = getattr(owner, attr_name)
97 with self.assertRaises(NotFoundException):
98 attr.save(self.db_conn)
99 owner.save(self.db_conn)
100 # check stored attribute is as expected
101 retrieved = self._load_from_db(owner.id_)[0]
102 attr = getattr(retrieved, attr_name)
103 self.assertEqual(sorted(attr.history.values()), vals)
104 # check owner.save() created entries in attr table
105 attr_vals_saved = retrieve_attr_vals()
106 self.assertEqual(vals, attr_vals_saved)
107 # check setting new val to attr inconsequential to DB without save
109 attr_vals_saved = retrieve_attr_vals()
110 self.assertEqual(vals, attr_vals_saved)
111 # check save finally adds new val
112 attr.save(self.db_conn)
113 attr_vals_saved = retrieve_attr_vals()
114 self.assertEqual(vals + [vals[0]], attr_vals_saved)
116 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
117 """Test both cache and DB equal content."""
120 expected_cache[item.id_] = item
121 self.assertEqual(self.checked_class.get_cache(), expected_cache)
122 hashes_content = [hash(x) for x in content]
123 db_found: list[Any] = []
125 assert isinstance(item.id_, type(self.default_ids[0]))
126 db_found += self._load_from_db(item.id_)
127 hashes_db_found = [hash(x) for x in db_found]
128 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
130 @_within_checked_class
131 def test_saving_and_caching(self) -> None:
132 """Test effects of .cache() and .save()."""
133 id1 = self.default_ids[0]
134 # check failure to cache without ID (if None-ID input possible)
135 if isinstance(id1, int):
136 obj0 = self.checked_class(None, **self.default_init_kwargs)
137 with self.assertRaises(HandledException):
139 # check mere object init itself doesn't even store in cache
140 obj1 = self.checked_class(id1, **self.default_init_kwargs)
141 self.assertEqual(self.checked_class.get_cache(), {})
142 # check .cache() fills cache, but not DB
144 self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
145 db_found = self._load_from_db(id1)
146 self.assertEqual(db_found, [])
147 # check .save() sets ID (for int IDs), updates cache, and fills DB
148 # (expect ID to be set to id1, despite obj1 already having that as ID:
149 # it's generated by cursor.lastrowid on the DB table, and with obj1
150 # not written there, obj2 should get it first!)
151 id_input = None if isinstance(id1, int) else id1
152 obj2 = self.checked_class(id_input, **self.default_init_kwargs)
153 obj2.save(self.db_conn)
154 obj2_hash = hash(obj2)
155 self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
156 db_found += self._load_from_db(id1)
157 self.assertEqual([hash(o) for o in db_found], [obj2_hash])
158 # check we cannot overwrite obj2 with obj1 despite its same ID,
159 # since it has disappeared now
160 with self.assertRaises(HandledException):
161 obj1.save(self.db_conn)
163 @_within_checked_class
164 def test_by_id(self) -> None:
166 id1, id2, _ = self.default_ids
167 # check failure if not yet saved
168 obj1 = self.checked_class(id1, **self.default_init_kwargs)
169 with self.assertRaises(NotFoundException):
170 self.checked_class.by_id(self.db_conn, id1)
171 # check identity of cached and retrieved
173 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
174 # check identity of saved and retrieved
175 obj2 = self.checked_class(id2, **self.default_init_kwargs)
176 obj2.save(self.db_conn)
177 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
179 @_within_checked_class
180 def test_by_id_or_create(self) -> None:
181 """Test .by_id_or_create."""
182 # check .by_id_or_create acts like normal instantiation (sans saving)
183 id_ = self.default_ids[0]
184 if not self.checked_class.can_create_by_id:
185 with self.assertRaises(HandledException):
186 self.checked_class.by_id_or_create(self.db_conn, id_)
187 # check .by_id_or_create fails if wrong class
189 by_id_created = self.checked_class.by_id_or_create(self.db_conn,
191 with self.assertRaises(NotFoundException):
192 self.checked_class.by_id(self.db_conn, id_)
193 self.assertEqual(self.checked_class(id_), by_id_created)
195 @_within_checked_class
196 def test_from_table_row(self) -> None:
197 """Test .from_table_row() properly reads in class directly from DB."""
198 id_ = self.default_ids[0]
199 obj = self.checked_class(id_, **self.default_init_kwargs)
200 obj.save(self.db_conn)
201 assert isinstance(obj.id_, type(self.default_ids[0]))
202 for row in self.db_conn.row_where(self.checked_class.table_name,
204 # check .from_table_row reproduces state saved, no matter if obj
205 # later changed (with caching even)
206 hash_original = hash(obj)
207 attr_name = self.checked_class.to_save[-1]
208 attr = getattr(obj, attr_name)
209 if isinstance(attr, (int, float)):
210 setattr(obj, attr_name, attr + 1)
211 elif isinstance(attr, str):
212 setattr(obj, attr_name, attr + "_")
213 elif isinstance(attr, bool):
214 setattr(obj, attr_name, not attr)
216 to_cmp = getattr(obj, attr_name)
217 retrieved = self.checked_class.from_table_row(self.db_conn, row)
218 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
219 self.assertEqual(hash_original, hash(retrieved))
220 # check cache contains what .from_table_row just produced
221 self.assertEqual({retrieved.id_: retrieved},
222 self.checked_class.get_cache())
224 def check_versioned_from_table_row(self, attr_name: str,
225 type_: type) -> None:
226 """Test .from_table_row() reads versioned attributes from DB."""
227 owner = self.checked_class(None)
228 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
229 attr = getattr(owner, attr_name)
232 owner.save(self.db_conn)
233 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
234 retrieved = owner.__class__.from_table_row(self.db_conn, row)
235 attr = getattr(retrieved, attr_name)
236 self.assertEqual(sorted(attr.history.values()), vals)
238 @_within_checked_class
239 def test_all(self) -> None:
240 """Test .all() and its relation to cache and savings."""
241 id_1, id_2, id_3 = self.default_ids
242 item1 = self.checked_class(id_1, **self.default_init_kwargs)
243 item2 = self.checked_class(id_2, **self.default_init_kwargs)
244 item3 = self.checked_class(id_3, **self.default_init_kwargs)
245 # check .all() returns empty list on un-cached items
246 self.assertEqual(self.checked_class.all(self.db_conn), [])
247 # check that all() shows only cached/saved items
249 item3.save(self.db_conn)
250 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
251 sorted([item1, item3]))
252 item2.save(self.db_conn)
253 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
254 sorted([item1, item2, item3]))
256 @_within_checked_class
257 def test_singularity(self) -> None:
258 """Test pointers made for single object keep pointing to it."""
259 id1 = self.default_ids[0]
260 obj = self.checked_class(id1, **self.default_init_kwargs)
261 obj.save(self.db_conn)
262 attr_name = self.checked_class.to_save[-1]
263 attr = getattr(obj, attr_name)
264 new_attr: str | int | float | bool
265 if isinstance(attr, (int, float)):
267 elif isinstance(attr, str):
268 new_attr = attr + '_'
269 elif isinstance(attr, bool):
271 setattr(obj, attr_name, new_attr)
272 retrieved = self.checked_class.by_id(self.db_conn, id1)
273 self.assertEqual(new_attr, getattr(retrieved, attr_name))
275 def check_versioned_singularity(self) -> None:
276 """Test singularity of VersionedAttributes on saving (with .title)."""
277 obj = self.checked_class(None) # pylint: disable=not-callable
278 obj.save(self.db_conn)
279 assert isinstance(obj.id_, int)
280 obj.title.set('named')
281 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
282 self.assertEqual(obj.title.history, retrieved.title.history)
284 def check_remove(self, *args: Any) -> None:
285 """Test .remove() effects on DB and cache."""
286 id_ = self.default_ids[0]
287 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
288 with self.assertRaises(HandledException):
289 obj.remove(self.db_conn)
290 obj.save(self.db_conn)
291 obj.remove(self.db_conn)
292 self.check_identity_with_cache_and_db([])
295 class TestCaseWithServer(TestCaseWithDB):
296 """Module tests against our HTTP server/handler (and database)."""
298 def setUp(self) -> None:
300 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
301 self.server_thread = Thread(target=self.httpd.serve_forever)
302 self.server_thread.daemon = True
303 self.server_thread.start()
304 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
305 self.httpd.server_address[1])
306 self.httpd.set_json_mode()
308 def tearDown(self) -> None:
309 self.httpd.shutdown()
310 self.httpd.server_close()
311 self.server_thread.join()
314 def check_redirect(self, target: str) -> None:
315 """Check that self.conn answers with a 302 redirect to target."""
316 response = self.conn.getresponse()
317 self.assertEqual(response.status, 302)
318 self.assertEqual(response.getheader('Location'), target)
320 def check_get(self, target: str, expected_code: int) -> None:
321 """Check that a GET to target yields expected_code."""
322 self.conn.request('GET', target)
323 self.assertEqual(self.conn.getresponse().status, expected_code)
325 def check_post(self, data: Mapping[str, object], target: str,
326 expected_code: int, redirect_location: str = '') -> None:
327 """Check that POST of data to target yields expected_code."""
328 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
329 headers = {'Content-Type': 'application/x-www-form-urlencoded',
330 'Content-Length': str(len(encoded_form_data))}
331 self.conn.request('POST', target,
332 body=encoded_form_data, headers=headers)
333 if 302 == expected_code:
334 if redirect_location == '':
335 redirect_location = target
336 self.check_redirect(redirect_location)
338 self.assertEqual(self.conn.getresponse().status, expected_code)
340 def check_get_defaults(self, path: str) -> None:
341 """Some standard model paths to test."""
342 self.check_get(path, 200)
343 self.check_get(f'{path}?id=', 200)
344 self.check_get(f'{path}?id=foo', 400)
345 self.check_get(f'/{path}?id=0', 500)
346 self.check_get(f'{path}?id=1', 200)
348 def post_process(self, id_: int = 1,
349 form_data: dict[str, Any] | None = None
351 """POST basic Process."""
353 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
354 self.check_post(form_data, f'/process?id={id_}', 302,
355 f'/process?id={id_}')
358 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
359 """Compare JSON on GET path with expected.
361 To simplify comparison of VersionedAttribute histories, transforms
362 timestamp keys of VersionedAttribute history keys into integers
363 counting chronologically forward from 0.
365 def rewrite_history_keys_in(item: Any) -> Any:
366 if isinstance(item, dict):
367 if '_versioned' in item.keys():
368 for k in item['_versioned']:
369 vals = item['_versioned'][k].values()
371 for i, val in enumerate(vals):
373 item['_versioned'][k] = history
374 for k in list(item.keys()):
375 rewrite_history_keys_in(item[k])
376 elif isinstance(item, list):
377 item[:] = [rewrite_history_keys_in(i) for i in item]
379 self.conn.request('GET', path)
380 response = self.conn.getresponse()
381 self.assertEqual(response.status, 200)
382 retrieved = json_loads(response.read().decode())
383 rewrite_history_keys_in(retrieved)
384 self.assertEqual(expected, retrieved)