1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
21 do_id_test: bool = False
22 default_init_args: list[Any] = []
23 versioned_defaults_to_test: dict[str, str | float] = {}
25 def test_id_setting(self) -> None:
26 """Test .id_ being set and its legal range being enforced."""
27 if not self.do_id_test:
29 with self.assertRaises(HandledException):
30 self.checked_class(0, *self.default_init_args)
31 obj = self.checked_class(5, *self.default_init_args)
32 self.assertEqual(obj.id_, 5)
34 def test_versioned_defaults(self) -> None:
35 """Test defaults of VersionedAttributes."""
36 if len(self.versioned_defaults_to_test) == 0:
38 obj = self.checked_class(1, *self.default_init_args)
39 for k, v in self.versioned_defaults_to_test.items():
40 self.assertEqual(getattr(obj, k).newest, v)
43 class TestCaseWithDB(TestCase):
44 """Module tests not requiring DB setup."""
46 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47 default_init_kwargs: dict[str, Any] = {}
48 test_versioneds: dict[str, type] = {}
50 def setUp(self) -> None:
51 Condition.empty_cache()
54 ProcessStep.empty_cache()
56 timestamp = datetime.now().timestamp()
57 self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
58 self.db_conn = DatabaseConnection(self.db_file)
60 def tearDown(self) -> None:
62 remove_file(self.db_file.path)
64 def test_saving_and_caching(self) -> None:
65 """Test storage and initialization of instances and attributes."""
66 if not hasattr(self, 'checked_class'):
68 self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69 obj = self.checked_class(None, **self.default_init_kwargs)
70 obj.save(self.db_conn)
71 self.assertEqual(obj.id_, 2)
72 for k, v in self.test_versioneds.items():
73 self.check_saving_of_versioned(k, v)
75 def check_storage(self, content: list[Any]) -> None:
76 """Test cache and DB equal content."""
79 expected_cache[item.id_] = item
80 self.assertEqual(self.checked_class.get_cache(), expected_cache)
81 db_found: list[Any] = []
83 assert isinstance(item.id_, type(self.default_ids[0]))
84 for row in self.db_conn.row_where(self.checked_class.table_name,
86 db_found += [self.checked_class.from_table_row(self.db_conn,
88 self.assertEqual(sorted(content), sorted(db_found))
90 def check_saving_and_caching(self, **kwargs: Any) -> None:
91 """Test instance.save in its core without relations."""
92 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
93 # check object init itself doesn't store anything yet
94 self.check_storage([])
95 # check saving stores in cache and DB
96 obj.save(self.db_conn)
97 self.check_storage([obj])
98 # check core attributes set properly (and not unset by saving)
99 for key, value in kwargs.items():
100 self.assertEqual(getattr(obj, key), value)
102 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
103 """Test owner's versioned attributes."""
104 owner = self.checked_class(None)
105 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
106 attr = getattr(owner, attr_name)
109 owner.save(self.db_conn)
111 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
112 attr = getattr(retrieved, attr_name)
113 self.assertEqual(sorted(attr.history.values()), vals)
115 def check_by_id(self) -> None:
116 """Test .by_id(), including creation."""
117 # check failure if not yet saved
118 id1, id2 = self.default_ids[0], self.default_ids[1]
119 obj = self.checked_class(id1) # pylint: disable=not-callable
120 with self.assertRaises(NotFoundException):
121 self.checked_class.by_id(self.db_conn, id1)
122 # check identity of saved and retrieved
123 obj.save(self.db_conn)
124 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
125 # check create=True acts like normal instantiation (sans saving)
126 by_id_created = self.checked_class.by_id(self.db_conn, id2,
128 # pylint: disable=not-callable
129 self.assertEqual(self.checked_class(id2), by_id_created)
130 self.check_storage([obj])
132 def check_from_table_row(self, *args: Any) -> None:
133 """Test .from_table_row() properly reads in class from DB"""
134 id_ = self.default_ids[0]
135 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
136 obj.save(self.db_conn)
137 assert isinstance(obj.id_, type(self.default_ids[0]))
138 for row in self.db_conn.row_where(self.checked_class.table_name,
140 retrieved = self.checked_class.from_table_row(self.db_conn, row)
141 self.assertEqual(obj, retrieved)
142 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
144 def check_versioned_from_table_row(self, attr_name: str,
145 type_: type) -> None:
146 """Test .from_table_row() reads versioned attributes from DB."""
147 owner = self.checked_class(None)
148 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
149 attr = getattr(owner, attr_name)
152 owner.save(self.db_conn)
153 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
154 retrieved = owner.__class__.from_table_row(self.db_conn, row)
155 attr = getattr(retrieved, attr_name)
156 self.assertEqual(sorted(attr.history.values()), vals)
158 def check_all(self) -> tuple[Any, Any, Any]:
160 # pylint: disable=not-callable
161 item1 = self.checked_class(self.default_ids[0])
162 item2 = self.checked_class(self.default_ids[1])
163 item3 = self.checked_class(self.default_ids[2])
164 # check pre-save .all() returns empty list
165 self.assertEqual(self.checked_class.all(self.db_conn), [])
166 # check that all() shows all saved, but no unsaved items
167 item1.save(self.db_conn)
168 item3.save(self.db_conn)
169 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
170 sorted([item1, item3]))
171 item2.save(self.db_conn)
172 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
173 sorted([item1, item2, item3]))
174 return item1, item2, item3
176 def check_singularity(self, defaulting_field: str,
177 non_default_value: Any, *args: Any) -> None:
178 """Test pointers made for single object keep pointing to it."""
179 id1 = self.default_ids[0]
180 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
181 obj.save(self.db_conn)
182 setattr(obj, defaulting_field, non_default_value)
183 retrieved = self.checked_class.by_id(self.db_conn, id1)
184 self.assertEqual(non_default_value,
185 getattr(retrieved, defaulting_field))
187 def check_versioned_singularity(self) -> None:
188 """Test singularity of VersionedAttributes on saving (with .title)."""
189 obj = self.checked_class(None) # pylint: disable=not-callable
190 obj.save(self.db_conn)
191 assert isinstance(obj.id_, int)
192 obj.title.set('named')
193 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
194 self.assertEqual(obj.title.history, retrieved.title.history)
196 def check_remove(self, *args: Any) -> None:
197 """Test .remove() effects on DB and cache."""
198 id_ = self.default_ids[0]
199 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
200 with self.assertRaises(HandledException):
201 obj.remove(self.db_conn)
202 obj.save(self.db_conn)
203 obj.remove(self.db_conn)
204 self.check_storage([])
207 class TestCaseWithServer(TestCaseWithDB):
208 """Module tests against our HTTP server/handler (and database)."""
210 def setUp(self) -> None:
212 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
213 self.server_thread = Thread(target=self.httpd.serve_forever)
214 self.server_thread.daemon = True
215 self.server_thread.start()
216 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
217 self.httpd.server_address[1])
219 def tearDown(self) -> None:
220 self.httpd.shutdown()
221 self.httpd.server_close()
222 self.server_thread.join()
225 def check_redirect(self, target: str) -> None:
226 """Check that self.conn answers with a 302 redirect to target."""
227 response = self.conn.getresponse()
228 self.assertEqual(response.status, 302)
229 self.assertEqual(response.getheader('Location'), target)
231 def check_get(self, target: str, expected_code: int) -> None:
232 """Check that a GET to target yields expected_code."""
233 self.conn.request('GET', target)
234 self.assertEqual(self.conn.getresponse().status, expected_code)
236 def check_post(self, data: Mapping[str, object], target: str,
237 expected_code: int, redirect_location: str = '') -> None:
238 """Check that POST of data to target yields expected_code."""
239 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
240 headers = {'Content-Type': 'application/x-www-form-urlencoded',
241 'Content-Length': str(len(encoded_form_data))}
242 self.conn.request('POST', target,
243 body=encoded_form_data, headers=headers)
244 if 302 == expected_code:
245 if redirect_location == '':
246 redirect_location = target
247 self.check_redirect(redirect_location)
249 self.assertEqual(self.conn.getresponse().status, expected_code)
251 def check_get_defaults(self, path: str) -> None:
252 """Some standard model paths to test."""
253 self.check_get(path, 200)
254 self.check_get(f'{path}?id=', 200)
255 self.check_get(f'{path}?id=foo', 400)
256 self.check_get(f'/{path}?id=0', 500)
257 self.check_get(f'{path}?id=1', 200)
259 def post_process(self, id_: int = 1,
260 form_data: dict[str, Any] | None = None
262 """POST basic Process."""
264 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
265 self.check_post(form_data, f'/process?id={id_}', 302,
266 f'/process?id={id_}')