1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
22 def check_id_setting(self, *args: Any) -> None:
23 """Test .id_ being set and its legal range being enforced."""
24 with self.assertRaises(HandledException):
25 self.checked_class(0, *args)
26 obj = self.checked_class(5, *args)
27 self.assertEqual(obj.id_, 5)
29 def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
30 """Test defaults of VersionedAttributes."""
31 obj = self.checked_class(None)
32 for k, v in attrs.items():
33 self.assertEqual(getattr(obj, k).newest, v)
36 class TestCaseWithDB(TestCase):
37 """Module tests not requiring DB setup."""
39 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
41 def setUp(self) -> None:
42 Condition.empty_cache()
45 ProcessStep.empty_cache()
47 timestamp = datetime.now().timestamp()
48 self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
49 self.db_conn = DatabaseConnection(self.db_file)
51 def tearDown(self) -> None:
53 remove_file(self.db_file.path)
55 def check_storage(self, content: list[Any]) -> None:
56 """Test cache and DB equal content."""
59 expected_cache[item.id_] = item
60 self.assertEqual(self.checked_class.get_cache(), expected_cache)
61 db_found: list[Any] = []
63 assert isinstance(item.id_, type(self.default_ids[0]))
64 for row in self.db_conn.row_where(self.checked_class.table_name,
66 db_found += [self.checked_class.from_table_row(self.db_conn,
68 self.assertEqual(sorted(content), sorted(db_found))
70 def check_saving_and_caching(self, **kwargs: Any) -> None:
71 """Test instance.save in its core without relations."""
72 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
73 # check object init itself doesn't store anything yet
74 self.check_storage([])
75 # check saving stores in cache and DB
76 obj.save(self.db_conn)
77 self.check_storage([obj])
78 # check core attributes set properly (and not unset by saving)
79 for key, value in kwargs.items():
80 self.assertEqual(getattr(obj, key), value)
82 def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
83 """Test owner's versioned attributes."""
84 owner = self.checked_class(None)
85 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
86 attr = getattr(owner, attr_name)
89 owner.save(self.db_conn)
91 retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
92 attr = getattr(retrieved, attr_name)
93 self.assertEqual(sorted(attr.history.values()), vals)
95 def check_by_id(self) -> None:
96 """Test .by_id(), including creation."""
97 # check failure if not yet saved
98 id1, id2 = self.default_ids[0], self.default_ids[1]
99 obj = self.checked_class(id1) # pylint: disable=not-callable
100 with self.assertRaises(NotFoundException):
101 self.checked_class.by_id(self.db_conn, id1)
102 # check identity of saved and retrieved
103 obj.save(self.db_conn)
104 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
105 # check create=True acts like normal instantiation (sans saving)
106 by_id_created = self.checked_class.by_id(self.db_conn, id2,
108 # pylint: disable=not-callable
109 self.assertEqual(self.checked_class(id2), by_id_created)
110 self.check_storage([obj])
112 def check_from_table_row(self, *args: Any) -> None:
113 """Test .from_table_row() properly reads in class from DB"""
114 id_ = self.default_ids[0]
115 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
116 obj.save(self.db_conn)
117 assert isinstance(obj.id_, type(self.default_ids[0]))
118 for row in self.db_conn.row_where(self.checked_class.table_name,
120 retrieved = self.checked_class.from_table_row(self.db_conn, row)
121 self.assertEqual(obj, retrieved)
122 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
124 def check_versioned_from_table_row(self, attr_name: str,
125 type_: type) -> None:
126 """Test .from_table_row() reads versioned attributes from DB."""
127 owner = self.checked_class(None)
128 vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
129 attr = getattr(owner, attr_name)
132 owner.save(self.db_conn)
133 for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
134 retrieved = owner.__class__.from_table_row(self.db_conn, row)
135 attr = getattr(retrieved, attr_name)
136 self.assertEqual(sorted(attr.history.values()), vals)
138 def check_all(self) -> tuple[Any, Any, Any]:
140 # pylint: disable=not-callable
141 item1 = self.checked_class(self.default_ids[0])
142 item2 = self.checked_class(self.default_ids[1])
143 item3 = self.checked_class(self.default_ids[2])
144 # check pre-save .all() returns empty list
145 self.assertEqual(self.checked_class.all(self.db_conn), [])
146 # check that all() shows all saved, but no unsaved items
147 item1.save(self.db_conn)
148 item3.save(self.db_conn)
149 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
150 sorted([item1, item3]))
151 item2.save(self.db_conn)
152 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
153 sorted([item1, item2, item3]))
154 return item1, item2, item3
156 def check_singularity(self, defaulting_field: str,
157 non_default_value: Any, *args: Any) -> None:
158 """Test pointers made for single object keep pointing to it."""
159 id1 = self.default_ids[0]
160 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
161 obj.save(self.db_conn)
162 setattr(obj, defaulting_field, non_default_value)
163 retrieved = self.checked_class.by_id(self.db_conn, id1)
164 self.assertEqual(non_default_value,
165 getattr(retrieved, defaulting_field))
167 def check_versioned_singularity(self) -> None:
168 """Test singularity of VersionedAttributes on saving (with .title)."""
169 obj = self.checked_class(None) # pylint: disable=not-callable
170 obj.save(self.db_conn)
171 assert isinstance(obj.id_, int)
172 obj.title.set('named')
173 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
174 self.assertEqual(obj.title.history, retrieved.title.history)
176 def check_remove(self, *args: Any) -> None:
177 """Test .remove() effects on DB and cache."""
178 id_ = self.default_ids[0]
179 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
180 with self.assertRaises(HandledException):
181 obj.remove(self.db_conn)
182 obj.save(self.db_conn)
183 obj.remove(self.db_conn)
184 self.check_storage([])
187 class TestCaseWithServer(TestCaseWithDB):
188 """Module tests against our HTTP server/handler (and database)."""
190 def setUp(self) -> None:
192 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
193 self.server_thread = Thread(target=self.httpd.serve_forever)
194 self.server_thread.daemon = True
195 self.server_thread.start()
196 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
197 self.httpd.server_address[1])
199 def tearDown(self) -> None:
200 self.httpd.shutdown()
201 self.httpd.server_close()
202 self.server_thread.join()
205 def check_redirect(self, target: str) -> None:
206 """Check that self.conn answers with a 302 redirect to target."""
207 response = self.conn.getresponse()
208 self.assertEqual(response.status, 302)
209 self.assertEqual(response.getheader('Location'), target)
211 def check_get(self, target: str, expected_code: int) -> None:
212 """Check that a GET to target yields expected_code."""
213 self.conn.request('GET', target)
214 self.assertEqual(self.conn.getresponse().status, expected_code)
216 def check_post(self, data: Mapping[str, object], target: str,
217 expected_code: int, redirect_location: str = '') -> None:
218 """Check that POST of data to target yields expected_code."""
219 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
220 headers = {'Content-Type': 'application/x-www-form-urlencoded',
221 'Content-Length': str(len(encoded_form_data))}
222 self.conn.request('POST', target,
223 body=encoded_form_data, headers=headers)
224 if 302 == expected_code:
225 if redirect_location == '':
226 redirect_location = target
227 self.check_redirect(redirect_location)
229 self.assertEqual(self.conn.getresponse().status, expected_code)
231 def check_get_defaults(self, path: str) -> None:
232 """Some standard model paths to test."""
233 self.check_get(path, 200)
234 self.check_get(f'{path}?id=', 200)
235 self.check_get(f'{path}?id=foo', 400)
236 self.check_get(f'/{path}?id=0', 500)
237 self.check_get(f'{path}?id=1', 200)
239 def post_process(self, id_: int = 1,
240 form_data: dict[str, Any] | None = None
242 """POST basic Process."""
244 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
245 self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')