home · contact · privacy
Minor tests refactoring.
[plomtask] / tests / utils.py
index 545a2ba2881372e2bcd479ecd4e8cd2d7da2ed8c..61dbb36b949ee3fdfd5a38dcada155a5ab18a924 100644 (file)
@@ -5,17 +5,20 @@ from http.client import HTTPConnection
 from urllib.parse import urlencode
 from datetime import datetime
 from os import remove as remove_file
-from typing import Mapping
+from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.conditions import Condition
 from plomtask.days import Day
 from plomtask.todos import Todo
+from plomtask.exceptions import NotFoundException, HandledException
 
 
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
+    checked_class: Any
+    default_ids: tuple[int | str, int | str, int | str]
 
     def setUp(self) -> None:
         Condition.empty_cache()
@@ -32,6 +35,101 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    def check_storage(self, content: list[Any]) -> None:
+        """Test cache and DB equal content."""
+        expected_cache = {}
+        for item in content:
+            expected_cache[item.id_] = item
+        self.assertEqual(self.checked_class.get_cache(), expected_cache)
+        db_found: list[Any] = []
+        for item in content:
+            assert isinstance(item.id_, (str, int))
+            for row in self.db_conn.row_where(self.checked_class.table_name,
+                                              'id', item.id_):
+                db_found += [self.checked_class.from_table_row(self.db_conn,
+                                                               row)]
+        self.assertEqual(sorted(content), sorted(db_found))
+
+    def check_saving_and_caching(self, **kwargs: Any) -> Any:
+        """Test instance.save in its core without relations."""
+        obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
+        # check object init itself doesn't store anything yet
+        self.check_storage([])
+        # check saving stores in cache and DB
+        obj.save(self.db_conn)
+        self.check_storage([obj])
+        # check core attributes set properly (and not unset by saving)
+        for key, value in kwargs.items():
+            self.assertEqual(getattr(obj, key), value)
+
+    def check_by_id(self) -> None:
+        """Test .by_id(), including creation."""
+        # check failure if not yet saved
+        id1, id2 = self.default_ids[0], self.default_ids[1]
+        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        with self.assertRaises(NotFoundException):
+            self.checked_class.by_id(self.db_conn, id1)
+        # check identity of saved and retrieved
+        obj.save(self.db_conn)
+        self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
+        # check create=True acts like normal instantiation (sans saving)
+        by_id_created = self.checked_class.by_id(self.db_conn, id2,
+                                                 create=True)
+        # pylint: disable=not-callable
+        self.assertEqual(self.checked_class(id2), by_id_created)
+        self.check_storage([obj])
+
+    def check_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class from DB"""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        assert isinstance(obj.id_, (str, int))
+        for row in self.db_conn.row_where(self.checked_class.table_name,
+                                          'id', obj.id_):
+            retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertEqual(obj, retrieved)
+            self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
+
+    def check_all(self) -> tuple[Any, Any, Any]:
+        """Test .all()."""
+        # pylint: disable=not-callable
+        item1 = self.checked_class(self.default_ids[0])
+        item2 = self.checked_class(self.default_ids[1])
+        item3 = self.checked_class(self.default_ids[2])
+        # check pre-save .all() returns empty list
+        self.assertEqual(self.checked_class.all(self.db_conn), [])
+        # check that all() shows all saved, but no unsaved items
+        item1.save(self.db_conn)
+        item3.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item3]))
+        item2.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item2, item3]))
+        return item1, item2, item3
+
+    def check_singularity(self, defaulting_field: str,
+                          non_default_value: Any) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        id1 = self.default_ids[0]
+        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        setattr(obj, defaulting_field, non_default_value)
+        retrieved = self.checked_class.by_id(self.db_conn, id1)
+        self.assertEqual(non_default_value,
+                         getattr(retrieved, defaulting_field))
+
+    def check_remove(self) -> None:
+        """Test .remove() effects on DB and cache."""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_)  # pylint: disable=not-callable
+        with self.assertRaises(HandledException):
+            obj.remove(self.db_conn)
+        obj.save(self.db_conn)
+        obj.remove(self.db_conn)
+        self.check_storage([])
+
 
 class TestCaseWithServer(TestCaseWithDB):
     """Module tests against our HTTP server/handler (and database)."""
@@ -84,3 +182,12 @@ class TestCaseWithServer(TestCaseWithDB):
         self.check_get(f'{path}?id=foo', 400)
         self.check_get(f'/{path}?id=0', 500)
         self.check_get(f'{path}?id=1', 200)
+
+    def post_process(self, id_: int = 1,
+                     form_data: dict[str, Any] | None = None
+                     ) -> dict[str, Any]:
+        """POST basic Process."""
+        if not form_data:
+            form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
+        self.check_post(form_data, '/process?id=', 302, f'/process?id={id_}')
+        return form_data