home · contact · privacy
More refactoring.
authorChristian Heller <c.heller@plomlompom.de>
Thu, 2 May 2024 04:34:49 +0000 (06:34 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Thu, 2 May 2024 04:34:49 +0000 (06:34 +0200)
plomtask/conditions.py
plomtask/db.py
plomtask/processes.py
scripts/pre-commit
tests/conditions.py
tests/processes.py
tests/utils.py

index 629510af868ae401c609b2c9307fa21f9c732c23..cba606d1a71dd19a0e410677f0ee8fc528d5bdbf 100644 (file)
@@ -19,11 +19,6 @@ class Condition(BaseModel[int]):
         self.description = VersionedAttribute(self, 'condition_descriptions',
                                               '')
 
         self.description = VersionedAttribute(self, 'condition_descriptions',
                                               '')
 
-    def __lt__(self, other: Condition) -> bool:
-        assert isinstance(self.id_, int)
-        assert isinstance(other.id_, int)
-        return self.id_ < other.id_
-
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> Condition:
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
                        row: Row | list[Any]) -> Condition:
index 0509492473c44730ea70fa3e793775cd892e2866..982ddfe3b96915d4ffc8bd8d4bce264671c09069 100644 (file)
@@ -130,6 +130,23 @@ class BaseModel(Generic[BaseModelId]):
             raise HandledException(msg)
         self.id_ = id_
 
             raise HandledException(msg)
         self.id_ = id_
 
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, self.__class__):
+            return False
+        to_hash_me = tuple([self.id_] +
+                           [getattr(self, name) for name in self.to_save])
+        to_hash_other = tuple([other.id_] +
+                              [getattr(other, name) for name in other.to_save])
+        return hash(to_hash_me) == hash(to_hash_other)
+
+    def __lt__(self, other: Any) -> bool:
+        if not isinstance(other, self.__class__):
+            msg = 'cannot compare to object of different class'
+            raise HandledException(msg)
+        assert isinstance(self.id_, int)
+        assert isinstance(other.id_, int)
+        return self.id_ < other.id_
+
     @classmethod
     def get_cached(cls: type[BaseModelInstance],
                    id_: BaseModelId) -> BaseModelInstance | None:
     @classmethod
     def get_cached(cls: type[BaseModelInstance],
                    id_: BaseModelId) -> BaseModelInstance | None:
@@ -228,16 +245,6 @@ class BaseModel(Generic[BaseModelId]):
                 items[item.id_] = item
         return list(items.values())
 
                 items[item.id_] = item
         return list(items.values())
 
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, self.__class__):
-            msg = 'cannot compare to object of different class'
-            raise HandledException(msg)
-        to_hash_me = tuple([self.id_] +
-                           [getattr(self, name) for name in self.to_save])
-        to_hash_other = tuple([other.id_] +
-                              [getattr(other, name) for name in other.to_save])
-        return hash(to_hash_me) == hash(to_hash_other)
-
     def save_core(self, db_conn: DatabaseConnection) -> None:
         """Write bare-bones self (sans connected items), ensuring self.id_.
 
     def save_core(self, db_conn: DatabaseConnection) -> None:
         """Write bare-bones self (sans connected items), ensuring self.id_.
 
index 21e2d8195edaf1a5d5fc1982b6aed26639e2da70..c4ccfa8fd3926491bf1a52fc3c2c36d6d271cd63 100644 (file)
@@ -171,7 +171,10 @@ class Process(BaseModel[int], ConditionsRelations):
             step.save(db_conn)
 
     def remove(self, db_conn: DatabaseConnection) -> None:
             step.save(db_conn)
 
     def remove(self, db_conn: DatabaseConnection) -> None:
-        """Remove from DB, with dependencies."""
+        """Remove from DB, with dependencies.
+
+        Guard against removal of Processes in use.
+        """
         assert isinstance(self.id_, int)
         for _ in db_conn.row_where('process_steps', 'step_process', self.id_):
             raise HandledException('cannot remove Process in use')
         assert isinstance(self.id_, int)
         for _ in db_conn.row_where('process_steps', 'step_process', self.id_):
             raise HandledException('cannot remove Process in use')
