home · contact · privacy
ccb485ad104cc37ae040a4b144093e69c75b87c0
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseSansDB(TestCase):
19     """Tests requiring no DB setup."""
20     checked_class: Any
21
22     def check_id_setting(self) -> None:
23         """Test .id_ being set and its legal range being enforced."""
24         with self.assertRaises(HandledException):
25             self.checked_class(0)
26         obj = self.checked_class(5)
27         self.assertEqual(obj.id_, 5)
28
29     def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
30         """Test defaults of VersionedAttributes."""
31         obj = self.checked_class(None)
32         for k, v in attrs.items():
33             self.assertEqual(getattr(obj, k).newest, v)
34
35
36 class TestCaseWithDB(TestCase):
37     """Module tests not requiring DB setup."""
38     checked_class: Any
39     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
40
41     def setUp(self) -> None:
42         Condition.empty_cache()
43         Day.empty_cache()
44         Process.empty_cache()
45         ProcessStep.empty_cache()
46         Todo.empty_cache()
47         timestamp = datetime.now().timestamp()
48         self.db_file = DatabaseFile(f'test_db:{timestamp}')
49         self.db_file.remake()
50         self.db_conn = DatabaseConnection(self.db_file)
51
52     def tearDown(self) -> None:
53         self.db_conn.close()
54         remove_file(self.db_file.path)
55
56     def check_storage(self, content: list[Any]) -> None:
57         """Test cache and DB equal content."""
58         expected_cache = {}
59         for item in content:
60             expected_cache[item.id_] = item
61         self.assertEqual(self.checked_class.get_cache(), expected_cache)
62         db_found: list[Any] = []
63         for item in content:
64             assert isinstance(item.id_, type(self.default_ids[0]))
65             for row in self.db_conn.row_where(self.checked_class.table_name,
66                                               'id', item.id_):
67                 db_found += [self.checked_class.from_table_row(self.db_conn,
68                                                                row)]
69         self.assertEqual(sorted(content), sorted(db_found))
70
71     def check_saving_and_caching(self, **kwargs: Any) -> Any:
72         """Test instance.save in its core without relations."""
73         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
74         # check object init itself doesn't store anything yet
75         self.check_storage([])
76         # check saving stores in cache and DB
77         obj.save(self.db_conn)
78         self.check_storage([obj])
79         # check core attributes set properly (and not unset by saving)
80         for key, value in kwargs.items():
81             self.assertEqual(getattr(obj, key), value)
82
83     def check_by_id(self) -> None:
84         """Test .by_id(), including creation."""
85         # check failure if not yet saved
86         id1, id2 = self.default_ids[0], self.default_ids[1]
87         obj = self.checked_class(id1)  # pylint: disable=not-callable
88         with self.assertRaises(NotFoundException):
89             self.checked_class.by_id(self.db_conn, id1)
90         # check identity of saved and retrieved
91         obj.save(self.db_conn)
92         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
93         # check create=True acts like normal instantiation (sans saving)
94         by_id_created = self.checked_class.by_id(self.db_conn, id2,
95                                                  create=True)
96         # pylint: disable=not-callable
97         self.assertEqual(self.checked_class(id2), by_id_created)
98         self.check_storage([obj])
99
100     def check_from_table_row(self, *args: Any) -> None:
101         """Test .from_table_row() properly reads in class from DB"""
102         id_ = self.default_ids[0]
103         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
104         obj.save(self.db_conn)
105         assert isinstance(obj.id_, type(self.default_ids[0]))
106         for row in self.db_conn.row_where(self.checked_class.table_name,
107                                           'id', obj.id_):
108             retrieved = self.checked_class.from_table_row(self.db_conn, row)
109             self.assertEqual(obj, retrieved)
110             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
111
112     def check_all(self) -> tuple[Any, Any, Any]:
113         """Test .all()."""
114         # pylint: disable=not-callable
115         item1 = self.checked_class(self.default_ids[0])
116         item2 = self.checked_class(self.default_ids[1])
117         item3 = self.checked_class(self.default_ids[2])
118         # check pre-save .all() returns empty list
119         self.assertEqual(self.checked_class.all(self.db_conn), [])
120         # check that all() shows all saved, but no unsaved items
121         item1.save(self.db_conn)
122         item3.save(self.db_conn)
123         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
124                          sorted([item1, item3]))
125         item2.save(self.db_conn)
126         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
127                          sorted([item1, item2, item3]))
128         return item1, item2, item3
129
130     def check_singularity(self, defaulting_field: str,
131                           non_default_value: Any, *args: Any) -> None:
132         """Test pointers made for single object keep pointing to it."""
133         id1 = self.default_ids[0]
134         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
135         obj.save(self.db_conn)
136         setattr(obj, defaulting_field, non_default_value)
137         retrieved = self.checked_class.by_id(self.db_conn, id1)
138         self.assertEqual(non_default_value,
139                          getattr(retrieved, defaulting_field))
140
141     def check_versioned_singularity(self) -> None:
142         """Test singularity of VersionedAttributes on saving (with .title)."""
143         obj = self.checked_class(None)  # pylint: disable=not-callable
144         obj.save(self.db_conn)
145         assert isinstance(obj.id_, int)
146         obj.title.set('named')
147         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
148         self.assertEqual(obj.title.history, retrieved.title.history)
149
150     def check_remove(self, *args: Any) -> None:
151         """Test .remove() effects on DB and cache."""
152         id_ = self.default_ids[0]
153         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
154         with self.assertRaises(HandledException):
155             obj.remove(self.db_conn)
156         obj.save(self.db_conn)
157         obj.remove(self.db_conn)
158         self.check_storage([])
159
160
161 class TestCaseWithServer(TestCaseWithDB):
162     """Module tests against our HTTP server/handler (and database)."""
163
164     def setUp(self) -> None:
165         super().setUp()
166         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
167         self.server_thread = Thread(target=self.httpd.serve_forever)
168         self.server_thread.daemon = True
169         self.server_thread.start()
170         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
171                                    self.httpd.server_address[1])
172
173     def tearDown(self) -> None:
174         self.httpd.shutdown()
175         self.httpd.server_close()
176         self.server_thread.join()
177         super().tearDown()
178
179     def check_redirect(self, target: str) -> None:
180         """Check that self.conn answers with a 302 redirect to target."""
181         response = self.conn.getresponse()
182         self.assertEqual(response.status, 302)
183         self.assertEqual(response.getheader('Location'), target)
184
185     def check_get(self, target: str, expected_code: int) -> None:
186         """Check that a GET to target yields expected_code."""
187         self.conn.request('GET', target)
188         self.assertEqual(self.conn.getresponse().status, expected_code)
189
190     def check_post(self, data: Mapping[str, object], target: str,
191                    expected_code: int, redirect_location: str = '') -> None:
192         """Check that POST of data to target yields expected_code."""
193         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
194         headers = {'Content-Type': 'application/x-www-form-urlencoded',
195                    'Content-Length': str(len(encoded_form_data))}
196         self.conn.request('POST', target,
197                           body=encoded_form_data, headers=headers)
198         if 302 == expected_code:
199             if redirect_location == '':
200                 redirect_location = target
201             self.check_redirect(redirect_location)
202         else:
203             self.assertEqual(self.conn.getresponse().status, expected_code)
204
205     def check_get_defaults(self, path: str) -> None:
206         """Some standard model paths to test."""
207         self.check_get(path, 200)
208         self.check_get(f'{path}?id=', 200)
209         self.check_get(f'{path}?id=foo', 400)
210         self.check_get(f'/{path}?id=0', 500)
211         self.check_get(f'{path}?id=1', 200)
212
213     def post_process(self, id_: int = 1,
214                      form_data: dict[str, Any] | None = None
215                      ) -> dict[str, Any]:
216         """POST basic Process."""
217         if not form_data:
218             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
219         self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
220         return form_data