1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
21 do_id_test: bool = False
22 default_init_args: list[Any] = []
23 versioned_defaults_to_test: dict[str, str | float] = {}
25 def test_id_setting(self) -> None:
26 """Test .id_ being set and its legal range being enforced."""
27 if not self.do_id_test:
29 with self.assertRaises(HandledException):
30 self.checked_class(0, *self.default_init_args)
31 obj = self.checked_class(5, *self.default_init_args)
32 self.assertEqual(obj.id_, 5)
34 def test_versioned_defaults(self) -> None:
35 """Test defaults of VersionedAttributes."""
36 if len(self.versioned_defaults_to_test) == 0:
38 obj = self.checked_class(1, *self.default_init_args)
39 for k, v in self.versioned_defaults_to_test.items():
40 self.assertEqual(getattr(obj, k).newest, v)
43 class TestCaseWithDB(TestCase):
44 """Module tests not requiring DB setup."""
46 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47 default_init_kwargs: dict[str, Any] = {}
48 test_versioneds: dict[str, type] = {}
50 def setUp(self) -> None:
51 Condition.empty_cache()
54 ProcessStep.empty_cache()
56 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
57 self.db_conn = DatabaseConnection(self.db_file)
59 def tearDown(self) -> None:
61 remove_file(self.db_file.path)
63 def test_saving_and_caching(self) -> None:
64 """Test storage and initialization of instances and attributes."""
65 if not hasattr(self, 'checked_class'):
67 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
68 obj = self.checked_class(None, **self.default_init_kwargs)
69 obj.save(self.db_conn)
70 self.assertEqual(obj.id_, 2)
71 for k, v in self.test_versioneds.items():
72 self.check_saving_of_versioned(k, v)
74 def check_storage(self, content: list[Any]) -> None:
75 """Test cache and DB equal content."""
78 expected_cache[item.id_] = item
79 self.assertEqual(self.checked_class.get_cache(), expected_cache)
80 hashes_content = [hash(x) for x in content]
81 db_found: list[Any] = []
83 assert isinstance(item.id_, type(self.default_ids[0]))
84 for row in self.db_conn.row_where(self.checked_class.table_name,
86 db_found += [self.checked_class.from_table_row(self.db_conn,
88 hashes_db_found = [hash(x) for x in db_found]
89 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
91 def check_saving_and_caching(self, **kwargs: Any) -> None:
92 """Test instance.save in its core without relations."""
93 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
94 # check object init itself doesn't store anything yet
95 self.check_storage([])
96 # check saving sets core attributes properly
97 obj.save(self.db_conn)
98 for key, value in kwargs.items():
99 self.assertEqual(getattr(obj, key), value)
100 # check saving stored properly in cache and DB
101 self.check_storage([obj])
103 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
104 """Test owner's versioned attributes."""
105 owner = self.checked_class(None)
106 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
107 attr = getattr(owner, attr_name)
110 owner.save(self.db_conn)
111 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
112 attr = getattr(retrieved, attr_name)
113 self.assertEqual(sorted(attr.history.values()), vals)
115 def check_by_id(self) -> None:
116 """Test .by_id(), including creation."""
117 # check failure if not yet saved
118 id1, id2 = self.default_ids[0], self.default_ids[1]
119 obj = self.checked_class(id1) # pylint: disable=not-callable
120 with self.assertRaises(NotFoundException):
121 self.checked_class.by_id(self.db_conn, id1)
122 # check identity of saved and retrieved
123 obj.save(self.db_conn)
124 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
125 # check create=True acts like normal instantiation (sans saving)
126 by_id_created = self.checked_class.by_id(self.db_conn, id2,
128 # pylint: disable=not-callable
129 self.assertEqual(self.checked_class(id2), by_id_created)
130 self.check_storage([obj])
132 def check_from_table_row(self, *args: Any) -> None:
133 """Test .from_table_row() properly reads in class from DB"""
134 id_ = self.default_ids[0]
135 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
136 obj.save(self.db_conn)
137 assert isinstance(obj.id_, type(self.default_ids[0]))
138 for row in self.db_conn.row_where(self.checked_class.table_name,
140 hash_original = hash(obj)
141 retrieved = self.checked_class.from_table_row(self.db_conn, row)
142 self.assertEqual(hash_original, hash(retrieved))
143 self.assertEqual({retrieved.id_: retrieved},
144 self.checked_class.get_cache())
146 def check_versioned_from_table_row(self, attr_name: str,
147 type_: type) -> None:
148 """Test .from_table_row() reads versioned attributes from DB."""
149 owner = self.checked_class(None)
150 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
151 attr = getattr(owner, attr_name)
154 owner.save(self.db_conn)
155 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
156 retrieved = owner.__class__.from_table_row(self.db_conn, row)
157 attr = getattr(retrieved, attr_name)
158 self.assertEqual(sorted(attr.history.values()), vals)
160 def check_all(self) -> tuple[Any, Any, Any]:
162 # pylint: disable=not-callable
163 item1 = self.checked_class(self.default_ids[0])
164 item2 = self.checked_class(self.default_ids[1])
165 item3 = self.checked_class(self.default_ids[2])
166 # check pre-save .all() returns empty list
167 self.assertEqual(self.checked_class.all(self.db_conn), [])
168 # check that all() shows all saved, but no unsaved items
169 item1.save(self.db_conn)
170 item3.save(self.db_conn)
171 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
172 sorted([item1, item3]))
173 item2.save(self.db_conn)
174 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
175 sorted([item1, item2, item3]))
176 return item1, item2, item3
178 def check_singularity(self, defaulting_field: str,
179 non_default_value: Any, *args: Any) -> None:
180 """Test pointers made for single object keep pointing to it."""
181 id1 = self.default_ids[0]
182 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
183 obj.save(self.db_conn)
184 setattr(obj, defaulting_field, non_default_value)
185 retrieved = self.checked_class.by_id(self.db_conn, id1)
186 self.assertEqual(non_default_value,
187 getattr(retrieved, defaulting_field))
189 def check_versioned_singularity(self) -> None:
190 """Test singularity of VersionedAttributes on saving (with .title)."""
191 obj = self.checked_class(None) # pylint: disable=not-callable
192 obj.save(self.db_conn)
193 assert isinstance(obj.id_, int)
194 obj.title.set('named')
195 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
196 self.assertEqual(obj.title.history, retrieved.title.history)
198 def check_remove(self, *args: Any) -> None:
199 """Test .remove() effects on DB and cache."""
200 id_ = self.default_ids[0]
201 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
202 with self.assertRaises(HandledException):
203 obj.remove(self.db_conn)
204 obj.save(self.db_conn)
205 obj.remove(self.db_conn)
206 self.check_storage([])
209 class TestCaseWithServer(TestCaseWithDB):
210 """Module tests against our HTTP server/handler (and database)."""
212 def setUp(self) -> None:
214 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
215 self.server_thread = Thread(target=self.httpd.serve_forever)
216 self.server_thread.daemon = True
217 self.server_thread.start()
218 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
219 self.httpd.server_address[1])
220 self.httpd.set_json_mode()
222 def tearDown(self) -> None:
223 self.httpd.shutdown()
224 self.httpd.server_close()
225 self.server_thread.join()
228 def check_redirect(self, target: str) -> None:
229 """Check that self.conn answers with a 302 redirect to target."""
230 response = self.conn.getresponse()
231 self.assertEqual(response.status, 302)
232 self.assertEqual(response.getheader('Location'), target)
234 def check_get(self, target: str, expected_code: int) -> None:
235 """Check that a GET to target yields expected_code."""
236 self.conn.request('GET', target)
237 self.assertEqual(self.conn.getresponse().status, expected_code)
239 def check_post(self, data: Mapping[str, object], target: str,
240 expected_code: int, redirect_location: str = '') -> None:
241 """Check that POST of data to target yields expected_code."""
242 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
243 headers = {'Content-Type': 'application/x-www-form-urlencoded',
244 'Content-Length': str(len(encoded_form_data))}
245 self.conn.request('POST', target,
246 body=encoded_form_data, headers=headers)
247 if 302 == expected_code:
248 if redirect_location == '':
249 redirect_location = target
250 self.check_redirect(redirect_location)
252 self.assertEqual(self.conn.getresponse().status, expected_code)
254 def check_get_defaults(self, path: str) -> None:
255 """Some standard model paths to test."""
256 self.check_get(path, 200)
257 self.check_get(f'{path}?id=', 200)
258 self.check_get(f'{path}?id=foo', 400)
259 self.check_get(f'/{path}?id=0', 500)
260 self.check_get(f'{path}?id=1', 200)
262 def post_process(self, id_: int = 1,
263 form_data: dict[str, Any] | None = None
265 """POST basic Process."""
267 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
268 self.check_post(form_data, f'/process?id={id_}', 302,
269 f'/process?id={id_}')
273 def blank_history_keys_in(d: dict[str, object]) -> None:
274 """Re-write "history" object keys to bracketed integer strings."""
275 def walk_tree(d: Any) -> Any:
276 if isinstance(d, dict):
277 if 'history' in d.keys():
278 vals = d['history'].values()
280 for i, val in enumerate(vals):
281 history[f'[{i}]'] = val
282 d['history'] = history
283 for k in list(d.keys()):
285 elif isinstance(d, list):
286 d[:] = [walk_tree(i) for i in d]