1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
21 do_id_test: bool = False
22 default_init_args: list[Any] = []
23 versioned_defaults_to_test: dict[str, str | float] = {}
25 def test_id_setting(self) -> None:
26 """Test .id_ being set and its legal range being enforced."""
27 if not self.do_id_test:
29 with self.assertRaises(HandledException):
30 self.checked_class(0, *self.default_init_args)
31 obj = self.checked_class(5, *self.default_init_args)
32 self.assertEqual(obj.id_, 5)
34 def test_versioned_defaults(self) -> None:
35 """Test defaults of VersionedAttributes."""
36 if len(self.versioned_defaults_to_test) == 0:
38 obj = self.checked_class(1, *self.default_init_args)
39 for k, v in self.versioned_defaults_to_test.items():
40 self.assertEqual(getattr(obj, k).newest, v)
43 class TestCaseWithDB(TestCase):
44 """Module tests not requiring DB setup."""
46 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47 default_init_kwargs: dict[str, Any] = {}
48 test_versioneds: dict[str, type] = {}
50 def setUp(self) -> None:
51 Condition.empty_cache()
54 ProcessStep.empty_cache()
56 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
57 self.db_conn = DatabaseConnection(self.db_file)
59 def tearDown(self) -> None:
61 remove_file(self.db_file.path)
63 def test_saving_and_caching(self) -> None:
64 """Test storage and initialization of instances and attributes."""
65 if not hasattr(self, 'checked_class'):
67 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
68 obj = self.checked_class(None, **self.default_init_kwargs)
69 obj.save(self.db_conn)
70 self.assertEqual(obj.id_, 2)
71 for k, v in self.test_versioneds.items():
72 self.check_saving_of_versioned(k, v)
74 def check_storage(self, content: list[Any]) -> None:
75 """Test cache and DB equal content."""
78 expected_cache[item.id_] = item
79 self.assertEqual(self.checked_class.get_cache(), expected_cache)
80 hashes_content = [hash(x) for x in content]
81 db_found: list[Any] = []
83 assert isinstance(item.id_, type(self.default_ids[0]))
84 for row in self.db_conn.row_where(self.checked_class.table_name,
86 db_found += [self.checked_class.from_table_row(self.db_conn,
88 hashes_db_found = [hash(x) for x in db_found]
89 self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
91 def check_saving_and_caching(self, **kwargs: Any) -> None:
92 """Test instance.save in its core without relations."""
93 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
94 # check object init itself doesn't store anything yet
95 self.check_storage([])
96 # check saving sets core attributes properly
97 obj.save(self.db_conn)
98 for key, value in kwargs.items():
99 self.assertEqual(getattr(obj, key), value)
100 # check saving stored properly in cache and DB
101 self.check_storage([obj])
103 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
104 """Test owner's versioned attributes."""
105 owner = self.checked_class(None)
106 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
107 attr = getattr(owner, attr_name)
110 owner.save(self.db_conn)
111 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
112 attr = getattr(retrieved, attr_name)
113 self.assertEqual(sorted(attr.history.values()), vals)
115 def check_by_id(self) -> None:
116 """Test .by_id(), including creation."""
117 # check failure if not yet saved
118 id1, id2 = self.default_ids[0], self.default_ids[1]
119 obj = self.checked_class(id1) # pylint: disable=not-callable
120 with self.assertRaises(NotFoundException):
121 self.checked_class.by_id(self.db_conn, id1)
122 # check identity of saved and retrieved
123 obj.save(self.db_conn)
124 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
125 # check create=True acts like normal instantiation (sans saving)
126 by_id_created = self.checked_class.by_id(self.db_conn, id2,
128 # pylint: disable=not-callable
129 self.assertEqual(self.checked_class(id2), by_id_created)
130 self.check_storage([obj])
132 def check_from_table_row(self, *args: Any) -> None:
133 """Test .from_table_row() properly reads in class from DB"""
134 id_ = self.default_ids[0]
135 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
136 obj.save(self.db_conn)
137 assert isinstance(obj.id_, type(self.default_ids[0]))
138 for row in self.db_conn.row_where(self.checked_class.table_name,
140 hash_original = hash(obj)
141 retrieved = self.checked_class.from_table_row(self.db_conn, row)
142 self.assertEqual(hash_original, hash(retrieved))
143 self.assertEqual({retrieved.id_: retrieved},
144 self.checked_class.get_cache())
146 def check_versioned_from_table_row(self, attr_name: str,
147 type_: type) -> None:
148 """Test .from_table_row() reads versioned attributes from DB."""
149 owner = self.checked_class(None)
150 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
151 attr = getattr(owner, attr_name)
154 owner.save(self.db_conn)
155 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
156 retrieved = owner.__class__.from_table_row(self.db_conn, row)
157 attr = getattr(retrieved, attr_name)
158 self.assertEqual(sorted(attr.history.values()), vals)
160 def check_all(self) -> tuple[Any, Any, Any]:
162 # pylint: disable=not-callable
163 item1 = self.checked_class(self.default_ids[0])
164 item2 = self.checked_class(self.default_ids[1])
165 item3 = self.checked_class(self.default_ids[2])
166 # check pre-save .all() returns empty list
167 self.assertEqual(self.checked_class.all(self.db_conn), [])
168 # check that all() shows all saved, but no unsaved items
169 item1.save(self.db_conn)
170 item3.save(self.db_conn)
171 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
172 sorted([item1, item3]))
173 item2.save(self.db_conn)
174 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
175 sorted([item1, item2, item3]))
176 return item1, item2, item3
178 def check_singularity(self, defaulting_field: str,
179 non_default_value: Any, *args: Any) -> None:
180 """Test pointers made for single object keep pointing to it."""
181 id1 = self.default_ids[0]
182 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
183 obj.save(self.db_conn)
184 setattr(obj, defaulting_field, non_default_value)
185 retrieved = self.checked_class.by_id(self.db_conn, id1)
186 self.assertEqual(non_default_value,
187 getattr(retrieved, defaulting_field))
189 def check_versioned_singularity(self) -> None:
190 """Test singularity of VersionedAttributes on saving (with .title)."""
191 obj = self.checked_class(None) # pylint: disable=not-callable
192 obj.save(self.db_conn)
193 assert isinstance(obj.id_, int)
194 obj.title.set('named')
195 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
196 self.assertEqual(obj.title.history, retrieved.title.history)
198 def check_remove(self, *args: Any) -> None:
199 """Test .remove() effects on DB and cache."""
200 id_ = self.default_ids[0]
201 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
202 with self.assertRaises(HandledException):
203 obj.remove(self.db_conn)
204 obj.save(self.db_conn)
205 obj.remove(self.db_conn)
206 self.check_storage([])
209 class TestCaseWithServer(TestCaseWithDB):
210 """Module tests against our HTTP server/handler (and database)."""
212 def setUp(self) -> None:
214 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
215 self.server_thread = Thread(target=self.httpd.serve_forever)
216 self.server_thread.daemon = True
217 self.server_thread.start()
218 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
219 self.httpd.server_address[1])
221 def tearDown(self) -> None:
222 self.httpd.shutdown()
223 self.httpd.server_close()
224 self.server_thread.join()
227 def check_redirect(self, target: str) -> None:
228 """Check that self.conn answers with a 302 redirect to target."""
229 response = self.conn.getresponse()
230 self.assertEqual(response.status, 302)
231 self.assertEqual(response.getheader('Location'), target)
233 def check_get(self, target: str, expected_code: int) -> None:
234 """Check that a GET to target yields expected_code."""
235 self.conn.request('GET', target)
236 self.assertEqual(self.conn.getresponse().status, expected_code)
238 def check_post(self, data: Mapping[str, object], target: str,
239 expected_code: int, redirect_location: str = '') -> None:
240 """Check that POST of data to target yields expected_code."""
241 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
242 headers = {'Content-Type': 'application/x-www-form-urlencoded',
243 'Content-Length': str(len(encoded_form_data))}
244 self.conn.request('POST', target,
245 body=encoded_form_data, headers=headers)
246 if 302 == expected_code:
247 if redirect_location == '':
248 redirect_location = target
249 self.check_redirect(redirect_location)
251 self.assertEqual(self.conn.getresponse().status, expected_code)
253 def check_get_defaults(self, path: str) -> None:
254 """Some standard model paths to test."""
255 self.check_get(path, 200)
256 self.check_get(f'{path}?id=', 200)
257 self.check_get(f'{path}?id=foo', 400)
258 self.check_get(f'/{path}?id=0', 500)
259 self.check_get(f'{path}?id=1', 200)
261 def post_process(self, id_: int = 1,
262 form_data: dict[str, Any] | None = None
264 """POST basic Process."""
266 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
267 self.check_post(form_data, f'/process?id={id_}', 302,
268 f'/process?id={id_}')