From: Christian Heller <c.heller@plomlompom.de>
Date: Sun, 5 May 2024 03:36:31 +0000 (+0200)
Subject: Refactor and extend tests.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/static/%7B%7Btodo.comment%7D%7D?a=commitdiff_plain;h=80491fac3c476788d90010812c9ba0b95701e09b;p=plomtask

Refactor and extend tests.
---

diff --git a/scripts/pre-commit b/scripts/pre-commit
index 6f84c41..2aaccb0 100755
--- a/scripts/pre-commit
+++ b/scripts/pre-commit
@@ -1,6 +1,7 @@
 #!/bin/sh
 set -e
-for dir in $(echo '.' 'plomtask' 'tests'); do
+# for dir in $(echo '.' 'plomtask' 'tests'); do
+for dir in $(echo 'tests'); do
     echo "Running mypy on ${dir}/ …."
     python3 -m mypy --strict ${dir}/*.py
     echo "Running flake8 on ${dir}/ …"
diff --git a/tests/__pycache__/conditions.cpython-311.pyc b/tests/__pycache__/conditions.cpython-311.pyc
index 4dd00fd..d671a08 100644
Binary files a/tests/__pycache__/conditions.cpython-311.pyc and b/tests/__pycache__/conditions.cpython-311.pyc differ
diff --git a/tests/__pycache__/processes.cpython-311.pyc b/tests/__pycache__/processes.cpython-311.pyc
index 92b1fcd..4f4c7f1 100644
Binary files a/tests/__pycache__/processes.cpython-311.pyc and b/tests/__pycache__/processes.cpython-311.pyc differ
diff --git a/tests/__pycache__/utils.cpython-311.pyc b/tests/__pycache__/utils.cpython-311.pyc
index 722fd58..5c85fd8 100644
Binary files a/tests/__pycache__/utils.cpython-311.pyc and b/tests/__pycache__/utils.cpython-311.pyc differ
diff --git a/tests/__pycache__/versioned_attributes.cpython-311.pyc b/tests/__pycache__/versioned_attributes.cpython-311.pyc
index 2e15b71..7a33722 100644
Binary files a/tests/__pycache__/versioned_attributes.cpython-311.pyc and b/tests/__pycache__/versioned_attributes.cpython-311.pyc differ
diff --git a/tests/conditions.py b/tests/conditions.py
index faaeb87..45c3df7 100644
--- a/tests/conditions.py
+++ b/tests/conditions.py
@@ -25,42 +25,22 @@ class TestsWithDB(TestCaseWithDB):
     """Tests requiring DB, but not server setup."""
     checked_class = Condition
 
-    def versioned_condition(self) -> Condition:
-        """Create Condition with some VersionedAttribute values."""
-        c = Condition(None)
-        c.title.set('title1')
-        c.title.set('title2')
-        c.description.set('desc1')
-        c.description.set('desc2')
-        return c
-
     def test_Condition_saving_and_caching(self) -> None:
         """Test .save/.save_core."""
         kwargs = {'id_': 1, 'is_active': False}
         self.check_saving_and_caching(**kwargs)
         # check .id_ set if None, and versioned attributes too
-        c = self.versioned_condition()
+        c = Condition(None)
         c.save(self.db_conn)
         self.assertEqual(c.id_, 2)
-        self.assertEqual(sorted(c.title.history.values()),
-                         ['title1', 'title2'])
-        self.assertEqual(sorted(c.description.history.values()),
-                         ['desc1', 'desc2'])
+        self.check_saving_of_versioned('title', str)
+        self.check_saving_of_versioned('description', str)
 
     def test_Condition_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class from DB"""
         self.check_from_table_row()
-        c = self.versioned_condition()
-        c.save(self.db_conn)
-        assert isinstance(c.id_, int)
-        for row in self.db_conn.row_where(Condition.table_name, 'id', c.id_):
-            retrieved = Condition.from_table_row(self.db_conn, row)
-            # pylint: disable=no-member
-            self.assertEqual(sorted(retrieved.title.history.values()),
-                             ['title1', 'title2'])
-            # pylint: disable=no-member
-            self.assertEqual(sorted(retrieved.description.history.values()),
-                             ['desc1', 'desc2'])
+        self.check_versioned_from_table_row('title', str)
+        self.check_versioned_from_table_row('description', str)
 
     def test_Condition_by_id(self) -> None:
         """Test .by_id(), including creation."""
diff --git a/tests/processes.py b/tests/processes.py
index 8616008..9e769c1 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -42,34 +42,58 @@ class TestsWithDB(TestCaseWithDB):
             p.save(self.db_conn)
         return p1, p2, p3
 
-    def test_Process_saving_and_caching(self) -> None:
-        """Test .save/.save_core."""
-        kwargs = {'id_': 1}
-        self.check_saving_and_caching(**kwargs)
+    def p_of_conditions(self) -> tuple[Process, list[Condition],
+                                       list[Condition], list[Condition]]:
+        """Return Process and its three Condition sets."""
         p = Process(None)
-        p.title.set('t1')
-        p.title.set('t2')
-        p.description.set('d1')
-        p.description.set('d2')
-        p.effort.set(0.5)
-        p.effort.set(1.5)
         c1, c2, c3 = Condition(None), Condition(None), Condition(None)
         for c in [c1, c2, c3]:
             c.save(self.db_conn)
         assert isinstance(c1.id_, int)
         assert isinstance(c2.id_, int)
         assert isinstance(c3.id_, int)
-        p.set_conditions(self.db_conn, [c1.id_, c2.id_])
-        p.set_enables(self.db_conn, [c2.id_, c3.id_])
-        p.set_disables(self.db_conn, [c1.id_, c3.id_])
+        set_1 = [c1, c2]
+        set_2 = [c2, c3]
+        set_3 = [c1, c3]
+        p.set_conditions(self.db_conn, [c.id_ for c in set_1
+                                        if isinstance(c.id_, int)])
+        p.set_enables(self.db_conn, [c.id_ for c in set_2
+                                     if isinstance(c.id_, int)])
+        p.set_disables(self.db_conn, [c.id_ for c in set_3
+                                      if isinstance(c.id_, int)])
         p.save(self.db_conn)
+        return p, set_1, set_2, set_3
+
+    def test_Process_saving_and_caching(self) -> None:
+        """Test .save/.save_core."""
+        kwargs = {'id_': 1}
+        self.check_saving_and_caching(**kwargs)
+        self.check_saving_of_versioned('title', str)
+        self.check_saving_of_versioned('description', str)
+        self.check_saving_of_versioned('effort', float)
+        p, set1, set2, set3 = self.p_of_conditions()
+        p.uncache()
         r = Process.by_id(self.db_conn, p.id_)
-        self.assertEqual(sorted(r.title.history.values()), ['t1', 't2'])
-        self.assertEqual(sorted(r.description.history.values()), ['d1', 'd2'])
-        self.assertEqual(sorted(r.effort.history.values()), [0.5, 1.5])
-        self.assertEqual(sorted(r.conditions), sorted([c1, c2]))
-        self.assertEqual(sorted(r.enables), sorted([c2, c3]))
-        self.assertEqual(sorted(r.disables), sorted([c1, c3]))
+        self.assertEqual(sorted(r.conditions), sorted(set1))
+        self.assertEqual(sorted(r.enables), sorted(set2))
+        self.assertEqual(sorted(r.disables), sorted(set3))
+
+    def test_Process_from_table_row(self) -> None:
+        """Test .from_table_row() properly reads in class from DB"""
+        self.check_from_table_row()
+        self.check_versioned_from_table_row('title', str)
+        self.check_versioned_from_table_row('description', str)
+        self.check_versioned_from_table_row('effort', float)
+        p, set1, set2, set3 = self.p_of_conditions()
+        p.save(self.db_conn)
+        assert isinstance(p.id_, int)
+        for row in self.db_conn.row_where(self.checked_class.table_name,
+                                          'id', p.id_):
+            # pylint: disable=no-member
+            r = Process.from_table_row(self.db_conn, row)
+            self.assertEqual(sorted(r.conditions), sorted(set1))
+            self.assertEqual(sorted(r.enables), sorted(set2))
+            self.assertEqual(sorted(r.disables), sorted(set3))
 
     def test_Process_steps(self) -> None:
         """Test addition, nesting, and non-recursion of ProcessSteps"""
diff --git a/tests/utils.py b/tests/utils.py
index fbe739d..bb37270 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -80,6 +80,19 @@ class TestCaseWithDB(TestCase):
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
 
+    def check_saving_of_versioned(self, attr_name: str, type_: type) -> None:
+        """Test owner's versioned attributes."""
+        owner = self.checked_class(None)
+        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+        attr = getattr(owner, attr_name)
+        attr.set(vals[0])
+        attr.set(vals[1])
+        owner.save(self.db_conn)
+        owner.uncache()
+        retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+        attr = getattr(retrieved, attr_name)
+        self.assertEqual(sorted(attr.history.values()), vals)
+
     def check_by_id(self) -> None:
         """Test .by_id(), including creation."""
         # check failure if not yet saved
@@ -109,6 +122,20 @@ class TestCaseWithDB(TestCase):
             self.assertEqual(obj, retrieved)
             self.assertEqual({obj.id_: obj}, self.checked_class.get_cache())
 
+    def check_versioned_from_table_row(self, attr_name: str,
+                                       type_: type) -> None:
+        """Test .from_table_row() reads versioned attributes from DB."""
+        owner = self.checked_class(None)
+        vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
+        attr = getattr(owner, attr_name)
+        attr.set(vals[0])
+        attr.set(vals[1])
+        owner.save(self.db_conn)
+        for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
+            retrieved = owner.__class__.from_table_row(self.db_conn, row)
+            attr = getattr(retrieved, attr_name)
+            self.assertEqual(sorted(attr.history.values()), vals)
+
     def check_all(self) -> tuple[Any, Any, Any]:
         """Test .all()."""
         # pylint: disable=not-callable
diff --git a/tests/versioned_attributes.py b/tests/versioned_attributes.py
index 69c31fe..a75fc3c 100644
--- a/tests/versioned_attributes.py
+++ b/tests/versioned_attributes.py
@@ -6,7 +6,7 @@ from tests.utils import TestCaseWithDB
 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
 from plomtask.db import BaseModel
 
-SQL_TEST_TABLE = '''
+SQL_TEST_TABLE_STR = '''
 CREATE TABLE versioned_tests (
   parent INTEGER NOT NULL,
   timestamp TEXT NOT NULL,
@@ -14,6 +14,14 @@ CREATE TABLE versioned_tests (
   PRIMARY KEY (parent, timestamp)
 );
 '''
+SQL_TEST_TABLE_FLOAT = '''
+CREATE TABLE versioned_tests (
+  parent INTEGER NOT NULL,
+  timestamp TEXT NOT NULL,
+  value REAL NOT NULL,
+  PRIMARY KEY (parent, timestamp)
+);
+'''
 
 
 class TestParentType(BaseModel[int]):
@@ -81,20 +89,22 @@ class TestsSansDB(TestCase):
         self.assertEqual(attr.at(timestamp_after_c), 'C')
 
 
-class TestsWithDB(TestCaseWithDB):
+class TestsWithDBStr(TestCaseWithDB):
     """Module tests requiring DB setup."""
+    default_vals: list[str | float] = ['A', 'B', 'C']
+    init_sql = SQL_TEST_TABLE_STR
 
     def setUp(self) -> None:
         super().setUp()
-        self.db_conn.exec(SQL_TEST_TABLE)
+        self.db_conn.exec(self.init_sql)
         self.test_parent = TestParentType(1)
         self.attr = VersionedAttribute(self.test_parent,
-                                       'versioned_tests', 'A')
+                                       'versioned_tests', self.default_vals[0])
 
     def test_VersionedAttribute_save(self) -> None:
         """Test .save() to write to DB."""
         # check mere .set() calls do not by themselves reflect in the DB
-        self.attr.set('B')
+        self.attr.set(self.default_vals[1])
         self.assertEqual([],
                          self.db_conn.row_where('versioned_tests',
                                                 'parent', 1))
@@ -103,25 +113,32 @@ class TestsWithDB(TestCaseWithDB):
         vals_found = []
         for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
             vals_found += [row[2]]
-        self.assertEqual(['B'], vals_found)
+        self.assertEqual([self.default_vals[1]], vals_found)
         # check .save() also updates history in DB
-        self.attr.set('C')
+        self.attr.set(self.default_vals[2])
         self.attr.save(self.db_conn)
         vals_found = []
         for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
             vals_found += [row[2]]
-        self.assertEqual(['B', 'C'], sorted(vals_found))
+        self.assertEqual([self.default_vals[1], self.default_vals[2]],
+                         sorted(vals_found))
 
     def test_VersionedAttribute_history_from_row(self) -> None:
         """"Test .history_from_row() properly interprets DB rows."""
-        self.attr.set('B')
-        self.attr.set('C')
+        self.attr.set(self.default_vals[1])
+        self.attr.set(self.default_vals[2])
         self.attr.save(self.db_conn)
-        loaded_attr = VersionedAttribute(self.test_parent,
-                                         'versioned_tests', 'A')
+        loaded_attr = VersionedAttribute(self.test_parent, 'versioned_tests',
+                                         self.default_vals[0])
         for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
             loaded_attr.history_from_row(row)
         for timestamp, value in self.attr.history.items():
             self.assertEqual(value, loaded_attr.history[timestamp])
         self.assertEqual(len(self.attr.history.keys()),
                          len(loaded_attr.history.keys()))
+
+
+class TestsWithDBFloat(TestsWithDBStr):
+    """Module tests requiring DB setup."""
+    default_vals: list[str | float] = [0.9, 1.1, 2]
+    init_sql = SQL_TEST_TABLE_FLOAT