1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
20 class TestCaseSansDB(TestCase):
21 """Tests requiring no DB setup."""
23 do_id_test: bool = False
24 default_init_args: list[Any] = []
25 versioned_defaults_to_test: dict[str, str | float] = {}
27 def test_id_setting(self) -> None:
28 """Test .id_ being set and its legal range being enforced."""
29 if not self.do_id_test:
31 with self.assertRaises(HandledException):
32 self.checked_class(0, *self.default_init_args)
33 obj = self.checked_class(5, *self.default_init_args)
34 self.assertEqual(obj.id_, 5)
36 def test_versioned_defaults(self) -> None:
37 """Test defaults of VersionedAttributes."""
38 if len(self.versioned_defaults_to_test) == 0:
40 obj = self.checked_class(1, *self.default_init_args)
41 for k, v in self.versioned_defaults_to_test.items():
42 self.assertEqual(getattr(obj, k).newest, v)
45 class TestCaseWithDB(TestCase):
46 """Module tests not requiring DB setup."""
48 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49 default_init_kwargs: dict[str, Any] = {}
50 test_versioneds: dict[str, type] = {}
52 def setUp(self) -> None:
53 Condition.empty_cache()
56 ProcessStep.empty_cache()
58 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59 self.db_conn = DatabaseConnection(self.db_file)
61 def tearDown(self) -> None:
63 remove_file(self.db_file.path)
66 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67 def wrapper(self: TestCaseWithDB) -> None:
68 if hasattr(self, 'checked_class'):
72 @_within_checked_class
73 def test_saving_and_caching(self) -> None:
74 """Test storage and initialization of instances and attributes."""
75 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
76 obj = self.checked_class(None, **self.default_init_kwargs)
77 obj.save(self.db_conn)
78 self.assertEqual(obj.id_, 2)
79 for attr_name, type_ in self.test_versioneds.items():
80 owner = self.checked_class(None)
81 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
82 attr = getattr(owner, attr_name)
85 owner.save(self.db_conn)
86 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
87 attr = getattr(retrieved, attr_name)
88 self.assertEqual(sorted(attr.history.values()), vals)
90 def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
91 """Test both cache and DB equal content."""
94 expected_cache[item.id_] = item
95 self.assertEqual(self.checked_class.get_cache(), expected_cache)
96 hashes_content = [hash(x) for x in content]
97 db_found: list[Any] = []
99 assert isinstance(item.id_, type(self.default_ids[0]))
100 for row in self.db_conn.row_where(self.checked_class.table_name,
102 db_found += [self.checked_class.from_table_row(self.db_conn,
104 hashes_db_found = [hash(x) for x in db_found]
105 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
107 def check_saving_and_caching(self, **kwargs: Any) -> None:
108 """Test instance.save in its core without relations."""
109 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
110 # check object init itself doesn't store anything yet
111 self.check_identity_with_cache_and_db([])
112 # check saving sets core attributes properly
113 obj.save(self.db_conn)
114 for key, value in kwargs.items():
115 self.assertEqual(getattr(obj, key), value)
116 # check saving stored properly in cache and DB
117 self.check_identity_with_cache_and_db([obj])
119 @_within_checked_class
120 def test_by_id(self) -> None:
122 id1, id2, _ = self.default_ids
123 # check failure if not yet saved
124 obj1 = self.checked_class(id1, **self.default_init_kwargs)
125 with self.assertRaises(NotFoundException):
126 self.checked_class.by_id(self.db_conn, id1)
127 # check identity of cached and retrieved
129 self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
130 # check identity of saved and retrieved
131 obj2 = self.checked_class(id2, **self.default_init_kwargs)
132 obj2.save(self.db_conn)
133 self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
134 # obj1.save(self.db_conn)
135 # self.check_identity_with_cache_and_db([obj1, obj2])
137 @_within_checked_class
138 def test_by_id_or_create(self) -> None:
139 """Test .by_id_or_create."""
140 # check .by_id_or_create acts like normal instantiation (sans saving)
141 id_ = self.default_ids[0]
142 if not self.checked_class.can_create_by_id:
143 with self.assertRaises(HandledException):
144 self.checked_class.by_id_or_create(self.db_conn, id_)
145 # check .by_id_or_create fails if wrong class
147 by_id_created = self.checked_class.by_id_or_create(self.db_conn,
149 with self.assertRaises(NotFoundException):
150 self.checked_class.by_id(self.db_conn, id_)
151 self.assertEqual(self.checked_class(id_), by_id_created)
153 @_within_checked_class
154 def test_from_table_row(self) -> None:
155 """Test .from_table_row() properly reads in class directly from DB."""
156 id_ = self.default_ids[0]
157 obj = self.checked_class(id_, **self.default_init_kwargs)
158 obj.save(self.db_conn)
159 assert isinstance(obj.id_, type(self.default_ids[0]))
160 for row in self.db_conn.row_where(self.checked_class.table_name,
162 # check .from_table_row reproduces state saved, no matter if obj
163 # later changed (with caching even)
164 hash_original = hash(obj)
165 attr_name = self.checked_class.to_save[-1]
166 attr = getattr(obj, attr_name)
167 if isinstance(attr, (int, float)):
168 setattr(obj, attr_name, attr + 1)
169 elif isinstance(attr, str):
170 setattr(obj, attr_name, attr + "_")
171 elif isinstance(attr, bool):
172 setattr(obj, attr_name, not attr)
174 to_cmp = getattr(obj, attr_name)
175 retrieved = self.checked_class.from_table_row(self.db_conn, row)
176 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
177 self.assertEqual(hash_original, hash(retrieved))
178 # check cache contains what .from_table_row just produced
179 self.assertEqual({retrieved.id_: retrieved},
180 self.checked_class.get_cache())
182 def check_versioned_from_table_row(self, attr_name: str,
183 type_: type) -> None:
184 """Test .from_table_row() reads versioned attributes from DB."""
185 owner = self.checked_class(None)
186 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
187 attr = getattr(owner, attr_name)
190 owner.save(self.db_conn)
191 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
192 retrieved = owner.__class__.from_table_row(self.db_conn, row)
193 attr = getattr(retrieved, attr_name)
194 self.assertEqual(sorted(attr.history.values()), vals)
196 @_within_checked_class
197 def test_all(self) -> None:
198 """Test .all() and its relation to cache and savings."""
199 id_1, id_2, id_3 = self.default_ids
200 item1 = self.checked_class(id_1, **self.default_init_kwargs)
201 item2 = self.checked_class(id_2, **self.default_init_kwargs)
202 item3 = self.checked_class(id_3, **self.default_init_kwargs)
203 # check .all() returns empty list on un-cached items
204 self.assertEqual(self.checked_class.all(self.db_conn), [])
205 # check that all() shows only cached/saved items
207 item3.save(self.db_conn)
208 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
209 sorted([item1, item3]))
210 item2.save(self.db_conn)
211 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
212 sorted([item1, item2, item3]))
214 @_within_checked_class
215 def test_singularity(self) -> None:
216 """Test pointers made for single object keep pointing to it."""
217 id1 = self.default_ids[0]
218 obj = self.checked_class(id1, **self.default_init_kwargs)
219 obj.save(self.db_conn)
220 attr_name = self.checked_class.to_save[-1]
221 attr = getattr(obj, attr_name)
222 new_attr: str | int | float | bool
223 if isinstance(attr, (int, float)):
225 elif isinstance(attr, str):
226 new_attr = attr + '_'
227 elif isinstance(attr, bool):
229 setattr(obj, attr_name, new_attr)
230 retrieved = self.checked_class.by_id(self.db_conn, id1)
231 self.assertEqual(new_attr, getattr(retrieved, attr_name))
233 def check_versioned_singularity(self) -> None:
234 """Test singularity of VersionedAttributes on saving (with .title)."""
235 obj = self.checked_class(None) # pylint: disable=not-callable
236 obj.save(self.db_conn)
237 assert isinstance(obj.id_, int)
238 obj.title.set('named')
239 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
240 self.assertEqual(obj.title.history, retrieved.title.history)
242 def check_remove(self, *args: Any) -> None:
243 """Test .remove() effects on DB and cache."""
244 id_ = self.default_ids[0]
245 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
246 with self.assertRaises(HandledException):
247 obj.remove(self.db_conn)
248 obj.save(self.db_conn)
249 obj.remove(self.db_conn)
250 self.check_identity_with_cache_and_db([])
253 class TestCaseWithServer(TestCaseWithDB):
254 """Module tests against our HTTP server/handler (and database)."""
256 def setUp(self) -> None:
258 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
259 self.server_thread = Thread(target=self.httpd.serve_forever)
260 self.server_thread.daemon = True
261 self.server_thread.start()
262 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
263 self.httpd.server_address[1])
264 self.httpd.set_json_mode()
266 def tearDown(self) -> None:
267 self.httpd.shutdown()
268 self.httpd.server_close()
269 self.server_thread.join()
272 def check_redirect(self, target: str) -> None:
273 """Check that self.conn answers with a 302 redirect to target."""
274 response = self.conn.getresponse()
275 self.assertEqual(response.status, 302)
276 self.assertEqual(response.getheader('Location'), target)
278 def check_get(self, target: str, expected_code: int) -> None:
279 """Check that a GET to target yields expected_code."""
280 self.conn.request('GET', target)
281 self.assertEqual(self.conn.getresponse().status, expected_code)
283 def check_post(self, data: Mapping[str, object], target: str,
284 expected_code: int, redirect_location: str = '') -> None:
285 """Check that POST of data to target yields expected_code."""
286 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
287 headers = {'Content-Type': 'application/x-www-form-urlencoded',
288 'Content-Length': str(len(encoded_form_data))}
289 self.conn.request('POST', target,
290 body=encoded_form_data, headers=headers)
291 if 302 == expected_code:
292 if redirect_location == '':
293 redirect_location = target
294 self.check_redirect(redirect_location)
296 self.assertEqual(self.conn.getresponse().status, expected_code)
298 def check_get_defaults(self, path: str) -> None:
299 """Some standard model paths to test."""
300 self.check_get(path, 200)
301 self.check_get(f'{path}?id=', 200)
302 self.check_get(f'{path}?id=foo', 400)
303 self.check_get(f'/{path}?id=0', 500)
304 self.check_get(f'{path}?id=1', 200)
306 def post_process(self, id_: int = 1,
307 form_data: dict[str, Any] | None = None
309 """POST basic Process."""
311 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
312 self.check_post(form_data, f'/process?id={id_}', 302,
313 f'/process?id={id_}')
316 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
317 """Compare JSON on GET path with expected.
319 To simplify comparison of VersionedAttribute histories, transforms
320 timestamp keys of VersionedAttribute history keys into integers
321 counting chronologically forward from 0.
323 def rewrite_history_keys_in(item: Any) -> Any:
324 if isinstance(item, dict):
325 if '_versioned' in item.keys():
326 for k in item['_versioned']:
327 vals = item['_versioned'][k].values()
329 for i, val in enumerate(vals):
331 item['_versioned'][k] = history
332 for k in list(item.keys()):
333 rewrite_history_keys_in(item[k])
334 elif isinstance(item, list):
335 item[:] = [rewrite_history_keys_in(i) for i in item]
337 self.conn.request('GET', path)
338 response = self.conn.getresponse()
339 self.assertEqual(response.status, 200)
340 retrieved = json_loads(response.read().decode())
341 rewrite_history_keys_in(retrieved)
342 self.assertEqual(expected, retrieved)