home · contact · privacy
Refactor and extend tests.
[plomtask] / tests / versioned_attributes.py
index 69c31fef1ac83853092229192c765ed5a4122824..a75fc3cec338d5bed5bf5739101287cd8efa1851 100644 (file)
@@ -6,7 +6,7 @@ from tests.utils import TestCaseWithDB
 from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
 from plomtask.db import BaseModel
 
-SQL_TEST_TABLE = '''
+SQL_TEST_TABLE_STR = '''
 CREATE TABLE versioned_tests (
   parent INTEGER NOT NULL,
   timestamp TEXT NOT NULL,
@@ -14,6 +14,14 @@ CREATE TABLE versioned_tests (
   PRIMARY KEY (parent, timestamp)
 );
 '''
+SQL_TEST_TABLE_FLOAT = '''
+CREATE TABLE versioned_tests (
+  parent INTEGER NOT NULL,
+  timestamp TEXT NOT NULL,
+  value REAL NOT NULL,
+  PRIMARY KEY (parent, timestamp)
+);
+'''
 
 
 class TestParentType(BaseModel[int]):
@@ -81,20 +89,22 @@ class TestsSansDB(TestCase):
         self.assertEqual(attr.at(timestamp_after_c), 'C')
 
 
-class TestsWithDB(TestCaseWithDB):
+class TestsWithDBStr(TestCaseWithDB):
     """Module tests requiring DB setup."""
+    default_vals: list[str | float] = ['A', 'B', 'C']
+    init_sql = SQL_TEST_TABLE_STR
 
     def setUp(self) -> None:
         super().setUp()
-        self.db_conn.exec(SQL_TEST_TABLE)
+        self.db_conn.exec(self.init_sql)
         self.test_parent = TestParentType(1)
         self.attr = VersionedAttribute(self.test_parent,
-                                       'versioned_tests', 'A')
+                                       'versioned_tests', self.default_vals[0])
 
     def test_VersionedAttribute_save(self) -> None:
         """Test .save() to write to DB."""
         # check mere .set() calls do not by themselves reflect in the DB
-        self.attr.set('B')
+        self.attr.set(self.default_vals[1])
         self.assertEqual([],
                          self.db_conn.row_where('versioned_tests',
                                                 'parent', 1))
@@ -103,25 +113,32 @@ class TestsWithDB(TestCaseWithDB):
         vals_found = []
         for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
             vals_found += [row[2]]
-        self.assertEqual(['B'], vals_found)
+        self.assertEqual([self.default_vals[1]], vals_found)
         # check .save() also updates history in DB
-        self.attr.set('C')
+        self.attr.set(self.default_vals[2])
         self.attr.save(self.db_conn)
         vals_found = []
         for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
             vals_found += [row[2]]
-        self.assertEqual(['B', 'C'], sorted(vals_found))
+        self.assertEqual([self.default_vals[1], self.default_vals[2]],
+                         sorted(vals_found))
 
     def test_VersionedAttribute_history_from_row(self) -> None:
         """"Test .history_from_row() properly interprets DB rows."""
-        self.attr.set('B')
-        self.attr.set('C')
+        self.attr.set(self.default_vals[1])
+        self.attr.set(self.default_vals[2])
         self.attr.save(self.db_conn)
-        loaded_attr = VersionedAttribute(self.test_parent,
-                                         'versioned_tests', 'A')
+        loaded_attr = VersionedAttribute(self.test_parent, 'versioned_tests',
+                                         self.default_vals[0])
         for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
             loaded_attr.history_from_row(row)
         for timestamp, value in self.attr.history.items():
             self.assertEqual(value, loaded_attr.history[timestamp])
         self.assertEqual(len(self.attr.history.keys()),
                          len(loaded_attr.history.keys()))
+
+
+class TestsWithDBFloat(TestsWithDBStr):
+    """Module tests requiring DB setup."""
+    default_vals: list[str | float] = [0.9, 1.1, 2]
+    init_sql = SQL_TEST_TABLE_FLOAT