X-Git-Url: https://plomlompom.com/repos/foo.html?a=blobdiff_plain;f=tests%2Fversioned_attributes.py;h=a75fc3cec338d5bed5bf5739101287cd8efa1851;hb=80491fac3c476788d90010812c9ba0b95701e09b;hp=69c31fef1ac83853092229192c765ed5a4122824;hpb=5e3c633f1994329297999899790e69d28516934b;p=plomtask
diff --git a/tests/versioned_attributes.py b/tests/versioned_attributes.py
index 69c31fe..a75fc3c 100644
--- a/tests/versioned_attributes.py
+++ b/tests/versioned_attributes.py
@@ -6,7 +6,7 @@ from tests.utils import TestCaseWithDB
from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
from plomtask.db import BaseModel
-SQL_TEST_TABLE = '''
+SQL_TEST_TABLE_STR = '''
CREATE TABLE versioned_tests (
parent INTEGER NOT NULL,
timestamp TEXT NOT NULL,
@@ -14,6 +14,14 @@ CREATE TABLE versioned_tests (
PRIMARY KEY (parent, timestamp)
);
'''
+SQL_TEST_TABLE_FLOAT = '''
+CREATE TABLE versioned_tests (
+ parent INTEGER NOT NULL,
+ timestamp TEXT NOT NULL,
+ value REAL NOT NULL,
+ PRIMARY KEY (parent, timestamp)
+);
+'''
class TestParentType(BaseModel[int]):
@@ -81,20 +89,22 @@ class TestsSansDB(TestCase):
self.assertEqual(attr.at(timestamp_after_c), 'C')
-class TestsWithDB(TestCaseWithDB):
+class TestsWithDBStr(TestCaseWithDB):
"""Module tests requiring DB setup."""
+ default_vals: list[str | float] = ['A', 'B', 'C']
+ init_sql = SQL_TEST_TABLE_STR
def setUp(self) -> None:
super().setUp()
- self.db_conn.exec(SQL_TEST_TABLE)
+ self.db_conn.exec(self.init_sql)
self.test_parent = TestParentType(1)
self.attr = VersionedAttribute(self.test_parent,
- 'versioned_tests', 'A')
+ 'versioned_tests', self.default_vals[0])
def test_VersionedAttribute_save(self) -> None:
"""Test .save() to write to DB."""
# check mere .set() calls do not by themselves reflect in the DB
- self.attr.set('B')
+ self.attr.set(self.default_vals[1])
self.assertEqual([],
self.db_conn.row_where('versioned_tests',
'parent', 1))
@@ -103,25 +113,32 @@ class TestsWithDB(TestCaseWithDB):
vals_found = []
for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
vals_found += [row[2]]
- self.assertEqual(['B'], vals_found)
+ self.assertEqual([self.default_vals[1]], vals_found)
# check .save() also updates history in DB
- self.attr.set('C')
+ self.attr.set(self.default_vals[2])
self.attr.save(self.db_conn)
vals_found = []
for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
vals_found += [row[2]]
- self.assertEqual(['B', 'C'], sorted(vals_found))
+ self.assertEqual([self.default_vals[1], self.default_vals[2]],
+ sorted(vals_found))
def test_VersionedAttribute_history_from_row(self) -> None:
""""Test .history_from_row() properly interprets DB rows."""
- self.attr.set('B')
- self.attr.set('C')
+ self.attr.set(self.default_vals[1])
+ self.attr.set(self.default_vals[2])
self.attr.save(self.db_conn)
- loaded_attr = VersionedAttribute(self.test_parent,
- 'versioned_tests', 'A')
+ loaded_attr = VersionedAttribute(self.test_parent, 'versioned_tests',
+ self.default_vals[0])
for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
loaded_attr.history_from_row(row)
for timestamp, value in self.attr.history.items():
self.assertEqual(value, loaded_attr.history[timestamp])
self.assertEqual(len(self.attr.history.keys()),
len(loaded_attr.history.keys()))
+
+
+class TestsWithDBFloat(TestsWithDBStr):
+ """Module tests requiring DB setup."""
+ default_vals: list[str | float] = [0.9, 1.1, 2]
+ init_sql = SQL_TEST_TABLE_FLOAT