1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseWithDB(TestCase):
19 """Module tests not requiring DB setup."""
21 default_ids: tuple[int | str, int | str, int | str]
23 def setUp(self) -> None:
24 Condition.empty_cache()
27 ProcessStep.empty_cache()
29 timestamp = datetime.now().timestamp()
30 self.db_file = DatabaseFile(f'test_db:{timestamp}')
32 self.db_conn = DatabaseConnection(self.db_file)
34 def tearDown(self) -> None:
36 remove_file(self.db_file.path)
38 def check_storage(self, content: list[Any]) -> None:
39 """Test cache and DB equal content."""
42 expected_cache[item.id_] = item
43 self.assertEqual(self.checked_class.get_cache(), expected_cache)
44 db_found: list[Any] = []
46 assert isinstance(item.id_, (str, int))
47 for row in self.db_conn.row_where(self.checked_class.table_name,
49 db_found += [self.checked_class.from_table_row(self.db_conn,
51 self.assertEqual(sorted(content), sorted(db_found))
53 def check_saving_and_caching(self, **kwargs: Any) -> Any:
54 """Test instance.save in its core without relations."""
55 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
56 # check object init itself doesn't store anything yet
57 self.check_storage([])
58 # check saving stores in cache and DB
59 obj.save(self.db_conn)
60 self.check_storage([obj])
61 # check core attributes set properly (and not unset by saving)
62 for key, value in kwargs.items():
63 self.assertEqual(getattr(obj, key), value)
65 def check_by_id(self) -> None:
66 """Test .by_id(), including creation."""
67 # check failure if not yet saved
68 id1, id2 = self.default_ids[0], self.default_ids[1]
69 obj = self.checked_class(id1) # pylint: disable=not-callable
70 with self.assertRaises(NotFoundException):
71 self.checked_class.by_id(self.db_conn, id1)
72 # check identity of saved and retrieved
73 obj.save(self.db_conn)
74 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
75 # check create=True acts like normal instantiation (sans saving)
76 by_id_created = self.checked_class.by_id(self.db_conn, id2,
78 # pylint: disable=not-callable
79 self.assertEqual(self.checked_class(id2), by_id_created)
80 self.check_storage([obj])
82 def check_from_table_row(self, id_: int | str) -> None:
83 """Test .from_table_row() properly reads in class from DB"""
84 obj = self.checked_class(id_) # pylint: disable=not-callable
85 obj.save(self.db_conn)
86 assert isinstance(obj.id_, (str, int))
87 for row in self.db_conn.row_where(self.checked_class.table_name,
89 retrieved = self.checked_class.from_table_row(self.db_conn, row)
90 self.assertEqual(obj, retrieved)
91 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
93 def check_all(self) -> tuple[Any, Any, Any]:
95 # pylint: disable=not-callable
96 item1 = self.checked_class(self.default_ids[0])
97 item2 = self.checked_class(self.default_ids[1])
98 item3 = self.checked_class(self.default_ids[2])
99 # check pre-save .all() returns empty list
100 self.assertEqual(self.checked_class.all(self.db_conn), [])
101 # check that all() shows all saved, but no unsaved items
102 item1.save(self.db_conn)
103 item3.save(self.db_conn)
104 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
105 sorted([item1, item3]))
106 item2.save(self.db_conn)
107 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
108 sorted([item1, item2, item3]))
109 return item1, item2, item3
111 def check_singularity(self, defaulting_field: str,
112 non_default_value: Any) -> None:
113 """Test pointers made for single object keep pointing to it."""
114 id1 = self.default_ids[0]
115 obj = self.checked_class(id1) # pylint: disable=not-callable
116 obj.save(self.db_conn)
117 setattr(obj, defaulting_field, non_default_value)
118 retrieved = self.checked_class.by_id(self.db_conn, id1)
119 self.assertEqual(non_default_value,
120 getattr(retrieved, defaulting_field))
122 def check_remove(self) -> None:
123 """Test .remove() effects on DB and cache."""
124 id_ = self.default_ids[0]
125 obj = self.checked_class(id_) # pylint: disable=not-callable
126 with self.assertRaises(HandledException):
127 obj.remove(self.db_conn)
128 obj.save(self.db_conn)
129 obj.remove(self.db_conn)
130 self.check_storage([])
133 class TestCaseWithServer(TestCaseWithDB):
134 """Module tests against our HTTP server/handler (and database)."""
136 def setUp(self) -> None:
138 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
139 self.server_thread = Thread(target=self.httpd.serve_forever)
140 self.server_thread.daemon = True
141 self.server_thread.start()
142 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
143 self.httpd.server_address[1])
145 def tearDown(self) -> None:
146 self.httpd.shutdown()
147 self.httpd.server_close()
148 self.server_thread.join()
151 def check_redirect(self, target: str) -> None:
152 """Check that self.conn answers with a 302 redirect to target."""
153 response = self.conn.getresponse()
154 self.assertEqual(response.status, 302)
155 self.assertEqual(response.getheader('Location'), target)
157 def check_get(self, target: str, expected_code: int) -> None:
158 """Check that a GET to target yields expected_code."""
159 self.conn.request('GET', target)
160 self.assertEqual(self.conn.getresponse().status, expected_code)
162 def check_post(self, data: Mapping[str, object], target: str,
163 expected_code: int, redirect_location: str = '') -> None:
164 """Check that POST of data to target yields expected_code."""
165 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
166 headers = {'Content-Type': 'application/x-www-form-urlencoded',
167 'Content-Length': str(len(encoded_form_data))}
168 self.conn.request('POST', target,
169 body=encoded_form_data, headers=headers)
170 if 302 == expected_code:
171 if redirect_location == '':
172 redirect_location = target
173 self.check_redirect(redirect_location)
175 self.assertEqual(self.conn.getresponse().status, expected_code)
177 def check_get_defaults(self, path: str) -> None:
178 """Some standard model paths to test."""
179 self.check_get(path, 200)
180 self.check_get(f'{path}?id=', 200)
181 self.check_get(f'{path}?id=foo', 400)
182 self.check_get(f'/{path}?id=0', 500)
183 self.check_get(f'{path}?id=1', 200)
185 def post_process(self, id_: int = 1,
186 form_data: dict[str, Any] | None = None
188 """POST basic Process."""
190 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
191 self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')