1 """Shared test utilities."""
2 from unittest import TestCase
3 from threading import Thread
4 from http.client import HTTPConnection
5 from urllib.parse import urlencode
6 from datetime import datetime
7 from os import remove as remove_file
8 from typing import Mapping, Any
9 from plomtask.db import DatabaseFile, DatabaseConnection
10 from plomtask.http import TaskHandler, TaskServer
11 from plomtask.processes import Process, ProcessStep
12 from plomtask.conditions import Condition
13 from plomtask.days import Day
14 from plomtask.todos import Todo
15 from plomtask.exceptions import NotFoundException, HandledException
18 class TestCaseWithDB(TestCase):
19 """Module tests not requiring DB setup."""
21 default_ids: tuple[int | str, int | str, int | str]
23 def setUp(self) -> None:
24 Condition.empty_cache()
27 ProcessStep.empty_cache()
29 timestamp = datetime.now().timestamp()
30 self.db_file = DatabaseFile(f'test_db:{timestamp}')
32 self.db_conn = DatabaseConnection(self.db_file)
34 def tearDown(self) -> None:
36 remove_file(self.db_file.path)
38 def check_storage(self, content: list[Any]) -> None:
39 """Test cache and DB equal content."""
42 expected_cache[item.id_] = item
43 self.assertEqual(self.checked_class.get_cache(), expected_cache)
44 db_found: list[Any] = []
46 assert isinstance(item.id_, (str, int))
47 for row in self.db_conn.row_where(self.checked_class.table_name,
49 db_found += [self.checked_class.from_table_row(self.db_conn,
51 self.assertEqual(sorted(content), sorted(db_found))
53 def check_saving_and_caching(self, **kwargs: Any) -> Any:
54 """Test instance.save in its core without relations."""
55 obj = self.checked_class(**kwargs) # pylint: disable=not-callable
56 # check object init itself doesn't store anything yet
57 self.check_storage([])
58 # check saving stores in cache and DB
59 obj.save(self.db_conn)
60 self.check_storage([obj])
61 # check core attributes set properly (and not unset by saving)
62 for key, value in kwargs.items():
63 self.assertEqual(getattr(obj, key), value)
65 def check_by_id(self) -> None:
66 """Test .by_id(), including creation."""
67 # check failure if not yet saved
68 id1, id2 = self.default_ids[0], self.default_ids[1]
69 obj = self.checked_class(id1) # pylint: disable=not-callable
70 with self.assertRaises(NotFoundException):
71 self.checked_class.by_id(self.db_conn, id1)
72 # check identity of saved and retrieved
73 obj.save(self.db_conn)
74 self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
75 # check create=True acts like normal instantiation (sans saving)
76 by_id_created = self.checked_class.by_id(self.db_conn, id2,
78 # pylint: disable=not-callable
79 self.assertEqual(self.checked_class(id2), by_id_created)
80 self.check_storage([obj])
82 def check_from_table_row(self) -> None:
83 """Test .from_table_row() properly reads in class from DB"""
84 id_ = self.default_ids[0]
85 obj = self.checked_class(id_) # pylint: disable=not-callable
86 obj.save(self.db_conn)
87 assert isinstance(obj.id_, (str, int))
88 for row in self.db_conn.row_where(self.checked_class.table_name,
90 retrieved = self.checked_class.from_table_row(self.db_conn, row)
91 self.assertEqual(obj, retrieved)
92 self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
94 def check_all(self) -> tuple[Any, Any, Any]:
96 # pylint: disable=not-callable
97 item1 = self.checked_class(self.default_ids[0])
98 item2 = self.checked_class(self.default_ids[1])
99 item3 = self.checked_class(self.default_ids[2])
100 # check pre-save .all() returns empty list
101 self.assertEqual(self.checked_class.all(self.db_conn), [])
102 # check that all() shows all saved, but no unsaved items
103 item1.save(self.db_conn)
104 item3.save(self.db_conn)
105 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
106 sorted([item1, item3]))
107 item2.save(self.db_conn)
108 self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
109 sorted([item1, item2, item3]))
110 return item1, item2, item3
112 def check_singularity(self, defaulting_field: str,
113 non_default_value: Any) -> None:
114 """Test pointers made for single object keep pointing to it."""
115 id1 = self.default_ids[0]
116 obj = self.checked_class(id1) # pylint: disable=not-callable
117 obj.save(self.db_conn)
118 setattr(obj, defaulting_field, non_default_value)
119 retrieved = self.checked_class.by_id(self.db_conn, id1)
120 self.assertEqual(non_default_value,
121 getattr(retrieved, defaulting_field))
123 def check_remove(self) -> None:
124 """Test .remove() effects on DB and cache."""
125 id_ = self.default_ids[0]
126 obj = self.checked_class(id_) # pylint: disable=not-callable
127 with self.assertRaises(HandledException):
128 obj.remove(self.db_conn)
129 obj.save(self.db_conn)
130 obj.remove(self.db_conn)
131 self.check_storage([])
134 class TestCaseWithServer(TestCaseWithDB):
135 """Module tests against our HTTP server/handler (and database)."""
137 def setUp(self) -> None:
139 self.httpd = TaskServer(self.db_file, ('localhost', 0), TaskHandler)
140 self.server_thread = Thread(target=self.httpd.serve_forever)
141 self.server_thread.daemon = True
142 self.server_thread.start()
143 self.conn = HTTPConnection(str(self.httpd.server_address[0]),
144 self.httpd.server_address[1])
146 def tearDown(self) -> None:
147 self.httpd.shutdown()
148 self.httpd.server_close()
149 self.server_thread.join()
152 def check_redirect(self, target: str) -> None:
153 """Check that self.conn answers with a 302 redirect to target."""
154 response = self.conn.getresponse()
155 self.assertEqual(response.status, 302)
156 self.assertEqual(response.getheader('Location'), target)
158 def check_get(self, target: str, expected_code: int) -> None:
159 """Check that a GET to target yields expected_code."""
160 self.conn.request('GET', target)
161 self.assertEqual(self.conn.getresponse().status, expected_code)
163 def check_post(self, data: Mapping[str, object], target: str,
164 expected_code: int, redirect_location: str = '') -> None:
165 """Check that POST of data to target yields expected_code."""
166 encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
167 headers = {'Content-Type': 'application/x-www-form-urlencoded',
168 'Content-Length': str(len(encoded_form_data))}
169 self.conn.request('POST', target,
170 body=encoded_form_data, headers=headers)
171 if 302 == expected_code:
172 if redirect_location == '':
173 redirect_location = target
174 self.check_redirect(redirect_location)
176 self.assertEqual(self.conn.getresponse().status, expected_code)
178 def check_get_defaults(self, path: str) -> None:
179 """Some standard model paths to test."""
180 self.check_get(path, 200)
181 self.check_get(f'{path}?id=', 200)
182 self.check_get(f'{path}?id=foo', 400)
183 self.check_get(f'/{path}?id=0', 500)
184 self.check_get(f'{path}?id=1', 200)
186 def post_process(self, id_: int = 1,
187 form_data: dict[str, Any] | None = None
189 """POST basic Process."""
191 form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
192 self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')