X-Git-Url: https://plomlompom.com/repos/foo.html?a=blobdiff_plain;f=tests%2Fversioned_attributes.py;h=a75fc3cec338d5bed5bf5739101287cd8efa1851;hb=80491fac3c476788d90010812c9ba0b95701e09b;hp=69c31fef1ac83853092229192c765ed5a4122824;hpb=5e3c633f1994329297999899790e69d28516934b;p=plomtask diff --git a/tests/versioned_attributes.py b/tests/versioned_attributes.py index 69c31fe..a75fc3c 100644 --- a/tests/versioned_attributes.py +++ b/tests/versioned_attributes.py @@ -6,7 +6,7 @@ from tests.utils import TestCaseWithDB from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT from plomtask.db import BaseModel -SQL_TEST_TABLE = ''' +SQL_TEST_TABLE_STR = ''' CREATE TABLE versioned_tests ( parent INTEGER NOT NULL, timestamp TEXT NOT NULL, @@ -14,6 +14,14 @@ CREATE TABLE versioned_tests ( PRIMARY KEY (parent, timestamp) ); ''' +SQL_TEST_TABLE_FLOAT = ''' +CREATE TABLE versioned_tests ( + parent INTEGER NOT NULL, + timestamp TEXT NOT NULL, + value REAL NOT NULL, + PRIMARY KEY (parent, timestamp) +); +''' class TestParentType(BaseModel[int]): @@ -81,20 +89,22 @@ class TestsSansDB(TestCase): self.assertEqual(attr.at(timestamp_after_c), 'C') -class TestsWithDB(TestCaseWithDB): +class TestsWithDBStr(TestCaseWithDB): """Module tests requiring DB setup.""" + default_vals: list[str | float] = ['A', 'B', 'C'] + init_sql = SQL_TEST_TABLE_STR def setUp(self) -> None: super().setUp() - self.db_conn.exec(SQL_TEST_TABLE) + self.db_conn.exec(self.init_sql) self.test_parent = TestParentType(1) self.attr = VersionedAttribute(self.test_parent, - 'versioned_tests', 'A') + 'versioned_tests', self.default_vals[0]) def test_VersionedAttribute_save(self) -> None: """Test .save() to write to DB.""" # check mere .set() calls do not by themselves reflect in the DB - self.attr.set('B') + self.attr.set(self.default_vals[1]) self.assertEqual([], self.db_conn.row_where('versioned_tests', 'parent', 1)) @@ -103,25 +113,32 @@ class TestsWithDB(TestCaseWithDB): vals_found = [] for row in self.db_conn.row_where('versioned_tests', 'parent', 1): vals_found += [row[2]] - self.assertEqual(['B'], vals_found) + self.assertEqual([self.default_vals[1]], vals_found) # check .save() also updates history in DB - self.attr.set('C') + self.attr.set(self.default_vals[2]) self.attr.save(self.db_conn) vals_found = [] for row in self.db_conn.row_where('versioned_tests', 'parent', 1): vals_found += [row[2]] - self.assertEqual(['B', 'C'], sorted(vals_found)) + self.assertEqual([self.default_vals[1], self.default_vals[2]], + sorted(vals_found)) def test_VersionedAttribute_history_from_row(self) -> None: """"Test .history_from_row() properly interprets DB rows.""" - self.attr.set('B') - self.attr.set('C') + self.attr.set(self.default_vals[1]) + self.attr.set(self.default_vals[2]) self.attr.save(self.db_conn) - loaded_attr = VersionedAttribute(self.test_parent, - 'versioned_tests', 'A') + loaded_attr = VersionedAttribute(self.test_parent, 'versioned_tests', + self.default_vals[0]) for row in self.db_conn.row_where('versioned_tests', 'parent', 1): loaded_attr.history_from_row(row) for timestamp, value in self.attr.history.items(): self.assertEqual(value, loaded_attr.history[timestamp]) self.assertEqual(len(self.attr.history.keys()), len(loaded_attr.history.keys())) + + +class TestsWithDBFloat(TestsWithDBStr): + """Module tests requiring DB setup.""" + default_vals: list[str | float] = [0.9, 1.1, 2] + init_sql = SQL_TEST_TABLE_FLOAT