home · contact · privacy
Fix minor ProcessStep POST handling bugs.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseSansDB(TestCase):
19     """Tests requiring no DB setup."""
20     checked_class: Any
21     do_id_test: bool = False
22     default_init_args: list[Any] = []
23     versioned_defaults_to_test: dict[str, str | float] = {}
24
25     def test_id_setting(self) -> None:
26         """Test .id_ being set and its legal range being enforced."""
27         if not self.do_id_test:
28             return
29         with self.assertRaises(HandledException):
30             self.checked_class(0, *self.default_init_args)
31         obj = self.checked_class(5, *self.default_init_args)
32         self.assertEqual(obj.id_, 5)
33
34     def test_versioned_defaults(self) -> None:
35         """Test defaults of VersionedAttributes."""
36         if len(self.versioned_defaults_to_test) == 0:
37             return
38         obj = self.checked_class(1, *self.default_init_args)
39         for k, v in self.versioned_defaults_to_test.items():
40             self.assertEqual(getattr(obj, k).newest, v)
41
42
43 class TestCaseWithDB(TestCase):
44     """Module tests not requiring DB setup."""
45     checked_class: Any
46     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47     default_init_kwargs: dict[str, Any] = {}
48     test_versioneds: dict[str, type] = {}
49
50     def setUp(self) -> None:
51         Condition.empty_cache()
52         Day.empty_cache()
53         Process.empty_cache()
54         ProcessStep.empty_cache()
55         Todo.empty_cache()
56         timestamp = datetime.now().timestamp()
57         self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
58         self.db_conn = DatabaseConnection(self.db_file)
59
60     def tearDown(self) -> None:
61         self.db_conn.close()
62         remove_file(self.db_file.path)
63
64     def test_saving_and_caching(self) -> None:
65         """Test storage and initialization of instances and attributes."""
66         if not hasattr(self, 'checked_class'):
67             return
68         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
69         obj = self.checked_class(None, **self.default_init_kwargs)
70         obj.save(self.db_conn)
71         self.assertEqual(obj.id_, 2)
72         for k, v in self.test_versioneds.items():
73             self.check_saving_of_versioned(k, v)
74
75     def check_storage(self, content: list[Any]) -> None:
76         """Test cache and DB equal content."""
77         expected_cache = {}
78         for item in content:
79             expected_cache[item.id_] = item
80         self.assertEqual(self.checked_class.get_cache(), expected_cache)
81         db_found: list[Any] = []
82         for item in content:
83             assert isinstance(item.id_, type(self.default_ids[0]))
84             for row in self.db_conn.row_where(self.checked_class.table_name,
85                                               'id', item.id_):
86                 db_found += [self.checked_class.from_table_row(self.db_conn,
87                                                                row)]
88         self.assertEqual(sorted(content), sorted(db_found))
89
90     def check_saving_and_caching(self, **kwargs: Any) -> None:
91         """Test instance.save in its core without relations."""
92         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
93         # check object init itself doesn't store anything yet
94         self.check_storage([])
95         # check saving stores in cache and DB
96         obj.save(self.db_conn)
97         self.check_storage([obj])
98         # check core attributes set properly (and not unset by saving)
99         for key, value in kwargs.items():
100             self.assertEqual(getattr(obj, key), value)
101
102     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
103         """Test owner's versioned attributes."""
104         owner = self.checked_class(None)
105         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
106         attr = getattr(owner, attr_name)
107         attr.set(vals[0])
108         attr.set(vals[1])
109         owner.save(self.db_conn)
110         owner.uncache()
111         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
112         attr = getattr(retrieved, attr_name)
113         self.assertEqual(sorted(attr.history.values()), vals)
114
115     def check_by_id(self) -> None:
116         """Test .by_id(), including creation."""
117         # check failure if not yet saved
118         id1, id2 = self.default_ids[0], self.default_ids[1]
119         obj = self.checked_class(id1)  # pylint: disable=not-callable
120         with self.assertRaises(NotFoundException):
121             self.checked_class.by_id(self.db_conn, id1)
122         # check identity of saved and retrieved
123         obj.save(self.db_conn)
124         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
125         # check create=True acts like normal instantiation (sans saving)
126         by_id_created = self.checked_class.by_id(self.db_conn, id2,
127                                                  create=True)
128         # pylint: disable=not-callable
129         self.assertEqual(self.checked_class(id2), by_id_created)
130         self.check_storage([obj])
131
132     def check_from_table_row(self, *args: Any) -> None:
133         """Test .from_table_row() properly reads in class from DB"""
134         id_ = self.default_ids[0]
135         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
136         obj.save(self.db_conn)
137         assert isinstance(obj.id_, type(self.default_ids[0]))
138         for row in self.db_conn.row_where(self.checked_class.table_name,
139                                           'id', obj.id_):
140             retrieved = self.checked_class.from_table_row(self.db_conn, row)
141             self.assertEqual(obj, retrieved)
142             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
143
144     def check_versioned_from_table_row(self, attr_name: str,
145                                        type_: type) -> None:
146         """Test .from_table_row() reads versioned attributes from DB."""
147         owner = self.checked_class(None)
148         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
149         attr = getattr(owner, attr_name)
150         attr.set(vals[0])
151         attr.set(vals[1])
152         owner.save(self.db_conn)
153         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
154             retrieved = owner.__class__.from_table_row(self.db_conn, row)
155             attr = getattr(retrieved, attr_name)
156             self.assertEqual(sorted(attr.history.values()), vals)
157
158     def check_all(self) -> tuple[Any, Any, Any]:
159         """Test .all()."""
160         # pylint: disable=not-callable
161         item1 = self.checked_class(self.default_ids[0])
162         item2 = self.checked_class(self.default_ids[1])
163         item3 = self.checked_class(self.default_ids[2])
164         # check pre-save .all() returns empty list
165         self.assertEqual(self.checked_class.all(self.db_conn), [])
166         # check that all() shows all saved, but no unsaved items
167         item1.save(self.db_conn)
168         item3.save(self.db_conn)
169         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
170                          sorted([item1, item3]))
171         item2.save(self.db_conn)
172         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
173                          sorted([item1, item2, item3]))
174         return item1, item2, item3
175
176     def check_singularity(self, defaulting_field: str,
177                           non_default_value: Any, *args: Any) -> None:
178         """Test pointers made for single object keep pointing to it."""
179         id1 = self.default_ids[0]
180         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
181         obj.save(self.db_conn)
182         setattr(obj, defaulting_field, non_default_value)
183         retrieved = self.checked_class.by_id(self.db_conn, id1)
184         self.assertEqual(non_default_value,
185                          getattr(retrieved, defaulting_field))
186
187     def check_versioned_singularity(self) -> None:
188         """Test singularity of VersionedAttributes on saving (with .title)."""
189         obj = self.checked_class(None)  # pylint: disable=not-callable
190         obj.save(self.db_conn)
191         assert isinstance(obj.id_, int)
192         obj.title.set('named')
193         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
194         self.assertEqual(obj.title.history, retrieved.title.history)
195
196     def check_remove(self, *args: Any) -> None:
197         """Test .remove() effects on DB and cache."""
198         id_ = self.default_ids[0]
199         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
200         with self.assertRaises(HandledException):
201             obj.remove(self.db_conn)
202         obj.save(self.db_conn)
203         obj.remove(self.db_conn)
204         self.check_storage([])
205
206
207 class TestCaseWithServer(TestCaseWithDB):
208     """Module tests against our HTTP server/handler (and database)."""
209
210     def setUp(self) -> None:
211         super().setUp()
212         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
213         self.server_thread = Thread(target=self.httpd.serve_forever)
214         self.server_thread.daemon = True
215         self.server_thread.start()
216         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
217                                    self.httpd.server_address[1])
218
219     def tearDown(self) -> None:
220         self.httpd.shutdown()
221         self.httpd.server_close()
222         self.server_thread.join()
223         super().tearDown()
224
225     def check_redirect(self, target: str) -> None:
226         """Check that self.conn answers with a 302 redirect to target."""
227         response = self.conn.getresponse()
228         self.assertEqual(response.status, 302)
229         self.assertEqual(response.getheader('Location'), target)
230
231     def check_get(self, target: str, expected_code: int) -> None:
232         """Check that a GET to target yields expected_code."""
233         self.conn.request('GET', target)
234         self.assertEqual(self.conn.getresponse().status, expected_code)
235
236     def check_post(self, data: Mapping[str, object], target: str,
237                    expected_code: int, redirect_location: str = '') -> None:
238         """Check that POST of data to target yields expected_code."""
239         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
240         headers = {'Content-Type': 'application/x-www-form-urlencoded',
241                    'Content-Length': str(len(encoded_form_data))}
242         self.conn.request('POST', target,
243                           body=encoded_form_data, headers=headers)
244         if 302 == expected_code:
245             if redirect_location == '':
246                 redirect_location = target
247             self.check_redirect(redirect_location)
248         else:
249             self.assertEqual(self.conn.getresponse().status, expected_code)
250
251     def check_get_defaults(self, path: str) -> None:
252         """Some standard model paths to test."""
253         self.check_get(path, 200)
254         self.check_get(f'{path}?id=', 200)
255         self.check_get(f'{path}?id=foo', 400)
256         self.check_get(f'/{path}?id=0', 500)
257         self.check_get(f'{path}?id=1', 200)
258
259     def post_process(self, id_: int = 1,
260                      form_data: dict[str, Any] | None = None
261                      ) -> dict[str, Any]:
262         """POST basic Process."""
263         if not form_data:
264             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
265         self.check_post(form_data, f'/process?id={id_}', 302,
266                         f'/process?id={id_}')
267         return form_data