home · contact · privacy
Minor template improvements.
[plomtask] / tests / utils.py
index 63b07e93e6f14ca51426e2fc00e959fdbfca7bf1..6f44f611b487f139536822b6340fae5f5ddcf5d5 100644 (file)
@@ -5,17 +5,47 @@ from http.client import HTTPConnection
 from urllib.parse import urlencode
 from datetime import datetime
 from os import remove as remove_file
-from typing import Mapping
+from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.conditions import Condition
 from plomtask.days import Day
 from plomtask.todos import Todo
+from plomtask.exceptions import NotFoundException, HandledException
+
+
+class TestCaseSansDB(TestCase):
+    """Tests requiring no DB setup."""
+    checked_class: Any
+    do_id_test: bool = False
+    default_init_args: list[Any] = []
+    versioned_defaults_to_test: dict[str, str | float] = {}
+
+    def test_id_setting(self) -> None:
+        """Test .id_ being set and its legal range being enforced."""
+        if not self.do_id_test:
+            return
+        with self.assertRaises(HandledException):
+            self.checked_class(0, *self.default_init_args)
+        obj = self.checked_class(5, *self.default_init_args)
+        self.assertEqual(obj.id_, 5)
+
+    def test_versioned_defaults(self) -> None:
+        """Test defaults of VersionedAttributes."""
+        if len(self.versioned_defaults_to_test) == 0:
+            return
+        obj = self.checked_class(1, *self.default_init_args)
+        for k, v in self.versioned_defaults_to_test.items():
+            self.assertEqual(getattr(obj, k).newest, v)
 
 
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
+    checked_class: Any
+    default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
+    default_init_kwargs: dict[str, Any] = {}
+    test_versioneds: dict[str, type] = {}
 
     def setUp(self) -> None:
         Condition.empty_cache()
@@ -24,14 +54,155 @@ class TestCaseWithDB(TestCase):
         ProcessStep.empty_cache()
         Todo.empty_cache()
         timestamp = datetime.now().timestamp()
