home · contact · privacy
Enable server to alternatively output response ctx as JSON, for debugging and testing...
[plomtask] / tests / utils.py
index 2a919a31dbb057f4db558861d2fafe8277db5521..15a53ae0ddc0b78835b5baacea15f43d3a81cba0 100644 (file)
@@ -3,19 +3,49 @@ from unittest import TestCase
 from threading import Thread
 from http.client import HTTPConnection
 from urllib.parse import urlencode
-from datetime import datetime
+from uuid import uuid4
 from os import remove as remove_file
-from typing import Mapping
+from typing import Mapping, Any
 from plomtask.db import DatabaseFile, DatabaseConnection
 from plomtask.http import TaskHandler, TaskServer
 from plomtask.processes import Process, ProcessStep
 from plomtask.conditions import Condition
 from plomtask.days import Day
 from plomtask.todos import Todo
+from plomtask.exceptions import NotFoundException, HandledException
+
+
+class TestCaseSansDB(TestCase):
+    """Tests requiring no DB setup."""
+    checked_class: Any
+    do_id_test: bool = False
+    default_init_args: list[Any] = []
+    versioned_defaults_to_test: dict[str, str | float] = {}
+
+    def test_id_setting(self) -> None:
+        """Test .id_ being set and its legal range being enforced."""
+        if not self.do_id_test:
+            return
+        with self.assertRaises(HandledException):
+            self.checked_class(0, *self.default_init_args)
+        obj = self.checked_class(5, *self.default_init_args)
+        self.assertEqual(obj.id_, 5)
+
+    def test_versioned_defaults(self) -> None:
+        """Test defaults of VersionedAttributes."""
+        if len(self.versioned_defaults_to_test) == 0:
+            return
+        obj = self.checked_class(1, *self.default_init_args)
+        for k, v in self.versioned_defaults_to_test.items():
+            self.assertEqual(getattr(obj, k).newest, v)
 
 
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
+    checked_class: Any
+    default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
+    default_init_kwargs: dict[str, Any] = {}
+    test_versioneds: dict[str, type] = {}
 
     def setUp(self) -> None:
         Condition.empty_cache()
@@ -23,15 +53,158 @@ class TestCaseWithDB(TestCase):
         Process.empty_cache()
         ProcessStep.empty_cache()
         Todo.empty_cache()
