home · contact · privacy
fb7e22746a96482e04153adff90a4bfa086bd58c
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseSansDB(TestCase):
19     """Tests requiring no DB setup."""
20     checked_class: Any
21
22     def check_id_setting(self, *args: Any) -> None:
23         """Test .id_ being set and its legal range being enforced."""
24         with self.assertRaises(HandledException):
25             self.checked_class(0, *args)
26         obj = self.checked_class(5, *args)
27         self.assertEqual(obj.id_, 5)
28
29     def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
30         """Test defaults of VersionedAttributes."""
31         obj = self.checked_class(None)
32         for k, v in attrs.items():
33             self.assertEqual(getattr(obj, k).newest, v)
34
35
36 class TestCaseWithDB(TestCase):
37     """Module tests not requiring DB setup."""
38     checked_class: Any
39     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
40
41     def setUp(self) -> None:
42         Condition.empty_cache()
43         Day.empty_cache()
44         Process.empty_cache()
45         ProcessStep.empty_cache()
46         Todo.empty_cache()
47         timestamp = datetime.now().timestamp()
48         self.db_file = DatabaseFile(f'test_db:{timestamp}')
49         self.db_file.remake()
50         self.db_conn = DatabaseConnection(self.db_file)
51
52     def tearDown(self) -> None:
53         self.db_conn.close()
54         remove_file(self.db_file.path)
55
56     def check_storage(self, content: list[Any]) -> None:
57         """Test cache and DB equal content."""
58         expected_cache = {}
59         for item in content:
60             expected_cache[item.id_] = item
61         self.assertEqual(self.checked_class.get_cache(), expected_cache)
62         db_found: list[Any] = []
63         for item in content:
64             assert isinstance(item.id_, type(self.default_ids[0]))
65             for row in self.db_conn.row_where(self.checked_class.table_name,
66                                               'id', item.id_):
67                 db_found += [self.checked_class.from_table_row(self.db_conn,
68                                                                row)]
69         self.assertEqual(sorted(content), sorted(db_found))
70
71     def check_saving_and_caching(self, **kwargs: Any) -> None:
72         """Test instance.save in its core without relations."""
73         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
74         # check object init itself doesn't store anything yet
75         self.check_storage([])
76         # check saving stores in cache and DB
77         obj.save(self.db_conn)
78         self.check_storage([obj])
79         # check core attributes set properly (and not unset by saving)
80         for key, value in kwargs.items():
81             self.assertEqual(getattr(obj, key), value)
82
83     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
84         """Test owner's versioned attributes."""
85         owner = self.checked_class(None)
86         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
87         attr = getattr(owner, attr_name)
88         attr.set(vals[0])
89         attr.set(vals[1])
90         owner.save(self.db_conn)
91         owner.uncache()
92         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
93         attr = getattr(retrieved, attr_name)
94         self.assertEqual(sorted(attr.history.values()), vals)
95
96     def check_by_id(self) -> None:
97         """Test .by_id(), including creation."""
98         # check failure if not yet saved
99         id1, id2 = self.default_ids[0], self.default_ids[1]
100         obj = self.checked_class(id1)  # pylint: disable=not-callable
101         with self.assertRaises(NotFoundException):
102             self.checked_class.by_id(self.db_conn, id1)
103         # check identity of saved and retrieved
104         obj.save(self.db_conn)
105         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
106         # check create=True acts like normal instantiation (sans saving)
107         by_id_created = self.checked_class.by_id(self.db_conn, id2,
108                                                  create=True)
109         # pylint: disable=not-callable
110         self.assertEqual(self.checked_class(id2), by_id_created)
111         self.check_storage([obj])
112
113     def check_from_table_row(self, *args: Any) -> None:
114         """Test .from_table_row() properly reads in class from DB"""
115         id_ = self.default_ids[0]
116         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
117         obj.save(self.db_conn)
118         assert isinstance(obj.id_, type(self.default_ids[0]))
119         for row in self.db_conn.row_where(self.checked_class.table_name,
120                                           'id', obj.id_):
121             retrieved = self.checked_class.from_table_row(self.db_conn, row)
122             self.assertEqual(obj, retrieved)
123             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
124
125     def check_versioned_from_table_row(self, attr_name: str,
126                                        type_: type) -> None:
127         """Test .from_table_row() reads versioned attributes from DB."""
128         owner = self.checked_class(None)
129         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
130         attr = getattr(owner, attr_name)
131         attr.set(vals[0])
132         attr.set(vals[1])
133         owner.save(self.db_conn)
134         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
135             retrieved = owner.__class__.from_table_row(self.db_conn, row)
136             attr = getattr(retrieved, attr_name)
137             self.assertEqual(sorted(attr.history.values()), vals)
138
139     def check_all(self) -> tuple[Any, Any, Any]:
140         """Test .all()."""
141         # pylint: disable=not-callable
142         item1 = self.checked_class(self.default_ids[0])
143         item2 = self.checked_class(self.default_ids[1])
144         item3 = self.checked_class(self.default_ids[2])
145         # check pre-save .all() returns empty list
146         self.assertEqual(self.checked_class.all(self.db_conn), [])
147         # check that all() shows all saved, but no unsaved items
148         item1.save(self.db_conn)
149         item3.save(self.db_conn)
150         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
151                          sorted([item1, item3]))
152         item2.save(self.db_conn)
153         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
154                          sorted([item1, item2, item3]))
155         return item1, item2, item3
156
157     def check_singularity(self, defaulting_field: str,
158                           non_default_value: Any, *args: Any) -> None:
159         """Test pointers made for single object keep pointing to it."""
160         id1 = self.default_ids[0]
161         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
162         obj.save(self.db_conn)
163         setattr(obj, defaulting_field, non_default_value)
164         retrieved = self.checked_class.by_id(self.db_conn, id1)
165         self.assertEqual(non_default_value,
166                          getattr(retrieved, defaulting_field))
167
168     def check_versioned_singularity(self) -> None:
169         """Test singularity of VersionedAttributes on saving (with .title)."""
170         obj = self.checked_class(None)  # pylint: disable=not-callable
171         obj.save(self.db_conn)
172         assert isinstance(obj.id_, int)
173         obj.title.set('named')
174         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
175         self.assertEqual(obj.title.history, retrieved.title.history)
176
177     def check_remove(self, *args: Any) -> None:
178         """Test .remove() effects on DB and cache."""
179         id_ = self.default_ids[0]
180         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
181         with self.assertRaises(HandledException):
182             obj.remove(self.db_conn)
183         obj.save(self.db_conn)
184         obj.remove(self.db_conn)
185         self.check_storage([])
186
187
188 class TestCaseWithServer(TestCaseWithDB):
189     """Module tests against our HTTP server/handler (and database)."""
190
191     def setUp(self) -> None:
192         super().setUp()
193         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
194         self.server_thread = Thread(target=self.httpd.serve_forever)
195         self.server_thread.daemon = True
196         self.server_thread.start()
197         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
198                                    self.httpd.server_address[1])
199
200     def tearDown(self) -> None:
201         self.httpd.shutdown()
202         self.httpd.server_close()
203         self.server_thread.join()
204         super().tearDown()
205
206     def check_redirect(self, target: str) -> None:
207         """Check that self.conn answers with a 302 redirect to target."""
208         response = self.conn.getresponse()
209         self.assertEqual(response.status, 302)
210         self.assertEqual(response.getheader('Location'), target)
211
212     def check_get(self, target: str, expected_code: int) -> None:
213         """Check that a GET to target yields expected_code."""
214         self.conn.request('GET', target)
215         self.assertEqual(self.conn.getresponse().status, expected_code)
216
217     def check_post(self, data: Mapping[str, object], target: str,
218                    expected_code: int, redirect_location: str = '') -> None:
219         """Check that POST of data to target yields expected_code."""
220         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
221         headers = {'Content-Type': 'application/x-www-form-urlencoded',
222                    'Content-Length': str(len(encoded_form_data))}
223         self.conn.request('POST', target,
224                           body=encoded_form_data, headers=headers)
225         if 302 == expected_code:
226             if redirect_location == '':
227                 redirect_location = target
228             self.check_redirect(redirect_location)
229         else:
230             self.assertEqual(self.conn.getresponse().status, expected_code)
231
232     def check_get_defaults(self, path: str) -> None:
233         """Some standard model paths to test."""
234         self.check_get(path, 200)
235         self.check_get(f'{path}?id=', 200)
236         self.check_get(f'{path}?id=foo', 400)
237         self.check_get(f'/{path}?id=0', 500)
238         self.check_get(f'{path}?id=1', 200)
239
240     def post_process(self, id_: int = 1,
241                      form_data: dict[str, Any] | None = None
242                      ) -> dict[str, Any]:
243         """POST basic Process."""
244         if not form_data:
245             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
246         self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
247         return form_data