-        self.db_file = DatabaseFile(f'test_db:{timestamp}')
-        self.db_file.remake()
+        self.db_file = DatabaseFile.create_at(f'test_db:{timestamp}')
         self.db_conn = DatabaseConnection(self.db_file)
 
     def tearDown(self) -> None:
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    def test_saving_and_caching(self) -> None:
+        """Test storage and initialization of instances and attributes."""
+        if not hasattr(self, 'checked_class'):
+            return
+        self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
+        obj = self.checked_class(None, **self.default_init_kwargs)
+        obj.save(self.db_conn)
+        self.assertEqual(obj.id_, 2)
+        for k, v in self.test_versioneds.items():
+            self.check_saving_of_versioned(k, v)
+
+    def check_storage(self, content: list[Any]) -> None:
+        """Test cache and DB equal content."""
+        expected_cache = {}
+        for item in content:
+            expected_cache[item.id_] = item
+        self.assertEqual(self.checked_class.get_cache(), expected_cache)
+        db_found: list[Any] = []
+        for item in content:
+            assert isinstance(item.id_, type(self.default_ids[0]))
+            for row in self.db_conn.row_where(self.checked_class.table_name,
+                                              'id', item.id_):
+                db_found += [self.checked_class.from_table_row(self.db_conn,
+                                                               row)]
+        self.assertEqual(sorted(content), sorted(db_found))
+
+    def check_saving_and_caching(self, **kwargs: Any) -> None:
+        """Test instance.save in its core without relations."""
+        obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
+        # check object init itself doesn't store anything yet
+        self.check_storage([])
+        # check saving stores in cache and DB
+        obj.save(self.db_conn)
+        self.check_storage([obj])
+        # check core attributes set properly (and not unset by saving)
+        for key, value in kwargs.items():
+            self.assertEqual(getattr(obj, key), value)
+
+    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
+        """Test owner's versioned attributes."""
+        owner = self.checked_class(None)
+        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+        attr = getattr(owner, attr_name)
+        attr.set(vals[0])
+        attr.set(vals[1])
+        owner.save(self.db_conn)
+        owner.uncache()
+        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+        attr = getattr(retrieved, attr_name)
+        self.assertEqual(sorted(attr.history.values()), vals)
+
+    def check_by_id(self) -> None:
+        """Test .by_id(), including creation."""
+        # check failure if not yet saved
+        id1, id2 = self.default_ids[0], self.default_ids[1]
+        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        with self.assertRaises(NotFoundException):
+            self.checked_class.by_id(self.db_conn, id1)
+        # check identity of saved and retrieved
+        obj.save(self.db_conn)
+        self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
+        # check create=True acts like normal instantiation (sans saving)
+        by_id_created = self.checked_class.by_id(self.db_conn, id2,
+                                                 create=True)
+        # pylint: disable=not-callable
+        self.assertEqual(self.checked_class(id2), by_id_created)
+        self.check_storage([obj])
+
+    def check_from_table_row(self, *args: Any) -> None:
+        """Test .from_table_row() properly reads in class from DB"""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        assert isinstance(obj.id_, type(self.default_ids[0]))
+        for row in self.db_conn.row_where(self.checked_class.table_name,
+                                          'id', obj.id_):
+            retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertEqual(obj, retrieved)
+            self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
+
+    def check_versioned_from_table_row(self, attr_name: str,
+                                       type_: type) -> None:
+        """Test .from_table_row() reads versioned attributes from DB."""
+        owner = self.checked_class(None)
+        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+        attr = getattr(owner, attr_name)
+        attr.set(vals[0])
+        attr.set(vals[1])
+        owner.save(self.db_conn)
+        for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
+            retrieved = owner.__class__.from_table_row(self.db_conn, row)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
+
+    def check_all(self) -> tuple[Any, Any, Any]:
+        """Test .all()."""
+        # pylint: disable=not-callable
+        item1 = self.checked_class(self.default_ids[0])
+        item2 = self.checked_class(self.default_ids[1])
+        item3 = self.checked_class(self.default_ids[2])
+        # check pre-save .all() returns empty list
+        self.assertEqual(self.checked_class.all(self.db_conn), [])
+        # check that all() shows all saved, but no unsaved items
+        item1.save(self.db_conn)
+        item3.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item3]))
+        item2.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item2, item3]))
+        return item1, item2, item3
+
+    def check_singularity(self, defaulting_field: str,
+                          non_default_value: Any, *args: Any) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        id1 = self.default_ids[0]
+        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        setattr(obj, defaulting_field, non_default_value)
+        retrieved = self.checked_class.by_id(self.db_conn, id1)
+        self.assertEqual(non_default_value,
+                         getattr(retrieved, defaulting_field))
+
+    def check_versioned_singularity(self) -> None:
+        """Test singularity of VersionedAttributes on saving (with .title)."""
+        obj = self.checked_class(None)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        assert isinstance(obj.id_, int)
+        obj.title.set('named')
+        retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
+        self.assertEqual(obj.title.history, retrieved.title.history)
+
+    def check_remove(self, *args: Any) -> None:
+        """Test .remove() effects on DB and cache."""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        with self.assertRaises(HandledException):
+            obj.remove(self.db_conn)
+        obj.save(self.db_conn)
+        obj.remove(self.db_conn)
+        self.check_storage([])
+
 
 class TestCaseWithServer(TestCaseWithDB):
     """Module tests against our HTTP server/handler (and database)."""
@@ -63,7 +234,7 @@ class TestCaseWithServer(TestCaseWithDB):
         self.assertEqual(self.conn.getresponse().status, expected_code)
 
     def check_post(self, data: Mapping[str, object], target: str,
-                   expected_code: int, redirect_location: str = '/') -> None:
+                   expected_code: int, redirect_location: str = '') -> None:
         """Check that POST of data to target yields expected_code."""
         encoded_form_data = urlencode(data, doseq=True).encode('utf-8')
         headers = {'Content-Type': 'application/x-www-form-urlencoded',
@@ -71,6 +242,26 @@ class TestCaseWithServer(TestCaseWithDB):
         self.conn.request('POST', target,
                           body=encoded_form_data, headers=headers)
         if 302 == expected_code:
+            if redirect_location == '':
+                redirect_location = target
             self.check_redirect(redirect_location)
         else:
             self.assertEqual(self.conn.getresponse().status, expected_code)
+
+    def check_get_defaults(self, path: str) -> None:
+        """Some standard model paths to test."""
+        self.check_get(path, 200)
+        self.check_get(f'{path}?id=', 200)
+        self.check_get(f'{path}?id=foo', 400)
+        self.check_get(f'/{path}?id=0', 500)
+        self.check_get(f'{path}?id=1', 200)
+
+    def post_process(self, id_: int = 1,
+                     form_data: dict[str, Any] | None = None
+                     ) -> dict[str, Any]:
+        """POST basic Process."""
+        if not form_data:
+            form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
+        self.check_post(form_data, f'/process?id={id_}', 302,
+                        f'/process?id={id_}')
+        return form_data