1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from json import loads as json_loads
6 from urllib.parse import urlencode
8 from os import remove as remove_file
9 from typing import Mapping, Any
10 from plomtask.db import DatabaseFile, DatabaseConnection
11 from plomtask.http import TaskHandler, TaskServer
12 from plomtask.processes import Process, ProcessStep
13 from plomtask.conditions import Condition
14 from plomtask.days import Day
15 from plomtask.todos import Todo
16 from plomtask.exceptions import NotFoundException, HandledException
19 class TestCaseSansDB(TestCase):
20 """Tests requiring no DB setup."""
22 do_id_test: bool = False
23 default_init_args: list[Any] = []
24 versioned_defaults_to_test: dict[str, str | float] = {}
26 def test_id_setting(self) -> None:
27 """Test .id_ being set and its legal range being enforced."""
28 if not self.do_id_test:
30 with self.assertRaises(HandledException):
31 self.checked_class(0, *self.default_init_args)
32 obj = self.checked_class(5, *self.default_init_args)
33 self.assertEqual(obj.id_, 5)
35 def test_versioned_defaults(self) -> None:
36 """Test defaults of VersionedAttributes."""
37 if len(self.versioned_defaults_to_test) == 0:
39 obj = self.checked_class(1, *self.default_init_args)
40 for k, v in self.versioned_defaults_to_test.items():
41 self.assertEqual(getattr(obj, k).newest, v)
44 class TestCaseWithDB(TestCase):
45 """Module tests not requiring DB setup."""
47 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
48 default_init_kwargs: dict[str, Any] = {}
49 test_versioneds: dict[str, type] = {}
51 def setUp(self) -> None:
52 Condition.empty_cache()
55 ProcessStep.empty_cache()
57 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
58 self.db_conn = DatabaseConnection(self.db_file)
60 def tearDown(self) -> None:
62 remove_file(self.db_file.path)
64 def test_saving_and_caching(self) -> None:
65 """Test storage and initialization of instances and attributes."""
66 if not hasattr(self, 'checked_class'):
68 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69 obj = self.checked_class(None, **self.default_init_kwargs)
70 obj.save(self.db_conn)
71 self.assertEqual(obj.id_, 2)
72 for k, v in self.test_versioneds.items():
73 self.check_saving_of_versioned(k, v)
75 def check_storage(self, content: list[Any]) -> None:
76 """Test cache and DB equal content."""
79 expected_cache[item.id_] = item
80 self.assertEqual(self.checked_class.get_cache(), expected_cache)
81 hashes_content = [hash(x) for x in content]
82 db_found: list[Any] = []
84 assert isinstance(item.id_, type(self.default_ids[0]))
85 for row in self.db_conn.row_where(self.checked_class.table_name,
87 db_found += [self.checked_class.from_table_row(self.db_conn,
89 hashes_db_found = [hash(x) for x in db_found]
90 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
92 def check_saving_and_caching(self, **kwargs: Any) -> None:
93 """Test instance.save in its core without relations."""
94 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
95 # check object init itself doesn't store anything yet
96 self.check_storage([])
97 # check saving sets core attributes properly
98 obj.save(self.db_conn)
99 for key, value in kwargs.items():
100 self.assertEqual(getattr(obj, key), value)
101 # check saving stored properly in cache and DB
102 self.check_storage([obj])
104 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
105 """Test owner's versioned attributes."""
106 owner = self.checked_class(None)
107 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
108 attr = getattr(owner, attr_name)
111 owner.save(self.db_conn)
112 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
113 attr = getattr(retrieved, attr_name)
114 self.assertEqual(sorted(attr.history.values()), vals)
116 def check_by_id(self) -> None:
117 """Test .by_id(), including creation."""
118 # check failure if not yet saved
119 id1, id2 = self.default_ids[0], self.default_ids[1]
120 obj = self.checked_class(id1) # pylint: disable=not-callable
121 with self.assertRaises(NotFoundException):
122 self.checked_class.by_id(self.db_conn, id1)
123 # check identity of saved and retrieved
124 obj.save(self.db_conn)
125 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
126 # check create=True acts like normal instantiation (sans saving)
127 by_id_created = self.checked_class.by_id(self.db_conn, id2,
129 # pylint: disable=not-callable
130 self.assertEqual(self.checked_class(id2), by_id_created)
131 self.check_storage([obj])
133 def test_from_table_row(self) -> None:
134 """Test .from_table_row() properly reads in class from DB."""
135 if not hasattr(self, 'checked_class'):
137 id_ = self.default_ids[0]
138 obj = self.checked_class(id_, **self.default_init_kwargs)
139 obj.save(self.db_conn)
140 assert isinstance(obj.id_, type(self.default_ids[0]))
141 for row in self.db_conn.row_where(self.checked_class.table_name,
143 # check .from_table_row reproduces state saved, no matter if obj
144 # later changed (with caching even)
145 hash_original = hash(obj)
146 attr_name = self.checked_class.to_save[-1]
147 attr = getattr(obj, attr_name)
148 if isinstance(attr, (int, float)):
149 setattr(obj, attr_name, attr + 1)
150 elif isinstance(attr, str):
151 setattr(obj, attr_name, attr + "_")
152 elif isinstance(attr, bool):
153 setattr(obj, attr_name, not attr)
155 to_cmp = getattr(obj, attr_name)
156 retrieved = self.checked_class.from_table_row(self.db_conn, row)
157 self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
158 self.assertEqual(hash_original, hash(retrieved))
159 # check cache contains what .from_table_row just produced
160 self.assertEqual({retrieved.id_: retrieved},
161 self.checked_class.get_cache())
163 def check_versioned_from_table_row(self, attr_name: str,
164 type_: type) -> None:
165 """Test .from_table_row() reads versioned attributes from DB."""
166 owner = self.checked_class(None)
167 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
168 attr = getattr(owner, attr_name)
171 owner.save(self.db_conn)
172 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
173 retrieved = owner.__class__.from_table_row(self.db_conn, row)
174 attr = getattr(retrieved, attr_name)
175 self.assertEqual(sorted(attr.history.values()), vals)
177 def check_all(self) -> tuple[Any, Any, Any]:
179 # pylint: disable=not-callable
180 item1 = self.checked_class(self.default_ids[0])
181 item2 = self.checked_class(self.default_ids[1])
182 item3 = self.checked_class(self.default_ids[2])
183 # check pre-save .all() returns empty list
184 self.assertEqual(self.checked_class.all(self.db_conn), [])
185 # check that all() shows all saved, but no unsaved items
186 item1.save(self.db_conn)
187 item3.save(self.db_conn)
188 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
189 sorted([item1, item3]))
190 item2.save(self.db_conn)
191 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
192 sorted([item1, item2, item3]))
193 return item1, item2, item3
195 def check_singularity(self, defaulting_field: str,
196 non_default_value: Any, *args: Any) -> None:
197 """Test pointers made for single object keep pointing to it."""
198 id1 = self.default_ids[0]
199 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
200 obj.save(self.db_conn)
201 setattr(obj, defaulting_field, non_default_value)
202 retrieved = self.checked_class.by_id(self.db_conn, id1)
203 self.assertEqual(non_default_value,
204 getattr(retrieved, defaulting_field))
206 def check_versioned_singularity(self) -> None:
207 """Test singularity of VersionedAttributes on saving (with .title)."""
208 obj = self.checked_class(None) # pylint: disable=not-callable
209 obj.save(self.db_conn)
210 assert isinstance(obj.id_, int)
211 obj.title.set('named')
212 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
213 self.assertEqual(obj.title.history, retrieved.title.history)
215 def check_remove(self, *args: Any) -> None:
216 """Test .remove() effects on DB and cache."""
217 id_ = self.default_ids[0]
218 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
219 with self.assertRaises(HandledException):
220 obj.remove(self.db_conn)
221 obj.save(self.db_conn)
222 obj.remove(self.db_conn)
223 self.check_storage([])
226 class TestCaseWithServer(TestCaseWithDB):
227 """Module tests against our HTTP server/handler (and database)."""
229 def setUp(self) -> None:
231 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
232 self.server_thread = Thread(target=self.httpd.serve_forever)
233 self.server_thread.daemon = True
234 self.server_thread.start()
235 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
236 self.httpd.server_address[1])
237 self.httpd.set_json_mode()
239 def tearDown(self) -> None:
240 self.httpd.shutdown()
241 self.httpd.server_close()
242 self.server_thread.join()
245 def check_redirect(self, target: str) -> None:
246 """Check that self.conn answers with a 302 redirect to target."""
247 response = self.conn.getresponse()
248 self.assertEqual(response.status, 302)
249 self.assertEqual(response.getheader('Location'), target)
251 def check_get(self, target: str, expected_code: int) -> None:
252 """Check that a GET to target yields expected_code."""
253 self.conn.request('GET', target)
254 self.assertEqual(self.conn.getresponse().status, expected_code)
256 def check_post(self, data: Mapping[str, object], target: str,
257 expected_code: int, redirect_location: str = '') -> None:
258 """Check that POST of data to target yields expected_code."""
259 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
260 headers = {'Content-Type': 'application/x-www-form-urlencoded',
261 'Content-Length': str(len(encoded_form_data))}
262 self.conn.request('POST', target,
263 body=encoded_form_data, headers=headers)
264 if 302 == expected_code:
265 if redirect_location == '':
266 redirect_location = target
267 self.check_redirect(redirect_location)
269 self.assertEqual(self.conn.getresponse().status, expected_code)
271 def check_get_defaults(self, path: str) -> None:
272 """Some standard model paths to test."""
273 self.check_get(path, 200)
274 self.check_get(f'{path}?id=', 200)
275 self.check_get(f'{path}?id=foo', 400)
276 self.check_get(f'/{path}?id=0', 500)
277 self.check_get(f'{path}?id=1', 200)
279 def post_process(self, id_: int = 1,
280 form_data: dict[str, Any] | None = None
282 """POST basic Process."""
284 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
285 self.check_post(form_data, f'/process?id={id_}', 302,
286 f'/process?id={id_}')
289 def check_json_get(self, path: str, expected: dict[str, object]) -> None:
290 """Compare JSON on GET path with expected.
292 To simplify comparison of VersionedAttribute histories, transforms
293 timestamp keys of VersionedAttribute history keys into integers
294 counting chronologically forward from 0.
296 def rewrite_history_keys_in(item: Any) -> Any:
297 if isinstance(item, dict):
298 if '_versioned' in item.keys():
299 for k in item['_versioned']:
300 vals = item['_versioned'][k].values()
302 for i, val in enumerate(vals):
304 item['_versioned'][k] = history
305 for k in list(item.keys()):
306 rewrite_history_keys_in(item[k])
307 elif isinstance(item, list):
308 item[:] = [rewrite_history_keys_in(i) for i in item]
310 self.conn.request('GET', path)
311 response = self.conn.getresponse()
312 self.assertEqual(response.status, 200)
313 retrieved = json_loads(response.read().decode())
314 rewrite_history_keys_in(retrieved)
315 self.assertEqual(expected, retrieved)