1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
21 do_id_test: bool = False
22 default_init_args: list[Any] = []
23 versioned_defaults_to_test: dict[str, str | float] = {}
25 def test_id_setting(self) -> None:
26 """Test .id_ being set and its legal range being enforced."""
27 if not self.do_id_test:
29 with self.assertRaises(HandledException):
30 self.checked_class(0, *self.default_init_args)
31 obj = self.checked_class(5, *self.default_init_args)
32 self.assertEqual(obj.id_, 5)
34 def test_versioned_defaults(self) -> None:
35 """Test defaults of VersionedAttributes."""
36 if len(self.versioned_defaults_to_test) == 0:
38 obj = self.checked_class(1, *self.default_init_args)
39 for k, v in self.versioned_defaults_to_test.items():
40 self.assertEqual(getattr(obj, k).newest, v)
43 class TestCaseWithDB(TestCase):
44 """Module tests not requiring DB setup."""
46 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
48 def setUp(self) -> None:
49 Condition.empty_cache()
52 ProcessStep.empty_cache()
54 timestamp = datetime.now().timestamp()
55 self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
56 self.db_conn = DatabaseConnection(self.db_file)
58 def tearDown(self) -> None:
60 remove_file(self.db_file.path)
62 def check_storage(self, content: list[Any]) -> None:
63 """Test cache and DB equal content."""
66 expected_cache[item.id_] = item
67 self.assertEqual(self.checked_class.get_cache(), expected_cache)
68 db_found: list[Any] = []
70 assert isinstance(item.id_, type(self.default_ids[0]))
71 for row in self.db_conn.row_where(self.checked_class.table_name,
73 db_found += [self.checked_class.from_table_row(self.db_conn,
75 self.assertEqual(sorted(content), sorted(db_found))
77 def check_saving_and_caching(self, **kwargs: Any) -> None:
78 """Test instance.save in its core without relations."""
79 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
80 # check object init itself doesn't store anything yet
81 self.check_storage([])
82 # check saving stores in cache and DB
83 obj.save(self.db_conn)
84 self.check_storage([obj])
85 # check core attributes set properly (and not unset by saving)
86 for key, value in kwargs.items():
87 self.assertEqual(getattr(obj, key), value)
89 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
90 """Test owner's versioned attributes."""
91 owner = self.checked_class(None)
92 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
93 attr = getattr(owner, attr_name)
96 owner.save(self.db_conn)
98 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
99 attr = getattr(retrieved, attr_name)
100 self.assertEqual(sorted(attr.history.values()), vals)
102 def check_by_id(self) -> None:
103 """Test .by_id(), including creation."""
104 # check failure if not yet saved
105 id1, id2 = self.default_ids[0], self.default_ids[1]
106 obj = self.checked_class(id1) # pylint: disable=not-callable
107 with self.assertRaises(NotFoundException):
108 self.checked_class.by_id(self.db_conn, id1)
109 # check identity of saved and retrieved
110 obj.save(self.db_conn)
111 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
112 # check create=True acts like normal instantiation (sans saving)
113 by_id_created = self.checked_class.by_id(self.db_conn, id2,
115 # pylint: disable=not-callable
116 self.assertEqual(self.checked_class(id2), by_id_created)
117 self.check_storage([obj])
119 def check_from_table_row(self, *args: Any) -> None:
120 """Test .from_table_row() properly reads in class from DB"""
121 id_ = self.default_ids[0]
122 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
123 obj.save(self.db_conn)
124 assert isinstance(obj.id_, type(self.default_ids[0]))
125 for row in self.db_conn.row_where(self.checked_class.table_name,
127 retrieved = self.checked_class.from_table_row(self.db_conn, row)
128 self.assertEqual(obj, retrieved)
129 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
131 def check_versioned_from_table_row(self, attr_name: str,
132 type_: type) -> None:
133 """Test .from_table_row() reads versioned attributes from DB."""
134 owner = self.checked_class(None)
135 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
136 attr = getattr(owner, attr_name)
139 owner.save(self.db_conn)
140 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
141 retrieved = owner.__class__.from_table_row(self.db_conn, row)
142 attr = getattr(retrieved, attr_name)
143 self.assertEqual(sorted(attr.history.values()), vals)
145 def check_all(self) -> tuple[Any, Any, Any]:
147 # pylint: disable=not-callable
148 item1 = self.checked_class(self.default_ids[0])
149 item2 = self.checked_class(self.default_ids[1])
150 item3 = self.checked_class(self.default_ids[2])
151 # check pre-save .all() returns empty list
152 self.assertEqual(self.checked_class.all(self.db_conn), [])
153 # check that all() shows all saved, but no unsaved items
154 item1.save(self.db_conn)
155 item3.save(self.db_conn)
156 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
157 sorted([item1, item3]))
158 item2.save(self.db_conn)
159 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
160 sorted([item1, item2, item3]))
161 return item1, item2, item3
163 def check_singularity(self, defaulting_field: str,
164 non_default_value: Any, *args: Any) -> None:
165 """Test pointers made for single object keep pointing to it."""
166 id1 = self.default_ids[0]
167 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
168 obj.save(self.db_conn)
169 setattr(obj, defaulting_field, non_default_value)
170 retrieved = self.checked_class.by_id(self.db_conn, id1)
171 self.assertEqual(non_default_value,
172 getattr(retrieved, defaulting_field))
174 def check_versioned_singularity(self) -> None:
175 """Test singularity of VersionedAttributes on saving (with .title)."""
176 obj = self.checked_class(None) # pylint: disable=not-callable
177 obj.save(self.db_conn)
178 assert isinstance(obj.id_, int)
179 obj.title.set('named')
180 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
181 self.assertEqual(obj.title.history, retrieved.title.history)
183 def check_remove(self, *args: Any) -> None:
184 """Test .remove() effects on DB and cache."""
185 id_ = self.default_ids[0]
186 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
187 with self.assertRaises(HandledException):
188 obj.remove(self.db_conn)
189 obj.save(self.db_conn)
190 obj.remove(self.db_conn)
191 self.check_storage([])
194 class TestCaseWithServer(TestCaseWithDB):
195 """Module tests against our HTTP server/handler (and database)."""
197 def setUp(self) -> None:
199 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
200 self.server_thread = Thread(target=self.httpd.serve_forever)
201 self.server_thread.daemon = True
202 self.server_thread.start()
203 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
204 self.httpd.server_address[1])
206 def tearDown(self) -> None:
207 self.httpd.shutdown()
208 self.httpd.server_close()
209 self.server_thread.join()
212 def check_redirect(self, target: str) -> None:
213 """Check that self.conn answers with a 302 redirect to target."""
214 response = self.conn.getresponse()
215 self.assertEqual(response.status, 302)
216 self.assertEqual(response.getheader('Location'), target)
218 def check_get(self, target: str, expected_code: int) -> None:
219 """Check that a GET to target yields expected_code."""
220 self.conn.request('GET', target)
221 self.assertEqual(self.conn.getresponse().status, expected_code)
223 def check_post(self, data: Mapping[str, object], target: str,
224 expected_code: int, redirect_location: str = '') -> None:
225 """Check that POST of data to target yields expected_code."""
226 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
227 headers = {'Content-Type': 'application/x-www-form-urlencoded',
228 'Content-Length': str(len(encoded_form_data))}
229 self.conn.request('POST', target,
230 body=encoded_form_data, headers=headers)
231 if 302 == expected_code:
232 if redirect_location == '':
233 redirect_location = target
234 self.check_redirect(redirect_location)
236 self.assertEqual(self.conn.getresponse().status, expected_code)
238 def check_get_defaults(self, path: str) -> None:
239 """Some standard model paths to test."""
240 self.check_get(path, 200)
241 self.check_get(f'{path}?id=', 200)
242 self.check_get(f'{path}?id=foo', 400)
243 self.check_get(f'/{path}?id=0', 500)
244 self.check_get(f'{path}?id=1', 200)
246 def post_process(self, id_: int = 1,
247 form_data: dict[str, Any] | None = None
249 """POST basic Process."""
251 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
252 self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')