home · contact · privacy
6581c61a1546e952c08b2669d28efdc0e27a06ee
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from uuid import uuid4
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseSansDB(TestCase):
19     """Tests requiring no DB setup."""
20     checked_class: Any
21     do_id_test: bool = False
22     default_init_args: list[Any] = []
23     versioned_defaults_to_test: dict[str, str | float] = {}
24
25     def test_id_setting(self) -> None:
26         """Test .id_ being set and its legal range being enforced."""
27         if not self.do_id_test:
28             return
29         with self.assertRaises(HandledException):
30             self.checked_class(0, *self.default_init_args)
31         obj = self.checked_class(5, *self.default_init_args)
32         self.assertEqual(obj.id_, 5)
33
34     def test_versioned_defaults(self) -> None:
35         """Test defaults of VersionedAttributes."""
36         if len(self.versioned_defaults_to_test) == 0:
37             return
38         obj = self.checked_class(1, *self.default_init_args)
39         for k, v in self.versioned_defaults_to_test.items():
40             self.assertEqual(getattr(obj, k).newest, v)
41
42
43 class TestCaseWithDB(TestCase):
44     """Module tests not requiring DB setup."""
45     checked_class: Any
46     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47     default_init_kwargs: dict[str, Any] = {}
48     test_versioneds: dict[str, type] = {}
49
50     def setUp(self) -> None:
51         Condition.empty_cache()
52         Day.empty_cache()
53         Process.empty_cache()
54         ProcessStep.empty_cache()
55         Todo.empty_cache()
56         self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
57         self.db_conn = DatabaseConnection(self.db_file)
58
59     def tearDown(self) -> None:
60         self.db_conn.close()
61         remove_file(self.db_file.path)
62
63     def test_saving_and_caching(self) -> None:
64         """Test storage and initialization of instances and attributes."""
65         if not hasattr(self, 'checked_class'):
66             return
67         self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
68         obj = self.checked_class(None, **self.default_init_kwargs)
69         obj.save(self.db_conn)
70         self.assertEqual(obj.id_, 2)
71         for k, v in self.test_versioneds.items():
72             self.check_saving_of_versioned(k, v)
73
74     def check_storage(self, content: list[Any]) -> None:
75         """Test cache and DB equal content."""
76         expected_cache = {}
77         for item in content:
78             expected_cache[item.id_] = item
79         self.assertEqual(self.checked_class.get_cache(), expected_cache)
80         db_found: list[Any] = []
81         for item in content:
82             assert isinstance(item.id_, type(self.default_ids[0]))
83             for row in self.db_conn.row_where(self.checked_class.table_name,
84                                               'id', item.id_):
85                 db_found += [self.checked_class.from_table_row(self.db_conn,
86                                                                row)]
87         self.assertEqual(sorted(content), sorted(db_found))
88
89     def check_saving_and_caching(self, **kwargs: Any) -> None:
90         """Test instance.save in its core without relations."""
91         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
92         # check object init itself doesn't store anything yet
93         self.check_storage([])
94         # check saving stores in cache and DB
95         obj.save(self.db_conn)
96         self.check_storage([obj])
97         # check core attributes set properly (and not unset by saving)
98         for key, value in kwargs.items():
99             self.assertEqual(getattr(obj, key), value)
100
101     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
102         """Test owner's versioned attributes."""
103         owner = self.checked_class(None)
104         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
105         attr = getattr(owner, attr_name)
106         attr.set(vals[0])
107         attr.set(vals[1])
108         owner.save(self.db_conn)
109         owner.uncache()
110         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
111         attr = getattr(retrieved, attr_name)
112         self.assertEqual(sorted(attr.history.values()), vals)
113
114     def check_by_id(self) -> None:
115         """Test .by_id(), including creation."""
116         # check failure if not yet saved
117         id1, id2 = self.default_ids[0], self.default_ids[1]
118         obj = self.checked_class(id1)  # pylint: disable=not-callable
119         with self.assertRaises(NotFoundException):
120             self.checked_class.by_id(self.db_conn, id1)
121         # check identity of saved and retrieved
122         obj.save(self.db_conn)
123         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
124         # check create=True acts like normal instantiation (sans saving)
125         by_id_created = self.checked_class.by_id(self.db_conn, id2,
126                                                  create=True)
127         # pylint: disable=not-callable
128         self.assertEqual(self.checked_class(id2), by_id_created)
129         self.check_storage([obj])
130
131     def check_from_table_row(self, *args: Any) -> None:
132         """Test .from_table_row() properly reads in class from DB"""
133         id_ = self.default_ids[0]
134         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
135         obj.save(self.db_conn)
136         assert isinstance(obj.id_, type(self.default_ids[0]))
137         for row in self.db_conn.row_where(self.checked_class.table_name,
138                                           'id', obj.id_):
139             retrieved = self.checked_class.from_table_row(self.db_conn, row)
140             self.assertEqual(obj, retrieved)
141             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
142
143     def check_versioned_from_table_row(self, attr_name: str,
144                                        type_: type) -> None:
145         """Test .from_table_row() reads versioned attributes from DB."""
146         owner = self.checked_class(None)
147         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
148         attr = getattr(owner, attr_name)
149         attr.set(vals[0])
150         attr.set(vals[1])
151         owner.save(self.db_conn)
152         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
153             retrieved = owner.__class__.from_table_row(self.db_conn, row)
154             attr = getattr(retrieved, attr_name)
155             self.assertEqual(sorted(attr.history.values()), vals)
156
157     def check_all(self) -> tuple[Any, Any, Any]:
158         """Test .all()."""
159         # pylint: disable=not-callable
160         item1 = self.checked_class(self.default_ids[0])
161         item2 = self.checked_class(self.default_ids[1])
162         item3 = self.checked_class(self.default_ids[2])
163         # check pre-save .all() returns empty list
164         self.assertEqual(self.checked_class.all(self.db_conn), [])
165         # check that all() shows all saved, but no unsaved items
166         item1.save(self.db_conn)
167         item3.save(self.db_conn)
168         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
169                          sorted([item1, item3]))
170         item2.save(self.db_conn)
171         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
172                          sorted([item1, item2, item3]))
173         return item1, item2, item3
174
175     def check_singularity(self, defaulting_field: str,
176                           non_default_value: Any, *args: Any) -> None:
177         """Test pointers made for single object keep pointing to it."""
178         id1 = self.default_ids[0]
179         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
180         obj.save(self.db_conn)
181         setattr(obj, defaulting_field, non_default_value)
182         retrieved = self.checked_class.by_id(self.db_conn, id1)
183         self.assertEqual(non_default_value,
184                          getattr(retrieved, defaulting_field))
185
186     def check_versioned_singularity(self) -> None:
187         """Test singularity of VersionedAttributes on saving (with .title)."""
188         obj = self.checked_class(None)  # pylint: disable=not-callable
189         obj.save(self.db_conn)
190         assert isinstance(obj.id_, int)
191         obj.title.set('named')
192         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
193         self.assertEqual(obj.title.history, retrieved.title.history)
194
195     def check_remove(self, *args: Any) -> None:
196         """Test .remove() effects on DB and cache."""
197         id_ = self.default_ids[0]
198         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
199         with self.assertRaises(HandledException):
200             obj.remove(self.db_conn)
201         obj.save(self.db_conn)
202         obj.remove(self.db_conn)
203         self.check_storage([])
204
205
206 class TestCaseWithServer(TestCaseWithDB):
207     """Module tests against our HTTP server/handler (and database)."""
208
209     def setUp(self) -> None:
210         super().setUp()
211         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
212         self.server_thread = Thread(target=self.httpd.serve_forever)
213         self.server_thread.daemon = True
214         self.server_thread.start()
215         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
216                                    self.httpd.server_address[1])
217
218     def tearDown(self) -> None:
219         self.httpd.shutdown()
220         self.httpd.server_close()
221         self.server_thread.join()
222         super().tearDown()
223
224     def check_redirect(self, target: str) -> None:
225         """Check that self.conn answers with a 302 redirect to target."""
226         response = self.conn.getresponse()
227         self.assertEqual(response.status, 302)
228         self.assertEqual(response.getheader('Location'), target)
229
230     def check_get(self, target: str, expected_code: int) -> None:
231         """Check that a GET to target yields expected_code."""
232         self.conn.request('GET', target)
233         self.assertEqual(self.conn.getresponse().status, expected_code)
234
235     def check_post(self, data: Mapping[str, object], target: str,
236                    expected_code: int, redirect_location: str = '') -> None:
237         """Check that POST of data to target yields expected_code."""
238         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
239         headers = {'Content-Type': 'application/x-www-form-urlencoded',
240                    'Content-Length': str(len(encoded_form_data))}
241         self.conn.request('POST', target,
242                           body=encoded_form_data, headers=headers)
243         if 302 == expected_code:
244             if redirect_location == '':
245                 redirect_location = target
246             self.check_redirect(redirect_location)
247         else:
248             self.assertEqual(self.conn.getresponse().status, expected_code)
249
250     def check_get_defaults(self, path: str) -> None:
251         """Some standard model paths to test."""
252         self.check_get(path, 200)
253         self.check_get(f'{path}?id=', 200)
254         self.check_get(f'{path}?id=foo', 400)
255         self.check_get(f'/{path}?id=0', 500)
256         self.check_get(f'{path}?id=1', 200)
257
258     def post_process(self, id_: int = 1,
259                      form_data: dict[str, Any] | None = None
260                      ) -> dict[str, Any]:
261         """POST basic Process."""
262         if not form_data:
263             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
264         self.check_post(form_data, f'/process?id={id_}', 302,
265                         f'/process?id={id_}')
266         return form_data