home · contact · privacy
Add Todo.comment, and for that purpose basic SQL migration infrastructure.
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseSansDB(TestCase):
19     """Tests requiring no DB setup."""
20     checked_class: Any
21
22     def check_id_setting(self, *args: Any) -> None:
23         """Test .id_ being set and its legal range being enforced."""
24         with self.assertRaises(HandledException):
25             self.checked_class(0, *args)
26         obj = self.checked_class(5, *args)
27         self.assertEqual(obj.id_, 5)
28
29     def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
30         """Test defaults of VersionedAttributes."""
31         obj = self.checked_class(None)
32         for k, v in attrs.items():
33             self.assertEqual(getattr(obj, k).newest, v)
34
35
36 class TestCaseWithDB(TestCase):
37     """Module tests not requiring DB setup."""
38     checked_class: Any
39     default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
40
41     def setUp(self) -> None:
42         Condition.empty_cache()
43         Day.empty_cache()
44         Process.empty_cache()
45         ProcessStep.empty_cache()
46         Todo.empty_cache()
47         timestamp = datetime.now().timestamp()
48         self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
49         self.db_conn = DatabaseConnection(self.db_file)
50
51     def tearDown(self) -> None:
52         self.db_conn.close()
53         remove_file(self.db_file.path)
54
55     def check_storage(self, content: list[Any]) -> None:
56         """Test cache and DB equal content."""
57         expected_cache = {}
58         for item in content:
59             expected_cache[item.id_] = item
60         self.assertEqual(self.checked_class.get_cache(), expected_cache)
61         db_found: list[Any] = []
62         for item in content:
63             assert isinstance(item.id_, type(self.default_ids[0]))
64             for row in self.db_conn.row_where(self.checked_class.table_name,
65                                               'id', item.id_):
66                 db_found += [self.checked_class.from_table_row(self.db_conn,
67                                                                row)]
68         self.assertEqual(sorted(content), sorted(db_found))
69
70     def check_saving_and_caching(self, **kwargs: Any) -> None:
71         """Test instance.save in its core without relations."""
72         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
73         # check object init itself doesn't store anything yet
74         self.check_storage([])
75         # check saving stores in cache and DB
76         obj.save(self.db_conn)
77         self.check_storage([obj])
78         # check core attributes set properly (and not unset by saving)
79         for key, value in kwargs.items():
80             self.assertEqual(getattr(obj, key), value)
81
82     def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
83         """Test owner's versioned attributes."""
84         owner = self.checked_class(None)
85         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
86         attr = getattr(owner, attr_name)
87         attr.set(vals[0])
88         attr.set(vals[1])
89         owner.save(self.db_conn)
90         owner.uncache()
91         retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
92         attr = getattr(retrieved, attr_name)
93         self.assertEqual(sorted(attr.history.values()), vals)
94
95     def check_by_id(self) -> None:
96         """Test .by_id(), including creation."""
97         # check failure if not yet saved
98         id1, id2 = self.default_ids[0], self.default_ids[1]
99         obj = self.checked_class(id1)  # pylint: disable=not-callable
100         with self.assertRaises(NotFoundException):
101             self.checked_class.by_id(self.db_conn, id1)
102         # check identity of saved and retrieved
103         obj.save(self.db_conn)
104         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
105         # check create=True acts like normal instantiation (sans saving)
106         by_id_created = self.checked_class.by_id(self.db_conn, id2,
107                                                  create=True)
108         # pylint: disable=not-callable
109         self.assertEqual(self.checked_class(id2), by_id_created)
110         self.check_storage([obj])
111
112     def check_from_table_row(self, *args: Any) -> None:
113         """Test .from_table_row() properly reads in class from DB"""
114         id_ = self.default_ids[0]
115         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
116         obj.save(self.db_conn)
117         assert isinstance(obj.id_, type(self.default_ids[0]))
118         for row in self.db_conn.row_where(self.checked_class.table_name,
119                                           'id', obj.id_):
120             retrieved = self.checked_class.from_table_row(self.db_conn, row)
121             self.assertEqual(obj, retrieved)
122             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
123
124     def check_versioned_from_table_row(self, attr_name: str,
125                                        type_: type) -> None:
126         """Test .from_table_row() reads versioned attributes from DB."""
127         owner = self.checked_class(None)
128         vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
129         attr = getattr(owner, attr_name)
130         attr.set(vals[0])
131         attr.set(vals[1])
132         owner.save(self.db_conn)
133         for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
134             retrieved = owner.__class__.from_table_row(self.db_conn, row)
135             attr = getattr(retrieved, attr_name)
136             self.assertEqual(sorted(attr.history.values()), vals)
137
138     def check_all(self) -> tuple[Any, Any, Any]:
139         """Test .all()."""
140         # pylint: disable=not-callable
141         item1 = self.checked_class(self.default_ids[0])
142         item2 = self.checked_class(self.default_ids[1])
143         item3 = self.checked_class(self.default_ids[2])
144         # check pre-save .all() returns empty list
145         self.assertEqual(self.checked_class.all(self.db_conn), [])
146         # check that all() shows all saved, but no unsaved items
147         item1.save(self.db_conn)
148         item3.save(self.db_conn)
149         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
150                          sorted([item1, item3]))
151         item2.save(self.db_conn)
152         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
153                          sorted([item1, item2, item3]))
154         return item1, item2, item3
155
156     def check_singularity(self, defaulting_field: str,
157                           non_default_value: Any, *args: Any) -> None:
158         """Test pointers made for single object keep pointing to it."""
159         id1 = self.default_ids[0]
160         obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
161         obj.save(self.db_conn)
162         setattr(obj, defaulting_field, non_default_value)
163         retrieved = self.checked_class.by_id(self.db_conn, id1)
164         self.assertEqual(non_default_value,
165                          getattr(retrieved, defaulting_field))
166
167     def check_versioned_singularity(self) -> None:
168         """Test singularity of VersionedAttributes on saving (with .title)."""
169         obj = self.checked_class(None)  # pylint: disable=not-callable
170         obj.save(self.db_conn)
171         assert isinstance(obj.id_, int)
172         obj.title.set('named')
173         retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
174         self.assertEqual(obj.title.history, retrieved.title.history)
175
176     def check_remove(self, *args: Any) -> None:
177         """Test .remove() effects on DB and cache."""
178         id_ = self.default_ids[0]
179         obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
180         with self.assertRaises(HandledException):
181             obj.remove(self.db_conn)
182         obj.save(self.db_conn)
183         obj.remove(self.db_conn)
184         self.check_storage([])
185
186
187 class TestCaseWithServer(TestCaseWithDB):
188     """Module tests against our HTTP server/handler (and database)."""
189
190     def setUp(self) -> None:
191         super().setUp()
192         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
193         self.server_thread = Thread(target=self.httpd.serve_forever)
194         self.server_thread.daemon = True
195         self.server_thread.start()
196         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
197                                    self.httpd.server_address[1])
198
199     def tearDown(self) -> None:
200         self.httpd.shutdown()
201         self.httpd.server_close()
202         self.server_thread.join()
203         super().tearDown()
204
205     def check_redirect(self, target: str) -> None:
206         """Check that self.conn answers with a 302 redirect to target."""
207         response = self.conn.getresponse()
208         self.assertEqual(response.status, 302)
209         self.assertEqual(response.getheader('Location'), target)
210
211     def check_get(self, target: str, expected_code: int) -> None:
212         """Check that a GET to target yields expected_code."""
213         self.conn.request('GET', target)
214         self.assertEqual(self.conn.getresponse().status, expected_code)
215
216     def check_post(self, data: Mapping[str, object], target: str,
217                    expected_code: int, redirect_location: str = '') -> None:
218         """Check that POST of data to target yields expected_code."""
219         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
220         headers = {'Content-Type': 'application/x-www-form-urlencoded',
221                    'Content-Length': str(len(encoded_form_data))}
222         self.conn.request('POST', target,
223                           body=encoded_form_data, headers=headers)
224         if 302 == expected_code:
225             if redirect_location == '':
226                 redirect_location = target
227             self.check_redirect(redirect_location)
228         else:
229             self.assertEqual(self.conn.getresponse().status, expected_code)
230
231     def check_get_defaults(self, path: str) -> None:
232         """Some standard model paths to test."""
233         self.check_get(path, 200)
234         self.check_get(f'{path}?id=', 200)
235         self.check_get(f'{path}?id=foo', 400)
236         self.check_get(f'/{path}?id=0', 500)
237         self.check_get(f'{path}?id=1', 200)
238
239     def post_process(self, id_: int = 1,
240                      form_data: dict[str, Any] | None = None
241                      ) -> dict[str, Any]:
242         """POST basic Process."""
243         if not form_data:
244             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
245         self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
246         return form_data