-        timestamp = datetime.now().timestamp()
-        self.db_file = DatabaseFile(f'test_db:{timestamp}')
-        self.db_file.remake()
+        self.db_file = DatabaseFile.create_at(f'test_db:{uuid4()}')
         self.db_conn = DatabaseConnection(self.db_file)
 
     def tearDown(self) -> None:
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    def test_saving_and_caching(self) -> None:
+        """Test storage and initialization of instances and attributes."""
+        if not hasattr(self, 'checked_class'):
+            return
+        self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
+        obj = self.checked_class(None, **self.default_init_kwargs)
+        obj.save(self.db_conn)
+        self.assertEqual(obj.id_, 2)
+        for k, v in self.test_versioneds.items():
+            self.check_saving_of_versioned(k, v)
+
+    def check_storage(self, content: list[Any]) -> None:
+        """Test cache and DB equal content."""
+        expected_cache = {}
+        for item in content:
+            expected_cache[item.id_] = item
+        self.assertEqual(self.checked_class.get_cache(), expected_cache)
+        hashes_content = [hash(x) for x in content]
+        db_found: list[Any] = []
+        for item in content:
+            assert isinstance(item.id_, type(self.default_ids[0]))
+            for row in self.db_conn.row_where(self.checked_class.table_name,
+                                              'id', item.id_):
+                db_found += [self.checked_class.from_table_row(self.db_conn,
+                                                               row)]
+        hashes_db_found = [hash(x) for x in db_found]
+        self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
+
+    def check_saving_and_caching(self, **kwargs: Any) -> None:
+        """Test instance.save in its core without relations."""
+        obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
+        # check object init itself doesn't store anything yet
+        self.check_storage([])
+        # check saving sets core attributes properly
+        obj.save(self.db_conn)
+        for key, value in kwargs.items():
+            self.assertEqual(getattr(obj, key), value)
+        # check saving stored properly in cache and DB
+        self.check_storage([obj])
+
+    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
+        """Test owner's versioned attributes."""
+        owner = self.checked_class(None)
+        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+        attr = getattr(owner, attr_name)
+        attr.set(vals[0])
+        attr.set(vals[1])
+        owner.save(self.db_conn)
+        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+        attr = getattr(retrieved, attr_name)
+        self.assertEqual(sorted(attr.history.values()), vals)
+
+    def check_by_id(self) -> None:
+        """Test .by_id(), including creation."""
+        # check failure if not yet saved
+        id1, id2 = self.default_ids[0], self.default_ids[1]
+        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        with self.assertRaises(NotFoundException):
+            self.checked_class.by_id(self.db_conn, id1)
+        # check identity of saved and retrieved
+        obj.save(self.db_conn)
+        self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
+        # check create=True acts like normal instantiation (sans saving)
+        by_id_created = self.checked_class.by_id(self.db_conn, id2,
+                                                 create=True)
+        # pylint: disable=not-callable
+        self.assertEqual(self.checked_class(id2), by_id_created)
+        self.check_storage([obj])
+
+    def check_from_table_row(self, *args: Any) -> None:
+        """Test .from_table_row() properly reads in class from DB"""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        assert isinstance(obj.id_, type(self.default_ids[0]))
+        for row in self.db_conn.row_where(self.checked_class.table_name,
+                                          'id', obj.id_):
+            hash_original = hash(obj)
+            retrieved = self.checked_class.from_table_row(self.db_conn, row)
+            self.assertEqual(hash_original, hash(retrieved))
+            self.assertEqual({retrieved.id_: retrieved},
+                             self.checked_class.get_cache())
+
+    def check_versioned_from_table_row(self, attr_name: str,
+                                       type_: type) -> None:
+        """Test .from_table_row() reads versioned attributes from DB."""
+        owner = self.checked_class(None)
+        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+        attr = getattr(owner, attr_name)
+        attr.set(vals[0])
+        attr.set(vals[1])
+        owner.save(self.db_conn)
+        for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
+            retrieved = owner.__class__.from_table_row(self.db_conn, row)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
+
+    def check_all(self) -> tuple[Any, Any, Any]:
+        """Test .all()."""
+        # pylint: disable=not-callable
+        item1 = self.checked_class(self.default_ids[0])
+        item2 = self.checked_class(self.default_ids[1])
+        item3 = self.checked_class(self.default_ids[2])
+        # check pre-save .all() returns empty list
+        self.assertEqual(self.checked_class.all(self.db_conn), [])
+        # check that all() shows all saved, but no unsaved items
+        item1.save(self.db_conn)
+        item3.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item3]))
+        item2.save(self.db_conn)
+        self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
+                         sorted([item1, item2, item3]))
+        return item1, item2, item3
+
+    def check_singularity(self, defaulting_field: str,
+                          non_default_value: Any, *args: Any) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        id1 = self.default_ids[0]
+        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        setattr(obj, defaulting_field, non_default_value)
+        retrieved = self.checked_class.by_id(self.db_conn, id1)
+        self.assertEqual(non_default_value,
+                         getattr(retrieved, defaulting_field))
+
+    def check_versioned_singularity(self) -> None:
+        """Test singularity of VersionedAttributes on saving (with .title)."""
+        obj = self.checked_class(None)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        assert isinstance(obj.id_, int)
+        obj.title.set('named')
+        retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
+        self.assertEqual(obj.title.history, retrieved.title.history)
+
+    def check_remove(self, *args: Any) -> None:
+        """Test .remove() effects on DB and cache."""
+        id_ = self.default_ids[0]
+        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
+        with self.assertRaises(HandledException):
+            obj.remove(self.db_conn)
+        obj.save(self.db_conn)
+        obj.remove(self.db_conn)
+        self.check_storage([])
+
 
 class TestCaseWithServer(TestCaseWithDB):
     """Module tests against our HTTP server/handler (and database)."""
@@ -44,6 +217,7 @@ class TestCaseWithServer(TestCaseWithDB):
         self.server_thread.start()
         self.conn = HTTPConnection(str(self.httpd.server_address[0]),
                                    self.httpd.server_address[1])
+        self.httpd.set_json_mode()
 
     def tearDown(self) -> None:
         self.httpd.shutdown()
@@ -70,9 +244,27 @@ class TestCaseWithServer(TestCaseWithDB):
                    'Content-Length': str(len(encoded_form_data))}
         self.conn.request('POST', target,
                           body=encoded_form_data, headers=headers)
-        if redirect_location == '':
-            redirect_location = target
         if 302 == expected_code:
+            if redirect_location == '':
+                redirect_location = target
             self.check_redirect(redirect_location)
         else:
             self.assertEqual(self.conn.getresponse().status, expected_code)
+
+    def check_get_defaults(self, path: str) -> None:
+        """Some standard model paths to test."""
+        self.check_get(path, 200)
+        self.check_get(f'{path}?id=', 200)
+        self.check_get(f'{path}?id=foo', 400)
+        self.check_get(f'/{path}?id=0', 500)
+        self.check_get(f'{path}?id=1', 200)
+
+    def post_process(self, id_: int = 1,
+                     form_data: dict[str, Any] | None = None
+                     ) -> dict[str, Any]:
+        """POST basic Process."""
+        if not form_data:
+            form_data = {'title': 'foo', 'description': 'foo', 'effort': 1.1}
+        self.check_post(form_data, f'/process?id={id_}', 302,
+                        f'/process?id={id_}')
+        return form_data