home · contact · privacy
a826c16a5adb1d24222ef1f4b42f5cefc711baf8
[plomtask] / tests / utils.py
1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
16
17
18 class TestCaseWithDB(TestCase):
19     """Module tests not requiring DB setup."""
20     checked_class: Any
21     default_ids: tuple[int | str, int | str, int | str]
22
23     def setUp(self) -> None:
24         Condition.empty_cache()
25         Day.empty_cache()
26         Process.empty_cache()
27         ProcessStep.empty_cache()
28         Todo.empty_cache()
29         timestamp = datetime.now().timestamp()
30         self.db_file = DatabaseFile(f'test_db:{timestamp}')
31         self.db_file.remake()
32         self.db_conn = DatabaseConnection(self.db_file)
33
34     def tearDown(self) -> None:
35         self.db_conn.close()
36         remove_file(self.db_file.path)
37
38     def check_storage(self, content: list[Any]) -> None:
39         """Test cache and DB equal content."""
40         expected_cache = {}
41         for item in content:
42             expected_cache[item.id_] = item
43         self.assertEqual(self.checked_class.get_cache(), expected_cache)
44         db_found: list[Any] = []
45         for item in content:
46             assert isinstance(item.id_, (str, int))
47             for row in self.db_conn.row_where(self.checked_class.table_name,
48                                               'id', item.id_):
49                 db_found += [self.checked_class.from_table_row(self.db_conn,
50                                                                row)]
51         self.assertEqual(sorted(content), sorted(db_found))
52
53     def check_saving_and_caching(self, **kwargs: Any) -> Any:
54         """Test instance.save in its core without relations."""
55         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
56         # check object init itself doesn't store anything yet
57         self.check_storage([])
58         # check saving stores in cache and DB
59         obj.save(self.db_conn)
60         self.check_storage([obj])
61         # check core attributes set properly (and not unset by saving)
62         for key, value in kwargs.items():
63             self.assertEqual(getattr(obj, key), value)
64
65     def check_by_id(self) -> None:
66         """Test .by_id(), including creation."""
67         # check failure if not yet saved
68         id1, id2 = self.default_ids[0], self.default_ids[1]
69         obj = self.checked_class(id1)  # pylint: disable=not-callable
70         with self.assertRaises(NotFoundException):
71             self.checked_class.by_id(self.db_conn, id1)
72         # check identity of saved and retrieved
73         obj.save(self.db_conn)
74         self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
75         # check create=True acts like normal instantiation (sans saving)
76         by_id_created = self.checked_class.by_id(self.db_conn, id2,
77                                                  create=True)
78         # pylint: disable=not-callable
79         self.assertEqual(self.checked_class(id2), by_id_created)
80         self.check_storage([obj])
81
82     def check_from_table_row(self, id_: int | str) -> None:
83         """Test .from_table_row() properly reads in class from DB"""
84         obj = self.checked_class(id_)  # pylint: disable=not-callable
85         obj.save(self.db_conn)
86         assert isinstance(obj.id_, (str, int))
87         for row in self.db_conn.row_where(self.checked_class.table_name,
88                                           'id', obj.id_):
89             retrieved = self.checked_class.from_table_row(self.db_conn, row)
90             self.assertEqual(obj, retrieved)
91             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
92
93     def check_all(self) -> tuple[Any, Any, Any]:
94         """Test .all()."""
95         # pylint: disable=not-callable
96         item1 = self.checked_class(self.default_ids[0])
97         item2 = self.checked_class(self.default_ids[1])
98         item3 = self.checked_class(self.default_ids[2])
99         # check pre-save .all() returns empty list
100         self.assertEqual(self.checked_class.all(self.db_conn), [])
101         # check that all() shows all saved, but no unsaved items
102         item1.save(self.db_conn)
103         item3.save(self.db_conn)
104         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
105                          sorted([item1, item3]))
106         item2.save(self.db_conn)
107         self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
108                          sorted([item1, item2, item3]))
109         return item1, item2, item3
110
111     def check_singularity(self, defaulting_field: str,
112                           non_default_value: Any) -> None:
113         """Test pointers made for single object keep pointing to it."""
114         id1 = self.default_ids[0]
115         obj = self.checked_class(id1)  # pylint: disable=not-callable
116         obj.save(self.db_conn)
117         setattr(obj, defaulting_field, non_default_value)
118         retrieved = self.checked_class.by_id(self.db_conn, id1)
119         self.assertEqual(non_default_value,
120                          getattr(retrieved, defaulting_field))
121
122     def check_remove(self) -> None:
123         """Test .remove() effects on DB and cache."""
124         id_ = self.default_ids[0]
125         obj = self.checked_class(id_)  # pylint: disable=not-callable
126         with self.assertRaises(HandledException):
127             obj.remove(self.db_conn)
128         obj.save(self.db_conn)
129         obj.remove(self.db_conn)
130         self.check_storage([])
131
132
133 class TestCaseWithServer(TestCaseWithDB):
134     """Module tests against our HTTP server/handler (and database)."""
135
136     def setUp(self) -> None:
137         super().setUp()
138         self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
139         self.server_thread = Thread(target=self.httpd.serve_forever)
140         self.server_thread.daemon = True
141         self.server_thread.start()
142         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
143                                    self.httpd.server_address[1])
144
145     def tearDown(self) -> None:
146         self.httpd.shutdown()
147         self.httpd.server_close()
148         self.server_thread.join()
149         super().tearDown()
150
151     def check_redirect(self, target: str) -> None:
152         """Check that self.conn answers with a 302 redirect to target."""
153         response = self.conn.getresponse()
154         self.assertEqual(response.status, 302)
155         self.assertEqual(response.getheader('Location'), target)
156
157     def check_get(self, target: str, expected_code: int) -> None:
158         """Check that a GET to target yields expected_code."""
159         self.conn.request('GET', target)
160         self.assertEqual(self.conn.getresponse().status, expected_code)
161
162     def check_post(self, data: Mapping[str, object], target: str,
163                    expected_code: int, redirect_location: str = '') -> None:
164         """Check that POST of data to target yields expected_code."""
165         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
166         headers = {'Content-Type': 'application/x-www-form-urlencoded',
167                    'Content-Length': str(len(encoded_form_data))}
168         self.conn.request('POST', target,
169                           body=encoded_form_data, headers=headers)
170         if 302 == expected_code:
171             if redirect_location == '':
172                 redirect_location = target
173             self.check_redirect(redirect_location)
174         else:
175             self.assertEqual(self.conn.getresponse().status, expected_code)
176
177     def check_get_defaults(self, path: str) -> None:
178         """Some standard model paths to test."""
179         self.check_get(path, 200)
180         self.check_get(f'{path}?id=', 200)
181         self.check_get(f'{path}?id=foo', 400)
182         self.check_get(f'/{path}?id=0', 500)
183         self.check_get(f'{path}?id=1', 200)
184
185     def post_process(self, id_: int = 1,
186                      form_data: dict[str, Any] | None = None
187                      ) -> dict[str, Any]:
188         """POST basic Process."""
189         if not form_data:
190             form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
191         self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
192         return form_data