1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
22 def check_id_setting(self, *args: Any) -> None:
23 """Test .id_ being set and its legal range being enforced."""
24 with self.assertRaises(HandledException):
25 self.checked_class(0, *args)
26 obj = self.checked_class(5, *args)
27 self.assertEqual(obj.id_, 5)
29 def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
30 """Test defaults of VersionedAttributes."""
31 obj = self.checked_class(None)
32 for k, v in attrs.items():
33 self.assertEqual(getattr(obj, k).newest, v)
36 class TestCaseWithDB(TestCase):
37 """Module tests not requiring DB setup."""
39 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
41 def setUp(self) -> None:
42 Condition.empty_cache()
45 ProcessStep.empty_cache()
47 timestamp = datetime.now().timestamp()
48 self.db_file = DatabaseFile(f'test_db:{timestamp}')
50 self.db_conn = DatabaseConnection(self.db_file)
52 def tearDown(self) -> None:
54 remove_file(self.db_file.path)
56 def check_storage(self, content: list[Any]) -> None:
57 """Test cache and DB equal content."""
60 expected_cache[item.id_] = item
61 self.assertEqual(self.checked_class.get_cache(), expected_cache)
62 db_found: list[Any] = []
64 assert isinstance(item.id_, type(self.default_ids[0]))
65 for row in self.db_conn.row_where(self.checked_class.table_name,
67 db_found += [self.checked_class.from_table_row(self.db_conn,
69 self.assertEqual(sorted(content), sorted(db_found))
71 def check_saving_and_caching(self, **kwargs: Any) -> None:
72 """Test instance.save in its core without relations."""
73 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
74 # check object init itself doesn't store anything yet
75 self.check_storage([])
76 # check saving stores in cache and DB
77 obj.save(self.db_conn)
78 self.check_storage([obj])
79 # check core attributes set properly (and not unset by saving)
80 for key, value in kwargs.items():
81 self.assertEqual(getattr(obj, key), value)
83 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
84 """Test owner's versioned attributes."""
85 owner = self.checked_class(None)
86 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
87 attr = getattr(owner, attr_name)
90 owner.save(self.db_conn)
92 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
93 attr = getattr(retrieved, attr_name)
94 self.assertEqual(sorted(attr.history.values()), vals)
96 def check_by_id(self) -> None:
97 """Test .by_id(), including creation."""
98 # check failure if not yet saved
99 id1, id2 = self.default_ids[0], self.default_ids[1]
100 obj = self.checked_class(id1) # pylint: disable=not-callable
101 with self.assertRaises(NotFoundException):
102 self.checked_class.by_id(self.db_conn, id1)
103 # check identity of saved and retrieved
104 obj.save(self.db_conn)
105 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
106 # check create=True acts like normal instantiation (sans saving)
107 by_id_created = self.checked_class.by_id(self.db_conn, id2,
109 # pylint: disable=not-callable
110 self.assertEqual(self.checked_class(id2), by_id_created)
111 self.check_storage([obj])
113 def check_from_table_row(self, *args: Any) -> None:
114 """Test .from_table_row() properly reads in class from DB"""
115 id_ = self.default_ids[0]
116 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
117 obj.save(self.db_conn)
118 assert isinstance(obj.id_, type(self.default_ids[0]))
119 for row in self.db_conn.row_where(self.checked_class.table_name,
121 retrieved = self.checked_class.from_table_row(self.db_conn, row)
122 self.assertEqual(obj, retrieved)
123 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
125 def check_versioned_from_table_row(self, attr_name: str,
126 type_: type) -> None:
127 """Test .from_table_row() reads versioned attributes from DB."""
128 owner = self.checked_class(None)
129 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
130 attr = getattr(owner, attr_name)
133 owner.save(self.db_conn)
134 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
135 retrieved = owner.__class__.from_table_row(self.db_conn, row)
136 attr = getattr(retrieved, attr_name)
137 self.assertEqual(sorted(attr.history.values()), vals)
139 def check_all(self) -> tuple[Any, Any, Any]:
141 # pylint: disable=not-callable
142 item1 = self.checked_class(self.default_ids[0])
143 item2 = self.checked_class(self.default_ids[1])
144 item3 = self.checked_class(self.default_ids[2])
145 # check pre-save .all() returns empty list
146 self.assertEqual(self.checked_class.all(self.db_conn), [])
147 # check that all() shows all saved, but no unsaved items
148 item1.save(self.db_conn)
149 item3.save(self.db_conn)
150 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
151 sorted([item1, item3]))
152 item2.save(self.db_conn)
153 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
154 sorted([item1, item2, item3]))
155 return item1, item2, item3
157 def check_singularity(self, defaulting_field: str,
158 non_default_value: Any, *args: Any) -> None:
159 """Test pointers made for single object keep pointing to it."""
160 id1 = self.default_ids[0]
161 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
162 obj.save(self.db_conn)
163 setattr(obj, defaulting_field, non_default_value)
164 retrieved = self.checked_class.by_id(self.db_conn, id1)
165 self.assertEqual(non_default_value,
166 getattr(retrieved, defaulting_field))
168 def check_versioned_singularity(self) -> None:
169 """Test singularity of VersionedAttributes on saving (with .title)."""
170 obj = self.checked_class(None) # pylint: disable=not-callable
171 obj.save(self.db_conn)
172 assert isinstance(obj.id_, int)
173 obj.title.set('named')
174 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
175 self.assertEqual(obj.title.history, retrieved.title.history)
177 def check_remove(self, *args: Any) -> None:
178 """Test .remove() effects on DB and cache."""
179 id_ = self.default_ids[0]
180 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
181 with self.assertRaises(HandledException):
182 obj.remove(self.db_conn)
183 obj.save(self.db_conn)
184 obj.remove(self.db_conn)
185 self.check_storage([])
188 class TestCaseWithServer(TestCaseWithDB):
189 """Module tests against our HTTP server/handler (and database)."""
191 def setUp(self) -> None:
193 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
194 self.server_thread = Thread(target=self.httpd.serve_forever)
195 self.server_thread.daemon = True
196 self.server_thread.start()
197 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
198 self.httpd.server_address[1])
200 def tearDown(self) -> None:
201 self.httpd.shutdown()
202 self.httpd.server_close()
203 self.server_thread.join()
206 def check_redirect(self, target: str) -> None:
207 """Check that self.conn answers with a 302 redirect to target."""
208 response = self.conn.getresponse()
209 self.assertEqual(response.status, 302)
210 self.assertEqual(response.getheader('Location'), target)
212 def check_get(self, target: str, expected_code: int) -> None:
213 """Check that a GET to target yields expected_code."""
214 self.conn.request('GET', target)
215 self.assertEqual(self.conn.getresponse().status, expected_code)
217 def check_post(self, data: Mapping[str, object], target: str,
218 expected_code: int, redirect_location: str = '') -> None:
219 """Check that POST of data to target yields expected_code."""
220 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
221 headers = {'Content-Type': 'application/x-www-form-urlencoded',
222 'Content-Length': str(len(encoded_form_data))}
223 self.conn.request('POST', target,
224 body=encoded_form_data, headers=headers)
225 if 302 == expected_code:
226 if redirect_location == '':
227 redirect_location = target
228 self.check_redirect(redirect_location)
230 self.assertEqual(self.conn.getresponse().status, expected_code)
232 def check_get_defaults(self, path: str) -> None:
233 """Some standard model paths to test."""
234 self.check_get(path, 200)
235 self.check_get(f'{path}?id=', 200)
236 self.check_get(f'{path}?id=foo', 400)
237 self.check_get(f'/{path}?id=0', 500)
238 self.check_get(f'{path}?id=1', 200)
240 def post_process(self, id_: int = 1,
241 form_data: dict[str, Any] | None = None
243 """POST basic Process."""
245 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
246 self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')