home · contact · privacy
Slightly improve and re-organize Condition tests.
[plomtask] / tests / utils.py
index 0925b2d5b2adc0e415293526a4b01c04fc42b178..665436873c27af704a13827715d3c795e04e1fe1 100644 (file)
@@ -17,27 +17,36 @@ from plomtask.todos import Todo
 from plomtask.exceptions import NotFoundException, HandledException
 
 
+def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+    def wrapper(self: TestCase) -> None:
+        if hasattr(self, 'checked_class'):
+            f(self)
+    return wrapper
+
+
 class TestCaseSansDB(TestCase):
     """Tests requiring no DB setup."""
     checked_class: Any
-    do_id_test: bool = False
     default_init_args: list[Any] = []
     versioned_defaults_to_test: dict[str, str | float] = {}
+    legal_ids = [1, 5]
+    illegal_ids = [0]
 
-    def test_id_setting(self) -> None:
-        """Test .id_ being set and its legal range being enforced."""
-        if not self.do_id_test:
-            return
-        with self.assertRaises(HandledException):
-            self.checked_class(0, *self.default_init_args)
-        obj = self.checked_class(5, *self.default_init_args)
-        self.assertEqual(obj.id_, 5)
+    @_within_checked_class
+    def test_id_validation(self) -> None:
+        """Test .id_ validation/setting."""
+        for id_ in self.illegal_ids:
+            with self.assertRaises(HandledException):
+                self.checked_class(id_, *self.default_init_args)
+        for id_ in self.legal_ids:
+            obj = self.checked_class(id_, *self.default_init_args)
+            self.assertEqual(obj.id_, id_)
 
+    @_within_checked_class
     def test_versioned_defaults(self) -> None:
         """Test defaults of VersionedAttributes."""
-        if len(self.versioned_defaults_to_test) == 0:
-            return
-        obj = self.checked_class(1, *self.default_init_args)
+        id_ = self.legal_ids[0]
+        obj = self.checked_class(id_, *self.default_init_args)
         for k, v in self.versioned_defaults_to_test.items():
             self.assertEqual(getattr(obj, k).newest, v)
 
@@ -62,13 +71,6 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
-    @staticmethod
-    def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
-        def wrapper(self: TestCaseWithDB) -> None:
-            if hasattr(self, 'checked_class'):
-                f(self)
-        return wrapper
-
     def _load_from_db(self, id_: int | str) -> list[object]:
         db_found: list[object] = []
         for row in self.db_conn.row_where(self.checked_class.table_name,
@@ -325,6 +327,71 @@ class TestCaseWithServer(TestCaseWithDB):
         self.server_thread.join()
         super().tearDown()
 
+    @staticmethod
+    def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
+        """Return list of only 'id' fields of items."""
+        id_list = []
+        for item in items:
+            assert isinstance(item['id'], (int, str))
+            id_list += [item['id']]
+        return id_list
+
+    @staticmethod
+    def as_refs(items: list[dict[str, object]]
+                ) -> dict[str, dict[str, object]]:
+        """Return dictionary of items by their 'id' fields."""
+        refs = {}
+        for item in items:
+            refs[str(item['id'])] = item
+        return refs
+
+    @staticmethod
+    def cond_as_dict(id_: int = 1,
+                     is_active: bool = False,
+                     titles: None | list[str] = None,
+                     descriptions: None | list[str] = None
+                     ) -> dict[str, object]:
+        """Return JSON of Condition to expect."""
+        d = {'id': id_,
+             'is_active': is_active,
+             '_versioned': {
+                 'title': {},
+                 'description': {}}}
+        titles = titles if titles else []
+        descriptions = descriptions if descriptions else []
+        assert isinstance(d['_versioned'], dict)
+        for i, title in enumerate(titles):
+            d['_versioned']['title'][i] = title
+        for i, description in enumerate(descriptions):
+            d['_versioned']['description'][i] = description
+        return d
+
+    @staticmethod
+    def proc_as_dict(id_: int = 1,
+                     title: str = 'A',
+                     description: str = '',
+                     effort: float = 1.0,
+                     conditions: None | list[int] = None,
+                     disables: None | list[int] = None,
+                     blockers: None | list[int] = None,
+                     enables: None | list[int] = None
+                     ) -> dict[str, object]:
+        """Return JSON of Process to expect."""
+        # pylint: disable=too-many-arguments
+        d = {'id': id_,
+             'calendarize': False,
+             'suppressed_steps': [],
+             'explicit_steps': [],
+             '_versioned': {
+                 'title': {0: title},
+                 'description': {0: description},
+                 'effort': {0: effort}},
+             'conditions': conditions if conditions else [],
+             'disables': disables if disables else [],
+             'enables': enables if enables else [],
+             'blockers': blockers if blockers else []}
+        return d
+
     def check_redirect(self, target: str) -> None:
         """Check that self.conn answers with a 302 redirect to target."""
         response = self.conn.getresponse()