index 6f84c41524e31d1aabd689c71dd37550e094ca0c..2aaccb027d613c3d960b634e1bf7747ff0627a2b 100755 (executable)
@@ -1,6 +1,7 @@
 #!/bin/sh
 set -e
 #!/bin/sh
 set -e
-for dir in $(echo '.' 'plomtask' 'tests'); do
+# for dir in $(echo '.' 'plomtask' 'tests'); do
+for dir in $(echo 'tests'); do
     echo "Running mypy on ${dir}/ …."
     python3 -m mypy --strict ${dir}/*.py
     echo "Running flake8 on ${dir}/ …"
     echo "Running mypy on ${dir}/ …."
     python3 -m mypy --strict ${dir}/*.py
     echo "Running flake8 on ${dir}/ …"
index dabcf06f767d02732b972d5b4bf40972f7d9fe66..9c95206aab63cbdd045aaedf1f415ba6779d8afa 100644 (file)
@@ -1,27 +1,29 @@
 """Test Conditions module."""
 """Test Conditions module."""
-from unittest import TestCase
-from tests.utils import TestCaseWithDB, TestCaseWithServer
+from tests.utils import TestCaseWithDB, TestCaseWithServer, TestCaseSansDB
 from plomtask.conditions import Condition
 from plomtask.processes import Process
 from plomtask.todos import Todo
 from plomtask.exceptions import HandledException
 
 
 from plomtask.conditions import Condition
 from plomtask.processes import Process
 from plomtask.todos import Todo
 from plomtask.exceptions import HandledException
 
 
-class TestsSansDB(TestCase):
+class TestsSansDB(TestCaseSansDB):
     """Tests requiring no DB setup."""
     """Tests requiring no DB setup."""
+    checked_class = Condition
 
     def test_Condition_id_setting(self) -> None:
         """Test .id_ being set and its legal range being enforced."""
 
     def test_Condition_id_setting(self) -> None:
         """Test .id_ being set and its legal range being enforced."""
-        with self.assertRaises(HandledException):
-            Condition(0)
-        condition = Condition(5)
-        self.assertEqual(condition.id_, 5)
+        self.check_id_setting()
+
+    def test_Condition_versioned_defaults(self) -> None:
+        """Test defaults of VersionedAttributes."""
+        self.check_versioned_defaults({
+            'title': 'UNNAMED',
+            'description': ''})
 
 
 class TestsWithDB(TestCaseWithDB):
     """Tests requiring DB, but not server setup."""
     checked_class = Condition
 
 
 class TestsWithDB(TestCaseWithDB):
     """Tests requiring DB, but not server setup."""
     checked_class = Condition
-    default_ids = (1, 2, 3)
 
     def versioned_condition(self) -> Condition:
         """Create Condition with some VersionedAttribute values."""
 
     def versioned_condition(self) -> Condition:
         """Create Condition with some VersionedAttribute values."""
@@ -72,6 +74,10 @@ class TestsWithDB(TestCaseWithDB):
         """Test pointers made for single object keep pointing to it."""
         self.check_singularity('is_active', True)
 
         """Test pointers made for single object keep pointing to it."""
         self.check_singularity('is_active', True)
 
+    def test_Condition_versioned_attributes_singularity(self) -> None:
+        """Test behavior of VersionedAttributes on saving (with .title)."""
+        self.check_versioned_singularity()
+
     def test_Condition_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
         self.check_remove()
     def test_Condition_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
         self.check_remove()
index ce7f8571cc7aa84aaa6a8a4855febfeb97d879b9..c3b1144de7563d956e5fd3a6d6474cb6180e726a 100644 (file)
@@ -1,49 +1,66 @@
 """Test Processes module."""
 """Test Processes module."""
-from unittest import TestCase
-from tests.utils import TestCaseWithDB, TestCaseWithServer
+from tests.utils import TestCaseWithDB, TestCaseWithServer, TestCaseSansDB
 from plomtask.processes import Process, ProcessStep, ProcessStepsNode
 from plomtask.conditions import Condition
 from plomtask.processes import Process, ProcessStep, ProcessStepsNode
 from plomtask.conditions import Condition
+from plomtask.exceptions import HandledException, NotFoundException
 from plomtask.todos import Todo
 from plomtask.todos import Todo
-from plomtask.exceptions import NotFoundException, HandledException
 
 
 
 
-class TestsSansDB(TestCase):
+class TestsSansDB(TestCaseSansDB):
     """Module tests not requiring DB setup."""
     """Module tests not requiring DB setup."""
+    checked_class = Process
 
 
-    def test_Process_versioned_defaults(self) -> None:
-        """Test defaults of Process' VersionedAttributes."""
-        self.assertEqual(Process(None).title.newest, 'UNNAMED')
-        self.assertEqual(Process(None).description.newest, '')
-        self.assertEqual(Process(None).effort.newest, 1.0)
+    def test_Process_id_setting(self) -> None:
+        """Test .id_ being set and its legal range being enforced."""
+        self.check_id_setting()
 
 
-    def test_Process_legal_ID(self) -> None:
-        """Test Process cannot be instantiated with id_=0."""
-        with self.assertRaises(HandledException):
-            Process(0)
+    def test_Process_versioned_defaults(self) -> None:
+        """Test defaults of VersionedAttributes."""
+        self.check_versioned_defaults({
+            'title': 'UNNAMED',
+            'description': '',
+            'effort': 1.0})
 
 
 class TestsWithDB(TestCaseWithDB):
     """Module tests requiring DB setup."""
 
 
 class TestsWithDB(TestCaseWithDB):
     """Module tests requiring DB setup."""
-
-    def setUp(self) -> None:
-        super().setUp()
-        self.proc1 = Process(None)
-        self.proc1.save(self.db_conn)
-        self.proc2 = Process(None)
-        self.proc2.save(self.db_conn)
-        self.proc3 = Process(None)
-        self.proc3.save(self.db_conn)
-
-    def test_Process_ids(self) -> None:
-        """Test Process.save() re Process.id_."""
-        self.assertEqual(self.proc1.id_,
-                         Process.by_id(self.db_conn, 1, create=False).id_)
-        self.assertEqual(self.proc2.id_,
-                         Process.by_id(self.db_conn, 2, create=False).id_)
-        proc5 = Process(5)
-        proc5.save(self.db_conn)
-        self.assertEqual(proc5.id_,
-                         Process.by_id(self.db_conn, 5, create=False).id_)
+    checked_class = Process
+
+    def three_processes(self) -> tuple[Process, Process, Process]:
+        """Return three saved processes."""
+        p1, p2, p3 = Process(None), Process(None), Process(None)
+        for p in [p1, p2, p3]:
+            p.save(self.db_conn)
+        return p1, p2, p3
+
+    def test_Process_saving_and_caching(self) -> None:
+        """Test .save/.save_core."""
+        kwargs = {'id_': 1}
+        self.check_saving_and_caching(**kwargs)
+        p = Process(None)
+        p.title.set('t1')
+        p.title.set('t2')
+        p.description.set('d1')
+        p.description.set('d2')
+        p.effort.set(0.5)
+        p.effort.set(1.5)
+        c1, c2, c3 = Condition(None), Condition(None), Condition(None)
+        for c in [c1, c2, c3]:
+            c.save(self.db_conn)
+        assert isinstance(c1.id_, int)
+        assert isinstance(c2.id_, int)
+        assert isinstance(c3.id_, int)
+        p.set_conditions(self.db_conn, [c1.id_, c2.id_])
+        p.set_enables(self.db_conn, [c2.id_, c3.id_])
+        p.set_disables(self.db_conn, [c1.id_, c3.id_])
+        p.save(self.db_conn)
+        r = Process.by_id(self.db_conn, p.id_)
+        self.assertEqual(sorted(r.title.history.values()), ['t1', 't2'])
+        self.assertEqual(sorted(r.description.history.values()), ['d1', 'd2'])
+        self.assertEqual(sorted(r.effort.history.values()), [0.5, 1.5])
+        self.assertEqual(sorted(r.conditions), sorted([c1, c2]))
+        self.assertEqual(sorted(r.enables), sorted([c2, c3]))
+        self.assertEqual(sorted(r.disables), sorted([c1, c3]))
 
     def test_Process_steps(self) -> None:
         """Test addition, nesting, and non-recursion of ProcessSteps"""
 
     def test_Process_steps(self) -> None:
         """Test addition, nesting, and non-recursion of ProcessSteps"""
@@ -54,133 +71,134 @@ class TestsWithDB(TestCaseWithDB):
             steps_proc += [step_tuple]
             proc.set_steps(self.db_conn, steps_proc)
             steps_proc[-1] = (expected_id, step_tuple[1], step_tuple[2])
             steps_proc += [step_tuple]
             proc.set_steps(self.db_conn, steps_proc)
             steps_proc[-1] = (expected_id, step_tuple[1], step_tuple[2])
-        assert isinstance(self.proc2.id_, int)
-        assert isinstance(self.proc3.id_, int)
-        steps_proc1: list[tuple[int | None, int, int | None]] = []
-        add_step(self.proc1, steps_proc1, (None, self.proc2.id_, None), 1)
-        p_1_dict: dict[int, ProcessStepsNode] = {}
-        p_1_dict[1] = ProcessStepsNode(self.proc2, None, True, {}, False)
-        self.assertEqual(self.proc1.get_steps(self.db_conn, None), p_1_dict)
-        add_step(self.proc1, steps_proc1, (None, self.proc3.id_, None), 2)
-        step_2 = self.proc1.explicit_steps[-1]
-        assert isinstance(step_2.id_, int)
-        p_1_dict[2] = ProcessStepsNode(self.proc3, None, True, {}, False)
-        self.assertEqual(self.proc1.get_steps(self.db_conn, None), p_1_dict)
-        steps_proc2: list[tuple[int | None, int, int | None]] = []
-        add_step(self.proc2, steps_proc2, (None, self.proc3.id_, None), 3)
-        p_1_dict[1].steps[3] = ProcessStepsNode(self.proc3, None,
-                                                False, {}, False)
-        self.assertEqual(self.proc1.get_steps(self.db_conn, None), p_1_dict)
-        add_step(self.proc1, steps_proc1, (None, self.proc2.id_, step_2.id_),
-                 4)
-        step_3 = ProcessStepsNode(self.proc3, None, False, {}, True)
-        p_1_dict[2].steps[4] = ProcessStepsNode(self.proc2, step_2.id_, True,
-                                                {3: step_3}, False)
-        self.assertEqual(self.proc1.get_steps(self.db_conn, None), p_1_dict)
-        add_step(self.proc1, steps_proc1, (None, self.proc3.id_, 999), 5)
-        p_1_dict[5] = ProcessStepsNode(self.proc3, None, True, {}, False)
-        self.assertEqual(self.proc1.get_steps(self.db_conn, None), p_1_dict)
-        add_step(self.proc1, steps_proc1, (None, self.proc3.id_, 3), 6)
-        p_1_dict[6] = ProcessStepsNode(self.proc3, None, True, {}, False)
-        self.assertEqual(self.proc1.get_steps(self.db_conn, None),
-                         p_1_dict)
-        self.assertEqual(self.proc1.used_as_step_by(self.db_conn),
-                         [])
-        self.assertEqual(self.proc2.used_as_step_by(self.db_conn),
-                         [self.proc1])
-        self.assertEqual(self.proc3.used_as_step_by(self.db_conn),
-                         [self.proc1, self.proc2])
+        p1, p2, p3 = self.three_processes()
+        assert isinstance(p1.id_, int)
+        assert isinstance(p2.id_, int)
+        assert isinstance(p3.id_, int)
+        steps_p1: list[tuple[int | None, int, int | None]] = []
+        add_step(p1, steps_p1, (None, p2.id_, None), 1)
+        p1_dict: dict[int, ProcessStepsNode] = {}
+        p1_dict[1] = ProcessStepsNode(p2, None, True, {}, False)
+        self.assertEqual(p1.get_steps(self.db_conn, None), p1_dict)
+        add_step(p1, steps_p1, (None, p3.id_, None), 2)
+        step_2 = p1.explicit_steps[-1]
+        p1_dict[2] = ProcessStepsNode(p3, None, True, {}, False)
+        self.assertEqual(p1.get_steps(self.db_conn, None), p1_dict)
+        steps_p2: list[tuple[int | None, int, int | None]] = []
+        add_step(p2, steps_p2, (None, p3.id_, None), 3)
+        p1_dict[1].steps[3] = ProcessStepsNode(p3, None, False, {}, False)
+        self.assertEqual(p1.get_steps(self.db_conn, None), p1_dict)
+        add_step(p1, steps_p1, (None, p2.id_, step_2.id_), 4)
+        step_3 = ProcessStepsNode(p3, None, False, {}, True)
+        p1_dict[2].steps[4] = ProcessStepsNode(p2, step_2.id_, True,
+                                               {3: step_3}, False)
+        self.assertEqual(p1.get_steps(self.db_conn, None), p1_dict)
+        add_step(p1, steps_p1, (None, p3.id_, 999), 5)
+        p1_dict[5] = ProcessStepsNode(p3, None, True, {}, False)
+        self.assertEqual(p1.get_steps(self.db_conn, None), p1_dict)
+        add_step(p1, steps_p1, (None, p3.id_, 3), 6)
+        p1_dict[6] = ProcessStepsNode(p3, None, True, {}, False)
+        self.assertEqual(p1.get_steps(self.db_conn, None), p1_dict)
+        self.assertEqual(p1.used_as_step_by(self.db_conn), [])
+        self.assertEqual(p2.used_as_step_by(self.db_conn), [p1])
+        self.assertEqual(p3.used_as_step_by(self.db_conn), [p1, p2])
 
     def test_Process_conditions(self) -> None:
         """Test setting Process.conditions/enables/disables."""
 
     def test_Process_conditions(self) -> None:
         """Test setting Process.conditions/enables/disables."""
+        p = Process(None)
+        p.save(self.db_conn)
         for target in ('conditions', 'enables', 'disables'):
         for target in ('conditions', 'enables', 'disables'):
-            c1 = Condition(None, False)
+            method = getattr(p, f'set_{target}')
+            c1, c2 = Condition(None), Condition(None)
             c1.save(self.db_conn)
             c1.save(self.db_conn)
-            assert isinstance(c1.id_, int)
-            c2 = Condition(None, False)
             c2.save(self.db_conn)
             c2.save(self.db_conn)
+            assert isinstance(c1.id_, int)
             assert isinstance(c2.id_, int)
             assert isinstance(c2.id_, int)
-            self.proc1.set_conditions(self.db_conn, [], target)
-            self.assertEqual(getattr(self.proc1, target), [])
-            self.proc1.set_conditions(self.db_conn, [c1.id_], target)
-            self.assertEqual(getattr(self.proc1, target), [c1])
-            self.proc1.set_conditions(self.db_conn, [c2.id_], target)
-            self.assertEqual(getattr(self.proc1, target), [c2])
-            self.proc1.set_conditions(self.db_conn, [c1.id_, c2.id_], target)
-            self.assertEqual(getattr(self.proc1, target), [c1, c2])
+            method(self.db_conn, [])
+            self.assertEqual(getattr(p, target), [])
+            method(self.db_conn, [c1.id_])
+            self.assertEqual(getattr(p, target), [c1])
+            method(self.db_conn, [c2.id_])
+            self.assertEqual(getattr(p, target), [c2])
+            method(self.db_conn, [c1.id_, c2.id_])
+            self.assertEqual(getattr(p, target), [c1, c2])
 
     def test_Process_by_id(self) -> None:
 
     def test_Process_by_id(self) -> None:
-        """Test Process.by_id()."""
-        with self.assertRaises(NotFoundException):
-            Process.by_id(self.db_conn, None, create=False)
-        with self.assertRaises(NotFoundException):
-            Process.by_id(self.db_conn, 0, create=False)
-        self.assertNotEqual(self.proc1.id_,
-                            Process.by_id(self.db_conn, None, create=True).id_)
-        self.assertEqual(Process(2).id_,
-                         Process.by_id(self.db_conn, 2, create=True).id_)
+        """Test .by_id(), including creation"""
+        self.check_by_id()
 
     def test_Process_all(self) -> None:
 
     def test_Process_all(self) -> None:
-        """Test Process.all()."""
-        self.assertEqual({self.proc1.id_, self.proc2.id_, self.proc3.id_},
-                         set(p.id_ for p in Process.all(self.db_conn)))
-
-    def test_ProcessStep_singularity(self) -> None:
-        """Test pointers made for single object keep pointing to it."""
-        assert isinstance(self.proc2.id_, int)
-        self.proc1.set_steps(self.db_conn, [(None, self.proc2.id_, None)])
-        step = self.proc1.explicit_steps[-1]
-        assert isinstance(step.id_, int)
-        step_retrieved = ProcessStep.by_id(self.db_conn, step.id_)
-        step.parent_step_id = 99
-        self.assertEqual(step.parent_step_id, step_retrieved.parent_step_id)
+        """Test .all()."""
+        self.check_all()
 
     def test_Process_singularity(self) -> None:
 
     def test_Process_singularity(self) -> None:
-        """Test pointers made for single object keep pointing to it, and
-        subsequent retrievals don't overload relations."""
-        assert isinstance(self.proc1.id_, int)
-        assert isinstance(self.proc2.id_, int)
-        c1 = Condition(None, False)
-        c1.save(self.db_conn)
-        assert isinstance(c1.id_, int)
-        self.proc1.set_conditions(self.db_conn, [c1.id_])
-        self.proc1.set_steps(self.db_conn, [(None, self.proc2.id_, None)])
-        self.proc1.save(self.db_conn)
-        p_retrieved = Process.by_id(self.db_conn, self.proc1.id_)
-        self.assertEqual(self.proc1.explicit_steps, p_retrieved.explicit_steps)
-        self.assertEqual(self.proc1.conditions, p_retrieved.conditions)
-        self.proc1.save(self.db_conn)
+        """Test pointers made for single object keep pointing to it."""
+        self.check_singularity('conditions', [Condition(None)])
 
     def test_Process_versioned_attributes_singularity(self) -> None:
         """Test behavior of VersionedAttributes on saving (with .title)."""
 
     def test_Process_versioned_attributes_singularity(self) -> None:
         """Test behavior of VersionedAttributes on saving (with .title)."""
-        assert isinstance(self.proc1.id_, int)
-        self.proc1.title.set('named')
-        p_loaded = Process.by_id(self.db_conn, self.proc1.id_)
-        self.assertEqual(self.proc1.title.history, p_loaded.title.history)
+        self.check_versioned_singularity()
 
     def test_Process_removal(self) -> None:
         """Test removal of Processes and ProcessSteps."""
 
     def test_Process_removal(self) -> None:
         """Test removal of Processes and ProcessSteps."""
-        assert isinstance(self.proc3.id_, int)
-        self.proc1.remove(self.db_conn)
-        self.assertEqual({self.proc2.id_, self.proc3.id_},
-                         set(p.id_ for p in Process.all(self.db_conn)))
-        self.proc2.set_steps(self.db_conn, [(None, self.proc3.id_, None)])
+        self.check_remove()
+        p1, p2, p3 = self.three_processes()
+        assert isinstance(p1.id_, int)
+        assert isinstance(p2.id_, int)
+        assert isinstance(p3.id_, int)
+        p2.set_steps(self.db_conn, [(None, p1.id_, None)])
         with self.assertRaises(HandledException):
         with self.assertRaises(HandledException):
-            self.proc3.remove(self.db_conn)
-        self.proc2.explicit_steps[0].remove(self.db_conn)
-        retrieved = Process.by_id(self.db_conn, self.proc2.id_)
-        self.assertEqual(retrieved.explicit_steps, [])
-        self.proc2.set_steps(self.db_conn, [(None, self.proc3.id_, None)])
-        step = retrieved.explicit_steps[0]
-        self.proc2.remove(self.db_conn)
+            p1.remove(self.db_conn)
+        step = p2.explicit_steps[0]
+        p2.set_steps(self.db_conn, [])
+        with self.assertRaises(NotFoundException):
+            ProcessStep.by_id(self.db_conn, step.id_)
+        p1.remove(self.db_conn)
+        p2.set_steps(self.db_conn, [(None, p3.id_, None)])
+        step = p2.explicit_steps[0]
+        p2.remove(self.db_conn)
         with self.assertRaises(NotFoundException):
             ProcessStep.by_id(self.db_conn, step.id_)
         with self.assertRaises(NotFoundException):
             ProcessStep.by_id(self.db_conn, step.id_)
-        todo = Todo(None, self.proc3, False, '2024-01-01')
+        todo = Todo(None, p3, False, '2024-01-01')
         todo.save(self.db_conn)
         with self.assertRaises(HandledException):
         todo.save(self.db_conn)
         with self.assertRaises(HandledException):
-            self.proc3.remove(self.db_conn)
+            p3.remove(self.db_conn)
         todo.remove(self.db_conn)
         todo.remove(self.db_conn)
-        self.proc3.remove(self.db_conn)
+        p3.remove(self.db_conn)
+
+
+class TestsWithDBForProcessStep(TestCaseWithDB):
+    """Module tests requiring DB setup."""
+    checked_class = ProcessStep
+
+    def test_ProcessStep_saving_and_caching(self) -> None:
+        """Test .save/.save_core."""
+        kwargs = {'id_': 1,
+                  'owner_id': 2,
+                  'step_process_id': 3,
+                  'parent_step_id': 4}
+        self.check_saving_and_caching(**kwargs)
+
+    def test_ProcessStep_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class from DB"""
+        self.check_from_table_row(2, 3, None)
+
+    def test_ProcessStep_singularity(self) -> None:
+        """Test pointers made for single object keep pointing to it."""
+        self.check_singularity('parent_step_id', 1, 2, 3, None)
+
+    def test_ProcessStep_remove(self) -> None:
+        """Test .remove and unsetting of owner's .explicit_steps entry."""
+        p1 = Process(None)
+        p2 = Process(None)
+        p1.save(self.db_conn)
+        p2.save(self.db_conn)
+        assert isinstance(p2.id_, int)
+        p1.set_steps(self.db_conn, [(None, p2.id_, None)])
+        step = p1.explicit_steps[0]
+        step.remove(self.db_conn)
+        self.assertEqual(p1.explicit_steps, [])
+        self.check_storage([])
 
 
 class TestsWithServer(TestCaseWithServer):
 
 
 class TestsWithServer(TestCaseWithServer):
index 61dbb36b949ee3fdfd5a38dcada155a5ab18a924..ccb485ad104cc37ae040a4b144093e69c75b87c0 100644 (file)
@@ -15,10 +15,28 @@ from plomtask.todos import Todo
 from plomtask.exceptions import NotFoundException, HandledException
 
 
 from plomtask.exceptions import NotFoundException, HandledException
 
 
+class TestCaseSansDB(TestCase):
+    """Tests requiring no DB setup."""
+    checked_class: Any
+
+    def check_id_setting(self) -> None:
+        """Test .id_ being set and its legal range being enforced."""
+        with self.assertRaises(HandledException):
+            self.checked_class(0)
+        obj = self.checked_class(5)
+        self.assertEqual(obj.id_, 5)
+
+    def check_versioned_defaults(self, attrs: dict[str, Any]) -> None:
+        """Test defaults of VersionedAttributes."""
+        obj = self.checked_class(None)
+        for k, v in attrs.items():
+            self.assertEqual(getattr(obj, k).newest, v)
+
+
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
     checked_class: Any
 class TestCaseWithDB(TestCase):
     """Module tests not requiring DB setup."""
     checked_class: Any
-    default_ids: tuple[int | str, int | str, int | str]
+    default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
 
     def setUp(self) -> None:
         Condition.empty_cache()
 
     def setUp(self) -> None:
         Condition.empty_cache()
@@ -43,7 +61,7 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class.get_cache(), expected_cache)
         db_found: list[Any] = []
         for item in content:
         self.assertEqual(self.checked_class.get_cache(), expected_cache)
         db_found: list[Any] = []
         for item in content:
-            assert isinstance(item.id_, (str, int))
+            assert isinstance(item.id_, type(self.default_ids[0]))
             for row in self.db_conn.row_where(self.checked_class.table_name,
                                               'id', item.id_):
                 db_found += [self.checked_class.from_table_row(self.db_conn,
             for row in self.db_conn.row_where(self.checked_class.table_name,
                                               'id', item.id_):
                 db_found += [self.checked_class.from_table_row(self.db_conn,
@@ -79,12 +97,12 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
         self.assertEqual(self.checked_class(id2), by_id_created)
         self.check_storage([obj])
 
-    def check_from_table_row(self) -> None:
+    def check_from_table_row(self, *args: Any) -> None:
         """Test .from_table_row() properly reads in class from DB"""
         id_ = self.default_ids[0]
         """Test .from_table_row() properly reads in class from DB"""
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
         obj.save(self.db_conn)
         obj.save(self.db_conn)
-        assert isinstance(obj.id_, (str, int))
+        assert isinstance(obj.id_, type(self.default_ids[0]))
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
             retrieved = self.checked_class.from_table_row(self.db_conn, row)
@@ -110,20 +128,29 @@ class TestCaseWithDB(TestCase):
         return item1, item2, item3
 
     def check_singularity(self, defaulting_field: str,
         return item1, item2, item3
 
     def check_singularity(self, defaulting_field: str,
-                          non_default_value: Any) -> None:
+                          non_default_value: Any, *args: Any) -> None:
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
         """Test pointers made for single object keep pointing to it."""
         id1 = self.default_ids[0]
-        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        obj = self.checked_class(id1, *args)  # pylint: disable=not-callable
         obj.save(self.db_conn)
         setattr(obj, defaulting_field, non_default_value)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         self.assertEqual(non_default_value,
                          getattr(retrieved, defaulting_field))
 
         obj.save(self.db_conn)
         setattr(obj, defaulting_field, non_default_value)
         retrieved = self.checked_class.by_id(self.db_conn, id1)
         self.assertEqual(non_default_value,
                          getattr(retrieved, defaulting_field))
 
-    def check_remove(self) -> None:
+    def check_versioned_singularity(self) -> None:
+        """Test singularity of VersionedAttributes on saving (with .title)."""
+        obj = self.checked_class(None)  # pylint: disable=not-callable
+        obj.save(self.db_conn)
+        assert isinstance(obj.id_, int)
+        obj.title.set('named')
+        retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
+        self.assertEqual(obj.title.history, retrieved.title.history)
+
+    def check_remove(self, *args: Any) -> None:
         """Test .remove() effects on DB and cache."""
         id_ = self.default_ids[0]
         """Test .remove() effects on DB and cache."""
         id_ = self.default_ids[0]
-        obj = self.checked_class(id_)  # pylint: disable=not-callable
+        obj = self.checked_class(id_, *args)  # pylint: disable=not-callable
         with self.assertRaises(HandledException):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         with self.assertRaises(HandledException):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)