1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseSansDB(TestCase):
19 """Tests requiring no DB setup."""
22 def check_id_setting(self) -> None:
23 """Test .id_ being set and its legal range being enforced."""
24 with self.assertRaises(HandledException):
26 obj = self.checked_class(5)
27 self.assertEqual(obj.id_, 5)
29 def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
30 """Test defaults of VersionedAttributes."""
31 obj = self.checked_class(None)
32 for k, v in attrs.items():
33 self.assertEqual(getattr(obj, k).newest, v)
36 class TestCaseWithDB(TestCase):
37 """Module tests not requiring DB setup."""
39 default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
41 def setUp(self) -> None:
42 Condition.empty_cache()
45 ProcessStep.empty_cache()
47 timestamp = datetime.now().timestamp()
48 self.db_file = DatabaseFile(f'test_db:{timestamp}')
50 self.db_conn = DatabaseConnection(self.db_file)
52 def tearDown(self) -> None:
54 remove_file(self.db_file.path)
56 def check_storage(self, content: list[Any]) -> None:
57 """Test cache and DB equal content."""
60 expected_cache[item.id_] = item
61 self.assertEqual(self.checked_class.get_cache(), expected_cache)
62 db_found: list[Any] = []
64 assert isinstance(item.id_, type(self.default_ids[0]))
65 for row in self.db_conn.row_where(self.checked_class.table_name,
67 db_found += [self.checked_class.from_table_row(self.db_conn,
69 self.assertEqual(sorted(content), sorted(db_found))
71 def check_saving_and_caching(self, **kwargs: Any) -> Any:
72 """Test instance.save in its core without relations."""
73 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
74 # check object init itself doesn't store anything yet
75 self.check_storage([])
76 # check saving stores in cache and DB
77 obj.save(self.db_conn)
78 self.check_storage([obj])
79 # check core attributes set properly (and not unset by saving)
80 for key, value in kwargs.items():
81 self.assertEqual(getattr(obj, key), value)
83 def check_by_id(self) -> None:
84 """Test .by_id(), including creation."""
85 # check failure if not yet saved
86 id1, id2 = self.default_ids[0], self.default_ids[1]
87 obj = self.checked_class(id1) # pylint: disable=not-callable
88 with self.assertRaises(NotFoundException):
89 self.checked_class.by_id(self.db_conn, id1)
90 # check identity of saved and retrieved
91 obj.save(self.db_conn)
92 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
93 # check create=True acts like normal instantiation (sans saving)
94 by_id_created = self.checked_class.by_id(self.db_conn, id2,
96 # pylint: disable=not-callable
97 self.assertEqual(self.checked_class(id2), by_id_created)
98 self.check_storage([obj])
100 def check_from_table_row(self, *args: Any) -> None:
101 """Test .from_table_row() properly reads in class from DB"""
102 id_ = self.default_ids[0]
103 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
104 obj.save(self.db_conn)
105 assert isinstance(obj.id_, type(self.default_ids[0]))
106 for row in self.db_conn.row_where(self.checked_class.table_name,
108 retrieved = self.checked_class.from_table_row(self.db_conn, row)
109 self.assertEqual(obj, retrieved)
110 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
112 def check_all(self) -> tuple[Any, Any, Any]:
114 # pylint: disable=not-callable
115 item1 = self.checked_class(self.default_ids[0])
116 item2 = self.checked_class(self.default_ids[1])
117 item3 = self.checked_class(self.default_ids[2])
118 # check pre-save .all() returns empty list
119 self.assertEqual(self.checked_class.all(self.db_conn), [])
120 # check that all() shows all saved, but no unsaved items
121 item1.save(self.db_conn)
122 item3.save(self.db_conn)
123 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
124 sorted([item1, item3]))
125 item2.save(self.db_conn)
126 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
127 sorted([item1, item2, item3]))
128 return item1, item2, item3
130 def check_singularity(self, defaulting_field: str,
131 non_default_value: Any, *args: Any) -> None:
132 """Test pointers made for single object keep pointing to it."""
133 id1 = self.default_ids[0]
134 obj = self.checked_class(id1, *args) # pylint: disable=not-callable
135 obj.save(self.db_conn)
136 setattr(obj, defaulting_field, non_default_value)
137 retrieved = self.checked_class.by_id(self.db_conn, id1)
138 self.assertEqual(non_default_value,
139 getattr(retrieved, defaulting_field))
141 def check_versioned_singularity(self) -> None:
142 """Test singularity of VersionedAttributes on saving (with .title)."""
143 obj = self.checked_class(None) # pylint: disable=not-callable
144 obj.save(self.db_conn)
145 assert isinstance(obj.id_, int)
146 obj.title.set('named')
147 retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
148 self.assertEqual(obj.title.history, retrieved.title.history)
150 def check_remove(self, *args: Any) -> None:
151 """Test .remove() effects on DB and cache."""
152 id_ = self.default_ids[0]
153 obj = self.checked_class(id_, *args) # pylint: disable=not-callable
154 with self.assertRaises(HandledException):
155 obj.remove(self.db_conn)
156 obj.save(self.db_conn)
157 obj.remove(self.db_conn)
158 self.check_storage([])
161 class TestCaseWithServer(TestCaseWithDB):
162 """Module tests against our HTTP server/handler (and database)."""
164 def setUp(self) -> None:
166 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
167 self.server_thread = Thread(target=self.httpd.serve_forever)
168 self.server_thread.daemon = True
169 self.server_thread.start()
170 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
171 self.httpd.server_address[1])
173 def tearDown(self) -> None:
174 self.httpd.shutdown()
175 self.httpd.server_close()
176 self.server_thread.join()
179 def check_redirect(self, target: str) -> None:
180 """Check that self.conn answers with a 302 redirect to target."""
181 response = self.conn.getresponse()
182 self.assertEqual(response.status, 302)
183 self.assertEqual(response.getheader('Location'), target)
185 def check_get(self, target: str, expected_code: int) -> None:
186 """Check that a GET to target yields expected_code."""
187 self.conn.request('GET', target)
188 self.assertEqual(self.conn.getresponse().status, expected_code)
190 def check_post(self, data: Mapping[str, object], target: str,
191 expected_code: int, redirect_location: str = '') -> None:
192 """Check that POST of data to target yields expected_code."""
193 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
194 headers = {'Content-Type': 'application/x-www-form-urlencoded',
195 'Content-Length': str(len(encoded_form_data))}
196 self.conn.request('POST', target,
197 body=encoded_form_data, headers=headers)
198 if 302 == expected_code:
199 if redirect_location == '':
200 redirect_location = target
201 self.check_redirect(redirect_location)
203 self.assertEqual(self.conn.getresponse().status, expected_code)
205 def check_get_defaults(self, path: str) -> None:
206 """Some standard model paths to test."""
207 self.check_get(path, 200)
208 self.check_get(f'{path}?id=', 200)
209 self.check_get(f'{path}?id=foo', 400)
210 self.check_get(f'/{path}?id=0', 500)
211 self.check_get(f'{path}?id=1', 200)
213 def post_process(self, id_: int = 1,
214 form_data: dict[str, Any] | None = None
216 """POST basic Process."""
218 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
219 self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')