1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from json import loads as json_loads
6 from urllib.parse import urlencode
8 from os import remove as remove_file
9 from typing import Mapping, Any
10 from plomtask.db import DatabaseFile, DatabaseConnection
11 from plomtask.http import TaskHandler, TaskServer
12 from plomtask.processes import Process, ProcessStep
13 from plomtask.conditions import Condition
14 from plomtask.days import Day
15 from plomtask.todos import Todo
16 from plomtask.exceptions import NotFoundException, HandledException
19 class TestCaseSansDB(TestCase):
20 """Tests requiring no DB setup."""
22 do_id_test: bool = False
23 default_init_args: list[Any] = []
24 versioned_defaults_to_test: dict[str, str | float] = {}
26 def test_id_setting(self) -> None:
27 """Test .id_ being set and its legal range being enforced."""
28 if not self.do_id_test:
30 with self.assertRaises(HandledException):
31 self.checked_class(0, *self.default_init_args)
32 obj = self.checked_class(5, *self.default_init_args)
33 self.assertEqual(obj.id_, 5)
35 def test_versioned_defaults(self) -> None:
36 """Test defaults of VersionedAttributes."""
37 if len(self.versioned_defaults_to_test) == 0:
39 obj = self.checked_class(1, *self.default_init_args)
40 for k, v in self.versioned_defaults_to_test.items():
41 self.assertEqual(getattr(obj, k).newest, v)
44 class TestCaseWithDB(TestCase):
45 """Module tests not requiring DB setup."""
47 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
48 default_init_kwargs: dict[str, Any] = {}
49 test_versioneds: dict[str, type] = {}
51 def setUp(self) -> None:
52 Condition.empty_cache()
55 ProcessStep.empty_cache()
57 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
58 self.db_conn = DatabaseConnection(self.db_file)
60 def tearDown(self) -> None:
62 remove_file(self.db_file.path)
64 def test_saving_and_caching(self) -> None:
65 """Test storage and initialization of instances and attributes."""
66 if not hasattr(self, 'checked_class'):
68 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69 obj = self.checked_class(None, **self.default_init_kwargs)
70 obj.save(self.db_conn)
71 self.assertEqual(obj.id_, 2)
72 for k, v in self.test_versioneds.items():
73 self.check_saving_of_versioned(k, v)
75 def check_storage(self, content: list[Any]) -> None:
76 """Test cache and DB equal content."""
79 expected_cache[item.id_] = item
80 self.assertEqual(self.checked_class.get_cache(), expected_cache)
81 hashes_content = [hash(x) for x in content]
82 db_found: list[Any] = []
84 assert isinstance(item.id_, type(self.default_ids[0]))
85 for row in self.db_conn.row_where(self.checked_class.table_name,
87 db_found += [self.checked_class.from_table_row(self.db_conn,
89 hashes_db_found = [hash(x) for x in db_found]
90 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
92 def check_saving_and_caching(self, **kwargs: Any) -> None:
93 """Test instance.save in its core without relations."""
94 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
95 # check object init itself doesn't store anything yet
96 self.check_storage([])
97 # check saving sets core attributes properly
98 obj.save(self.db_conn)
99 for key, value in kwargs.items():
100 self.assertEqual(getattr(obj, key), value)
101 # check saving stored properly in cache and DB
102 self.check_storage([obj])
104 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
105 """Test owner's versioned attributes."""
106 owner = self.checked_class(None)
107 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
108 attr = getattr(owner, attr_name)
111 owner.save(self.db_conn)
112 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
113 attr = getattr(retrieved, attr_name)
114 self.assertEqual(sorted(attr.history.values()), vals)
116 def check_by_id(self) -> None:
117 """Test .by_id(), including creation."""
118 # check failure if not yet saved
119 id1, id2 = self.default_ids[0], self.default_ids[1]
120 obj = self.checked_class(id1) # pylint: disable=not-callable
121 with self.assertRaises(NotFoundException):
122 self.checked_class.by_id(self.db_conn, id1)
123 # check identity of saved and retrieved
124 obj.save(self.db_conn)
125 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
126 # check create=True acts like normal instantiation (sans saving)
127 by_id_created = self.checked_class.by_id(self.db_conn, id2,
129 # pylint: disable=not-callable
130 self.assertEqual(self.checked_class(id2), by_id_created)
131 self.check_storage([obj])
133 def test_from_table_row(self) -> None:
134 """Test .from_table_row() properly reads in class from DB."""
135 if not hasattr(self, 'checked_class'):
137 id_ = self.default_ids[0]
138 obj = self.checked_class(id_, **self.default_init_kwargs)
139 obj.save(self.db_conn)
140 assert isinstance(obj.id_, type(self.default_ids[0]))
141 for row in self.db_conn.row_where(self.checked_class.table_name,
143 # check .from_table_row reproduces state saved, no matter if obj
144 # later changed (with caching even)
145 hash_original = hash(obj)
146 attr_name = self.checked_class.to_save[-1]
147 attr = getattr(obj, attr_name)
148 if isinstance(attr, (int, float)):
149 setattr(obj, attr_name, attr + 1)
150 elif isinstance(attr, str):
151 setattr(obj, attr_name, attr + "_")
152 elif isinstance(attr, bool):
153 setattr(obj, attr_name, not attr)
155 to_cmp = getattr(obj, attr_name)
156 retrieved = self.checked_class.from_table_row(self.db_conn, row)
157 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
158 self.assertEqual(hash_original, hash(retrieved))
159 # check cache contains what .from_table_row just produced
160 self.assertEqual({retrieved.id_: retrieved},
161 self.checked_class.get_cache())
163 def check_versioned_from_table_row(self, attr_name: str,
164 type_: type) -> None:
165 """Test .from_table_row() reads versioned attributes from DB."""
166 owner = self.checked_class(None)
167 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
168 attr = getattr(owner, attr_name)
171 owner.save(self.db_conn)
172 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
173 retrieved = owner.__class__.from_table_row(self.db_conn, row)
174 attr = getattr(retrieved, attr_name)
175 self.assertEqual(sorted(attr.history.values()), vals)
177 def check_all(self) -> tuple[Any, Any, Any]:
179 # pylint: disable=not-callable
180 item1 = self.checked_class(self.default_ids[0])
181 item2 = self.checked_class(self.default_ids[1])
182 item3 = self.checked_class(self.default_ids[2])
183 # check pre-save .all() returns empty list
184 self.assertEqual(self.checked_class.all(self.db_conn), [])
185 # check that all() shows all saved, but no unsaved items
186 item1.save(self.db_conn)
187 item3.save(self.db_conn)
188 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
189 sorted([item1, item3]))
190 item2.save(self.db_conn)
191 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
192 sorted([item1, item2, item3]))
193 return item1, item2, item3
195 def test_singularity(self)-> None:
196 """Test pointers made for single object keep pointing to it."""
197 if not hasattr(self, 'checked_class'):
199 id1 = self.default_ids[0]
200 obj = self.checked_class(id1, **self.default_init_kwargs)
201 obj.save(self.db_conn)
202 attr_name = self.checked_class.to_save[-1]
203 attr = getattr(obj, attr_name)
204 if isinstance(attr, (int, float)):
206 elif isinstance(attr, str):
207 new_attr = attr + '_'
208 elif isinstance(attr, bool):
210 setattr(obj, attr_name, new_attr)
211 retrieved = self.checked_class.by_id(self.db_conn, id1)
212 self.assertEqual(new_attr, getattr(retrieved, attr_name))
214 def check_versioned_singularity(self) -> None:
215 """Test singularity of VersionedAttributes on saving (with .title)."""
216 obj = self.checked_class(None) # pylint: disable=not-callable
217 obj.save(self.db_conn)
218 assert isinstance(obj.id_, int)
219 obj.title.set('named')
220 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
221 self.assertEqual(obj.title.history, retrieved.title.history)
223 def check_remove(self, *args: Any) -> None:
224 """Test .remove() effects on DB and cache."""
225 id_ = self.default_ids[0]
226 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
227 with self.assertRaises(HandledException):
228 obj.remove(self.db_conn)
229 obj.save(self.db_conn)
230 obj.remove(self.db_conn)
231 self.check_storage([])
234 class TestCaseWithServer(TestCaseWithDB):
235 """Module tests against our HTTP server/handler (and database)."""
237 def setUp(self) -> None:
239 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
240 self.server_thread = Thread(target=self.httpd.serve_forever)
241 self.server_thread.daemon = True
242 self.server_thread.start()
243 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
244 self.httpd.server_address[1])
245 self.httpd.set_json_mode()
247 def tearDown(self) -> None:
248 self.httpd.shutdown()
249 self.httpd.server_close()
250 self.server_thread.join()
253 def check_redirect(self, target: str) -> None:
254 """Check that self.conn answers with a 302 redirect to target."""
255 response = self.conn.getresponse()
256 self.assertEqual(response.status, 302)
257 self.assertEqual(response.getheader('Location'), target)
259 def check_get(self, target: str, expected_code: int) -> None:
260 """Check that a GET to target yields expected_code."""
261 self.conn.request('GET', target)
262 self.assertEqual(self.conn.getresponse().status, expected_code)
264 def check_post(self, data: Mapping[str, object], target: str,
265 expected_code: int, redirect_location: str = '') -> None:
266 """Check that POST of data to target yields expected_code."""
267 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
268 headers = {'Content-Type': 'application/x-www-form-urlencoded',
269 'Content-Length': str(len(encoded_form_data))}
270 self.conn.request('POST', target,
271 body=encoded_form_data, headers=headers)
272 if 302 == expected_code:
273 if redirect_location == '':
274 redirect_location = target
275 self.check_redirect(redirect_location)
277 self.assertEqual(self.conn.getresponse().status, expected_code)
279 def check_get_defaults(self, path: str) -> None:
280 """Some standard model paths to test."""
281 self.check_get(path, 200)
282 self.check_get(f'{path}?id=', 200)
283 self.check_get(f'{path}?id=foo', 400)
284 self.check_get(f'/{path}?id=0', 500)
285 self.check_get(f'{path}?id=1', 200)
287 def post_process(self, id_: int = 1,
288 form_data: dict[str, Any] | None = None
290 """POST basic Process."""
292 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
293 self.check_post(form_data, f'/process?id={id_}', 302,
294 f'/process?id={id_}')
297 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
298 """Compare JSON on GET path with expected.
300 To simplify comparison of VersionedAttribute histories, transforms
301 timestamp keys of VersionedAttribute history keys into integers
302 counting chronologically forward from 0.
304 def rewrite_history_keys_in(item: Any) -> Any:
305 if isinstance(item, dict):
306 if '_versioned' in item.keys():
307 for k in item['_versioned']:
308 vals = item['_versioned'][k].values()
310 for i, val in enumerate(vals):
312 item['_versioned'][k] = history
313 for k in list(item.keys()):
314 rewrite_history_keys_in(item[k])
315 elif isinstance(item, list):
316 item[:] = [rewrite_history_keys_in(i) for i in item]
318 self.conn.request('GET', path)
319 response = self.conn.getresponse()
320 self.assertEqual(response.status, 200)
321 retrieved = json_loads(response.read().decode())
322 rewrite_history_keys_in(retrieved)
323 self.assertEqual(expected, retrieved)