1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
21 do_id_test: bool = False
22 default_init_args: list[Any] = []
23 versioned_defaults_to_test: dict[str, str | float] = {}
25 def test_id_setting(self) -> None:
26 """Test .id_ being set and its legal range being enforced."""
27 if not self.do_id_test:
29 with self.assertRaises(HandledException):
30 self.checked_class(0, *self.default_init_args)
31 obj = self.checked_class(5, *self.default_init_args)
32 self.assertEqual(obj.id_, 5)
34 def test_versioned_defaults(self) -> None:
35 """Test defaults of VersionedAttributes."""
36 if len(self.versioned_defaults_to_test) == 0:
38 obj = self.checked_class(1, *self.default_init_args)
39 for k, v in self.versioned_defaults_to_test.items():
40 self.assertEqual(getattr(obj, k).newest, v)
43 class TestCaseWithDB(TestCase):
44 """Module tests not requiring DB setup."""
46 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47 default_init_kwargs: dict[str, Any] = {}
48 test_versioneds: dict[str, type] = {}
50 def setUp(self) -> None:
51 Condition.empty_cache()
54 ProcessStep.empty_cache()
56 self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
57 self.db_conn = DatabaseConnection(self.db_file)
59 def tearDown(self) -> None:
61 remove_file(self.db_file.path)
63 def test_saving_and_caching(self) -> None:
64 """Test storage and initialization of instances and attributes."""
65 if not hasattr(self, 'checked_class'):
67 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
68 obj = self.checked_class(None, **self.default_init_kwargs)
69 obj.save(self.db_conn)
70 self.assertEqual(obj.id_, 2)
71 for k, v in self.test_versioneds.items():
72 self.check_saving_of_versioned(k, v)
74 def check_storage(self, content: list[Any]) -> None:
75 """Test cache and DB equal content."""
78 expected_cache[item.id_] = item
79 self.assertEqual(self.checked_class.get_cache(), expected_cache)
80 db_found: list[Any] = []
82 assert isinstance(item.id_, type(self.default_ids[0]))
83 for row in self.db_conn.row_where(self.checked_class.table_name,
85 db_found += [self.checked_class.from_table_row(self.db_conn,
87 self.assertEqual(sorted(content), sorted(db_found))
89 def check_saving_and_caching(self, **kwargs: Any) -> None:
90 """Test instance.save in its core without relations."""
91 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
92 # check object init itself doesn't store anything yet
93 self.check_storage([])
94 # check saving stores in cache and DB
95 obj.save(self.db_conn)
96 self.check_storage([obj])
97 # check core attributes set properly (and not unset by saving)
98 for key, value in kwargs.items():
99 self.assertEqual(getattr(obj, key), value)
101 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
102 """Test owner's versioned attributes."""
103 owner = self.checked_class(None)
104 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
105 attr = getattr(owner, attr_name)
108 owner.save(self.db_conn)
110 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
111 attr = getattr(retrieved, attr_name)
112 self.assertEqual(sorted(attr.history.values()), vals)
114 def check_by_id(self) -> None:
115 """Test .by_id(), including creation."""
116 # check failure if not yet saved
117 id1, id2 = self.default_ids[0], self.default_ids[1]
118 obj = self.checked_class(id1) # pylint: disable=not-callable
119 with self.assertRaises(NotFoundException):
120 self.checked_class.by_id(self.db_conn, id1)
121 # check identity of saved and retrieved
122 obj.save(self.db_conn)
123 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
124 # check create=True acts like normal instantiation (sans saving)
125 by_id_created = self.checked_class.by_id(self.db_conn, id2,
127 # pylint: disable=not-callable
128 self.assertEqual(self.checked_class(id2), by_id_created)
129 self.check_storage([obj])
131 def check_from_table_row(self, *args: Any) -> None:
132 """Test .from_table_row() properly reads in class from DB"""
133 id_ = self.default_ids[0]
134 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
135 obj.save(self.db_conn)
136 assert isinstance(obj.id_, type(self.default_ids[0]))
137 for row in self.db_conn.row_where(self.checked_class.table_name,
139 retrieved = self.checked_class.from_table_row(self.db_conn, row)
140 self.assertEqual(obj, retrieved)
141 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
143 def check_versioned_from_table_row(self, attr_name: str,
144 type_: type) -> None:
145 """Test .from_table_row() reads versioned attributes from DB."""
146 owner = self.checked_class(None)
147 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
148 attr = getattr(owner, attr_name)
151 owner.save(self.db_conn)
152 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
153 retrieved = owner.__class__.from_table_row(self.db_conn, row)
154 attr = getattr(retrieved, attr_name)
155 self.assertEqual(sorted(attr.history.values()), vals)
157 def check_all(self) -> tuple[Any, Any, Any]:
159 # pylint: disable=not-callable
160 item1 = self.checked_class(self.default_ids[0])
161 item2 = self.checked_class(self.default_ids[1])
162 item3 = self.checked_class(self.default_ids[2])
163 # check pre-save .all() returns empty list
164 self.assertEqual(self.checked_class.all(self.db_conn), [])
165 # check that all() shows all saved, but no unsaved items
166 item1.save(self.db_conn)
167 item3.save(self.db_conn)
168 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
169 sorted([item1, item3]))
170 item2.save(self.db_conn)
171 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
172 sorted([item1, item2, item3]))
173 return item1, item2, item3
175 def check_singularity(self, defaulting_field: str,
176 non_default_value: Any, *args: Any) -> None:
177 """Test pointers made for single object keep pointing to it."""
178 id1 = self.default_ids[0]
179 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
180 obj.save(self.db_conn)
181 setattr(obj, defaulting_field, non_default_value)
182 retrieved = self.checked_class.by_id(self.db_conn, id1)
183 self.assertEqual(non_default_value,
184 getattr(retrieved, defaulting_field))
186 def check_versioned_singularity(self) -> None:
187 """Test singularity of VersionedAttributes on saving (with .title)."""
188 obj = self.checked_class(None) # pylint: disable=not-callable
189 obj.save(self.db_conn)
190 assert isinstance(obj.id_, int)
191 obj.title.set('named')
192 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
193 self.assertEqual(obj.title.history, retrieved.title.history)
195 def check_remove(self, *args: Any) -> None:
196 """Test .remove() effects on DB and cache."""
197 id_ = self.default_ids[0]
198 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
199 with self.assertRaises(HandledException):
200 obj.remove(self.db_conn)
201 obj.save(self.db_conn)
202 obj.remove(self.db_conn)
203 self.check_storage([])
206 class TestCaseWithServer(TestCaseWithDB):
207 """Module tests against our HTTP server/handler (and database)."""
209 def setUp(self) -> None:
211 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
212 self.server_thread = Thread(target=self.httpd.serve_forever)
213 self.server_thread.daemon = True
214 self.server_thread.start()
215 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
216 self.httpd.server_address[1])
218 def tearDown(self) -> None:
219 self.httpd.shutdown()
220 self.httpd.server_close()
221 self.server_thread.join()
224 def check_redirect(self, target: str) -> None:
225 """Check that self.conn answers with a 302 redirect to target."""
226 response = self.conn.getresponse()
227 self.assertEqual(response.status, 302)
228 self.assertEqual(response.getheader('Location'), target)
230 def check_get(self, target: str, expected_code: int) -> None:
231 """Check that a GET to target yields expected_code."""
232 self.conn.request('GET', target)
233 self.assertEqual(self.conn.getresponse().status, expected_code)
235 def check_post(self, data: Mapping[str, object], target: str,
236 expected_code: int, redirect_location: str = '') -> None:
237 """Check that POST of data to target yields expected_code."""
238 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
239 headers = {'Content-Type': 'application/x-www-form-urlencoded',
240 'Content-Length': str(len(encoded_form_data))}
241 self.conn.request('POST', target,
242 body=encoded_form_data, headers=headers)
243 if 302 == expected_code:
244 if redirect_location == '':
245 redirect_location = target
246 self.check_redirect(redirect_location)
248 self.assertEqual(self.conn.getresponse().status, expected_code)
250 def check_get_defaults(self, path: str) -> None:
251 """Some standard model paths to test."""
252 self.check_get(path, 200)
253 self.check_get(f'{path}?id=', 200)
254 self.check_get(f'{path}?id=foo', 400)
255 self.check_get(f'/{path}?id=0', 500)
256 self.check_get(f'{path}?id=1', 200)
258 def post_process(self, id_: int = 1,
259 form_data: dict[str, Any] | None = None
261 """POST basic Process."""
263 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
264 self.check_post(form_data, f'/process?id={id_}', 302,
265 f'/process?id={id_}')