1 """Shared test utilities."""
2 from __future__ import annotations
3 from unittest import TestCase
4 from typing import Mapping, Any, Callable
5 from threading import Thread
6 from http.client import HTTPConnection
7 from json import loads as json_loads
8 from urllib.parse import urlencode
10 from os import remove as remove_file
11 from plomtask.db import DatabaseFile, DatabaseConnection
12 from plomtask.http import TaskHandler, TaskServer
13 from plomtask.processes import Process, ProcessStep
14 from plomtask.conditions import Condition
15 from plomtask.days import Day
16 from plomtask.todos import Todo
17 from plomtask.exceptions import NotFoundException, HandledException
20 class TestCaseSansDB(TestCase):
21 """Tests requiring no DB setup."""
23 do_id_test: bool = False
24 default_init_args: list[Any] = []
25 versioned_defaults_to_test: dict[str, str | float] = {}
27 def test_id_setting(self) -> None:
28 """Test .id_ being set and its legal range being enforced."""
29 if not self.do_id_test:
31 with self.assertRaises(HandledException):
32 self.checked_class(0, *self.default_init_args)
33 obj = self.checked_class(5, *self.default_init_args)
34 self.assertEqual(obj.id_, 5)
36 def test_versioned_defaults(self) -> None:
37 """Test defaults of VersionedAttributes."""
38 if len(self.versioned_defaults_to_test) == 0:
40 obj = self.checked_class(1, *self.default_init_args)
41 for k, v in self.versioned_defaults_to_test.items():
42 self.assertEqual(getattr(obj, k).newest, v)
45 class TestCaseWithDB(TestCase):
46 """Module tests not requiring DB setup."""
48 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
49 default_init_kwargs: dict[str, Any] = {}
50 test_versioneds: dict[str, type] = {}
52 def setUp(self) -> None:
53 Condition.empty_cache()
56 ProcessStep.empty_cache()
58 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
59 self.db_conn = DatabaseConnection(self.db_file)
61 def tearDown(self) -> None:
63 remove_file(self.db_file.path)
66 def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
67 def wrapper(self: TestCaseWithDB) -> None:
68 if hasattr(self, 'checked_class'):
72 @_within_checked_class
73 def test_saving_and_caching(self) -> None:
74 """Test storage and initialization of instances and attributes."""
75 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
76 obj = self.checked_class(None, **self.default_init_kwargs)
77 obj.save(self.db_conn)
78 self.assertEqual(obj.id_, 2)
79 for attr_name, type_ in self.test_versioneds.items():
80 owner = self.checked_class(None)
81 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
82 attr = getattr(owner, attr_name)
85 owner.save(self.db_conn)
86 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
87 attr = getattr(retrieved, attr_name)
88 self.assertEqual(sorted(attr.history.values()), vals)
90 def check_storage(self, content: list[Any]) -> None:
91 """Test cache and DB equal content."""
94 expected_cache[item.id_] = item
95 self.assertEqual(self.checked_class.get_cache(), expected_cache)
96 hashes_content = [hash(x) for x in content]
97 db_found: list[Any] = []
99 assert isinstance(item.id_, type(self.default_ids[0]))
100 for row in self.db_conn.row_where(self.checked_class.table_name,
102 db_found += [self.checked_class.from_table_row(self.db_conn,
104 hashes_db_found = [hash(x) for x in db_found]
105 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
107 def check_saving_and_caching(self, **kwargs: Any) -> None:
108 """Test instance.save in its core without relations."""
109 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
110 # check object init itself doesn't store anything yet
111 self.check_storage([])
112 # check saving sets core attributes properly
113 obj.save(self.db_conn)
114 for key, value in kwargs.items():
115 self.assertEqual(getattr(obj, key), value)
116 # check saving stored properly in cache and DB
117 self.check_storage([obj])
119 def check_by_id(self) -> None:
120 """Test .by_id(), including creation."""
121 # check failure if not yet saved
122 id1, id2 = self.default_ids[0], self.default_ids[1]
123 obj = self.checked_class(id1) # pylint: disable=not-callable
124 with self.assertRaises(NotFoundException):
125 self.checked_class.by_id(self.db_conn, id1)
126 # check identity of saved and retrieved
127 obj.save(self.db_conn)
128 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
129 # check create=True acts like normal instantiation (sans saving)
130 by_id_created = self.checked_class.by_id(self.db_conn, id2,
132 # pylint: disable=not-callable
133 self.assertEqual(self.checked_class(id2), by_id_created)
134 self.check_storage([obj])
136 @_within_checked_class
137 def test_from_table_row(self) -> None:
138 """Test .from_table_row() properly reads in class directly from DB."""
139 id_ = self.default_ids[0]
140 obj = self.checked_class(id_, **self.default_init_kwargs)
141 obj.save(self.db_conn)
142 assert isinstance(obj.id_, type(self.default_ids[0]))
143 for row in self.db_conn.row_where(self.checked_class.table_name,
145 # check .from_table_row reproduces state saved, no matter if obj
146 # later changed (with caching even)
147 hash_original = hash(obj)
148 attr_name = self.checked_class.to_save[-1]
149 attr = getattr(obj, attr_name)
150 if isinstance(attr, (int, float)):
151 setattr(obj, attr_name, attr + 1)
152 elif isinstance(attr, str):
153 setattr(obj, attr_name, attr + "_")
154 elif isinstance(attr, bool):
155 setattr(obj, attr_name, not attr)
157 to_cmp = getattr(obj, attr_name)
158 retrieved = self.checked_class.from_table_row(self.db_conn, row)
159 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
160 self.assertEqual(hash_original, hash(retrieved))
161 # check cache contains what .from_table_row just produced
162 self.assertEqual({retrieved.id_: retrieved},
163 self.checked_class.get_cache())
165 def check_versioned_from_table_row(self, attr_name: str,
166 type_: type) -> None:
167 """Test .from_table_row() reads versioned attributes from DB."""
168 owner = self.checked_class(None)
169 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
170 attr = getattr(owner, attr_name)
173 owner.save(self.db_conn)
174 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
175 retrieved = owner.__class__.from_table_row(self.db_conn, row)
176 attr = getattr(retrieved, attr_name)
177 self.assertEqual(sorted(attr.history.values()), vals)
179 @_within_checked_class
180 def test_all(self) -> None:
181 """Test .all() and its relation to cache and savings."""
182 id_1, id_2, id_3 = self.default_ids
183 item1 = self.checked_class(id_1, **self.default_init_kwargs)
184 item2 = self.checked_class(id_2, **self.default_init_kwargs)
185 item3 = self.checked_class(id_3, **self.default_init_kwargs)
186 # check .all() returns empty list on un-cached items
187 self.assertEqual(self.checked_class.all(self.db_conn), [])
188 # check that all() shows only cached/saved items
190 item3.save(self.db_conn)
191 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
192 sorted([item1, item3]))
193 item2.save(self.db_conn)
194 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
195 sorted([item1, item2, item3]))
197 @_within_checked_class
198 def test_singularity(self) -> None:
199 """Test pointers made for single object keep pointing to it."""
200 id1 = self.default_ids[0]
201 obj = self.checked_class(id1, **self.default_init_kwargs)
202 obj.save(self.db_conn)
203 attr_name = self.checked_class.to_save[-1]
204 attr = getattr(obj, attr_name)
205 new_attr: str | int | float | bool
206 if isinstance(attr, (int, float)):
208 elif isinstance(attr, str):
209 new_attr = attr + '_'
210 elif isinstance(attr, bool):
212 setattr(obj, attr_name, new_attr)
213 retrieved = self.checked_class.by_id(self.db_conn, id1)
214 self.assertEqual(new_attr, getattr(retrieved, attr_name))
216 def check_versioned_singularity(self) -> None:
217 """Test singularity of VersionedAttributes on saving (with .title)."""
218 obj = self.checked_class(None) # pylint: disable=not-callable
219 obj.save(self.db_conn)
220 assert isinstance(obj.id_, int)
221 obj.title.set('named')
222 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
223 self.assertEqual(obj.title.history, retrieved.title.history)
225 def check_remove(self, *args: Any) -> None:
226 """Test .remove() effects on DB and cache."""
227 id_ = self.default_ids[0]
228 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
229 with self.assertRaises(HandledException):
230 obj.remove(self.db_conn)
231 obj.save(self.db_conn)
232 obj.remove(self.db_conn)
233 self.check_storage([])
236 class TestCaseWithServer(TestCaseWithDB):
237 """Module tests against our HTTP server/handler (and database)."""
239 def setUp(self) -> None:
241 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
242 self.server_thread = Thread(target=self.httpd.serve_forever)
243 self.server_thread.daemon = True
244 self.server_thread.start()
245 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
246 self.httpd.server_address[1])
247 self.httpd.set_json_mode()
249 def tearDown(self) -> None:
250 self.httpd.shutdown()
251 self.httpd.server_close()
252 self.server_thread.join()
255 def check_redirect(self, target: str) -> None:
256 """Check that self.conn answers with a 302 redirect to target."""
257 response = self.conn.getresponse()
258 self.assertEqual(response.status, 302)
259 self.assertEqual(response.getheader('Location'), target)
261 def check_get(self, target: str, expected_code: int) -> None:
262 """Check that a GET to target yields expected_code."""
263 self.conn.request('GET', target)
264 self.assertEqual(self.conn.getresponse().status, expected_code)
266 def check_post(self, data: Mapping[str, object], target: str,
267 expected_code: int, redirect_location: str = '') -> None:
268 """Check that POST of data to target yields expected_code."""
269 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
270 headers = {'Content-Type': 'application/x-www-form-urlencoded',
271 'Content-Length': str(len(encoded_form_data))}
272 self.conn.request('POST', target,
273 body=encoded_form_data, headers=headers)
274 if 302 == expected_code:
275 if redirect_location == '':
276 redirect_location = target
277 self.check_redirect(redirect_location)
279 self.assertEqual(self.conn.getresponse().status, expected_code)
281 def check_get_defaults(self, path: str) -> None:
282 """Some standard model paths to test."""
283 self.check_get(path, 200)
284 self.check_get(f'{path}?id=', 200)
285 self.check_get(f'{path}?id=foo', 400)
286 self.check_get(f'/{path}?id=0', 500)
287 self.check_get(f'{path}?id=1', 200)
289 def post_process(self, id_: int = 1,
290 form_data: dict[str, Any] | None = None
292 """POST basic Process."""
294 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
295 self.check_post(form_data, f'/process?id={id_}', 302,
296 f'/process?id={id_}')
299 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
300 """Compare JSON on GET path with expected.
302 To simplify comparison of VersionedAttribute histories, transforms
303 timestamp keys of VersionedAttribute history keys into integers
304 counting chronologically forward from 0.
306 def rewrite_history_keys_in(item: Any) -> Any:
307 if isinstance(item, dict):
308 if '_versioned' in item.keys():
309 for k in item['_versioned']:
310 vals = item['_versioned'][k].values()
312 for i, val in enumerate(vals):
314 item['_versioned'][k] = history
315 for k in list(item.keys()):
316 rewrite_history_keys_in(item[k])
317 elif isinstance(item, list):
318 item[:] = [rewrite_history_keys_in(i) for i in item]
320 self.conn.request('GET', path)
321 response = self.conn.getresponse()
322 self.assertEqual(response.status, 200)
323 retrieved = json_loads(response.read().decode())
324 rewrite_history_keys_in(retrieved)
325 self.assertEqual(expected, retrieved)