home · contact · privacy
Refactor tests.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseSansDB(TestCase):
19     """Tests requiring no DB setup."""
20     checked_class: Any
21     do_id_test: bool = False
22     default_init_args: list[Any] = []
23     versioned_defaults_to_test: dict[str, str | float] = {}
24
25     def test_id_setting(self) -> None:
26         """Test .id_ being set and its legal range being enforced."""
27         if not self.do_id_test:
28             return
29         with self.assertRaises(HandledException):
30             self.checked_class(0, *self.default_init_args)
31         obj = self.checked_class(5, *self.default_init_args)
32         self.assertEqual(obj.id_, 5)
33
34     def test_versioned_defaults(self) -> None:
35         """Test defaults of VersionedAttributes."""
36         if len(self.versioned_defaults_to_test) == 0:
37             return
38         obj = self.checked_class(1, *self.default_init_args)
39         for k, v in self.versioned_defaults_to_test.items():
40             self.assertEqual(getattr(obj, k).newest, v)
41
42
43 class TestCaseWithDB(TestCase):
44     """Module tests not requiring DB setup."""
45     checked_class: Any
46     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
47
48     def setUp(self) -> None:
49         Condition.empty_cache()
50         Day.empty_cache()
51         Process.empty_cache()
52         ProcessStep.empty_cache()
53         Todo.empty_cache()
54         timestamp = datetime.now().timestamp()
55         self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
56         self.db_conn = DatabaseConnection(self.db_file)
57
58     def tearDown(self) -> None:
59         self.db_conn.close()
60         remove_file(self.db_file.path)
61
62     def check_storage(self, content: list[Any]) -> None:
63         """Test cache and DB equal content."""
64         expected_cache = {}
65         for item in content:
66             expected_cache[item.id_] = item
67         self.assertEqual(self.checked_class.get_cache(), expected_cache)
68         db_found: list[Any] = []
69         for item in content:
70             assert isinstance(item.id_, type(self.default_ids[0]))
71             for row in self.db_conn.row_where(self.checked_class.table_name,
72                                               'id', item.id_):
73                 db_found += [self.checked_class.from_table_row(self.db_conn,
74                                                                row)]
75         self.assertEqual(sorted(content), sorted(db_found))
76
77     def check_saving_and_caching(self, **kwargs: Any) -> None:
78         """Test instance.save in its core without relations."""
79         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
80         # check object init itself doesn't store anything yet
81         self.check_storage([])
82         # check saving stores in cache and DB
83         obj.save(self.db_conn)
84         self.check_storage([obj])
85         # check core attributes set properly (and not unset by saving)
86         for key, value in kwargs.items():
87             self.assertEqual(getattr(obj, key), value)
88
89     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
90         """Test owner's versioned attributes."""
91         owner = self.checked_class(None)
92         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
93         attr = getattr(owner, attr_name)
94         attr.set(vals[0])
95         attr.set(vals[1])
96         owner.save(self.db_conn)
97         owner.uncache()
98         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
99         attr = getattr(retrieved, attr_name)
100         self.assertEqual(sorted(attr.history.values()), vals)
101
102     def check_by_id(self) -> None:
103         """Test .by_id(), including creation."""
104         # check failure if not yet saved
105         id1, id2 = self.default_ids[0], self.default_ids[1]
106         obj = self.checked_class(id1)  # pylint: disable=not-callable
107         with self.assertRaises(NotFoundException):
108             self.checked_class.by_id(self.db_conn, id1)
109         # check identity of saved and retrieved
110         obj.save(self.db_conn)
111         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
112         # check create=True acts like normal instantiation (sans saving)
113         by_id_created = self.checked_class.by_id(self.db_conn, id2,
114                                                  create=True)
115         # pylint: disable=not-callable
116         self.assertEqual(self.checked_class(id2), by_id_created)
117         self.check_storage([obj])
118
119     def check_from_table_row(self, *args: Any) -> None:
120         """Test .from_table_row() properly reads in class from DB"""
121         id_ = self.default_ids[0]
122         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
123         obj.save(self.db_conn)
124         assert isinstance(obj.id_, type(self.default_ids[0]))
125         for row in self.db_conn.row_where(self.checked_class.table_name,
126                                           'id', obj.id_):
127             retrieved = self.checked_class.from_table_row(self.db_conn, row)
128             self.assertEqual(obj, retrieved)
129             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
130
131     def check_versioned_from_table_row(self, attr_name: str,
132                                        type_: type) -> None:
133         """Test .from_table_row() reads versioned attributes from DB."""
134         owner = self.checked_class(None)
135         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
136         attr = getattr(owner, attr_name)
137         attr.set(vals[0])
138         attr.set(vals[1])
139         owner.save(self.db_conn)
140         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
141             retrieved = owner.__class__.from_table_row(self.db_conn, row)
142             attr = getattr(retrieved, attr_name)
143             self.assertEqual(sorted(attr.history.values()), vals)
144
145     def check_all(self) -> tuple[Any, Any, Any]:
146         """Test .all()."""
147         # pylint: disable=not-callable
148         item1 = self.checked_class(self.default_ids[0])
149         item2 = self.checked_class(self.default_ids[1])
150         item3 = self.checked_class(self.default_ids[2])
151         # check pre-save .all() returns empty list
152         self.assertEqual(self.checked_class.all(self.db_conn), [])
153         # check that all() shows all saved, but no unsaved items
154         item1.save(self.db_conn)
155         item3.save(self.db_conn)
156         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
157                          sorted([item1, item3]))
158         item2.save(self.db_conn)
159         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
160                          sorted([item1, item2, item3]))
161         return item1, item2, item3
162
163     def check_singularity(self, defaulting_field: str,
164                           non_default_value: Any, *args: Any) -> None:
165         """Test pointers made for single object keep pointing to it."""
166         id1 = self.default_ids[0]
167         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
168         obj.save(self.db_conn)
169         setattr(obj, defaulting_field, non_default_value)
170         retrieved = self.checked_class.by_id(self.db_conn, id1)
171         self.assertEqual(non_default_value,
172                          getattr(retrieved, defaulting_field))
173
174     def check_versioned_singularity(self) -> None:
175         """Test singularity of VersionedAttributes on saving (with .title)."""
176         obj = self.checked_class(None)  # pylint: disable=not-callable
177         obj.save(self.db_conn)
178         assert isinstance(obj.id_, int)
179         obj.title.set('named')
180         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
181         self.assertEqual(obj.title.history, retrieved.title.history)
182
183     def check_remove(self, *args: Any) -> None:
184         """Test .remove() effects on DB and cache."""
185         id_ = self.default_ids[0]
186         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
187         with self.assertRaises(HandledException):
188             obj.remove(self.db_conn)
189         obj.save(self.db_conn)
190         obj.remove(self.db_conn)
191         self.check_storage([])
192
193
194 class TestCaseWithServer(TestCaseWithDB):
195     """Module tests against our HTTP server/handler (and database)."""
196
197     def setUp(self) -> None:
198         super().setUp()
199         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
200         self.server_thread = Thread(target=self.httpd.serve_forever)
201         self.server_thread.daemon = True
202         self.server_thread.start()
203         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
204                                    self.httpd.server_address[1])
205
206     def tearDown(self) -> None:
207         self.httpd.shutdown()
208         self.httpd.server_close()
209         self.server_thread.join()
210         super().tearDown()
211
212     def check_redirect(self, target: str) -> None:
213         """Check that self.conn answers with a 302 redirect to target."""
214         response = self.conn.getresponse()
215         self.assertEqual(response.status, 302)
216         self.assertEqual(response.getheader('Location'), target)
217
218     def check_get(self, target: str, expected_code: int) -> None:
219         """Check that a GET to target yields expected_code."""
220         self.conn.request('GET', target)
221         self.assertEqual(self.conn.getresponse().status, expected_code)
222
223     def check_post(self, data: Mapping[str, object], target: str,
224                    expected_code: int, redirect_location: str = '') -> None:
225         """Check that POST of data to target yields expected_code."""
226         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
227         headers = {'Content-Type': 'application/x-www-form-urlencoded',
228                    'Content-Length': str(len(encoded_form_data))}
229         self.conn.request('POST', target,
230                           body=encoded_form_data, headers=headers)
231         if 302 == expected_code:
232             if redirect_location == '':
233                 redirect_location = target
234             self.check_redirect(redirect_location)
235         else:
236             self.assertEqual(self.conn.getresponse().status, expected_code)
237
238     def check_get_defaults(self, path: str) -> None:
239         """Some standard model paths to test."""
240         self.check_get(path, 200)
241         self.check_get(f'{path}?id=', 200)
242         self.check_get(f'{path}?id=foo', 400)
243         self.check_get(f'/{path}?id=0', 500)
244         self.check_get(f'{path}?id=1', 200)
245
246     def post_process(self, id_: int = 1,
247                      form_data: dict[str, Any] | None = None
248                      ) -> dict[str, Any]:
249         """POST basic Process."""
250         if not form_data:
251             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
252         self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
253         return form_data