From 6c73e0e0ba5a50fd34139821d80466c81251d52e Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Thu, 11 Jul 2024 07:34:11 +0200 Subject: [PATCH] Run VersionedAttributes tests over all models. --- tests/utils.py | 227 +++++++++++++++++++++++++--------- tests/versioned_attributes.py | 144 --------------------- 2 files changed, 166 insertions(+), 205 deletions(-) delete mode 100644 tests/versioned_attributes.py diff --git a/tests/utils.py b/tests/utils.py index 4d81c91..52cc66e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,6 +4,8 @@ from unittest import TestCase from typing import Mapping, Any, Callable from threading import Thread from http.client import HTTPConnection +from datetime import datetime +from time import sleep from json import loads as json_loads from urllib.parse import urlencode from uuid import uuid4 @@ -14,7 +16,7 @@ from plomtask.processes import Process, ProcessStep from plomtask.conditions import Condition from plomtask.days import Day from plomtask.todos import Todo -from plomtask.versioned_attributes import VersionedAttribute +from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT from plomtask.exceptions import NotFoundException, HandledException @@ -25,6 +27,10 @@ def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: return wrapper +vals_str: list[Any] = ['A', 'B'] +vals_float: list[Any] = [0.3, 1.1] + + class TestCaseSansDB(TestCase): """Tests requiring no DB setup.""" checked_class: Any @@ -33,6 +39,17 @@ class TestCaseSansDB(TestCase): legal_ids = [1, 5] illegal_ids = [0] + @staticmethod + def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: TestCaseSansDB) -> None: + owner = self.checked_class(self.legal_ids[0], + *self.default_init_args) + for attr_name, default in self.versioned_defaults_to_test.items(): + to_set = vals_str if isinstance(default, str) else vals_float + attr = getattr(owner, attr_name) + f(self, attr, default, to_set) + return wrapper + @_within_checked_class def test_id_validation(self) -> None: """Test .id_ validation/setting.""" @@ -44,12 +61,79 @@ class TestCaseSansDB(TestCase): self.assertEqual(obj.id_, id_) @_within_checked_class - def test_versioned_defaults(self) -> None: - """Test defaults of VersionedAttributes.""" - id_ = self.legal_ids[0] - obj = self.checked_class(id_, *self.default_init_args) - for k, v in self.versioned_defaults_to_test.items(): - self.assertEqual(getattr(obj, k).newest, v) + @_for_versioned_attr + def test_versioned_set(self, attr: VersionedAttribute, + default: str | float, to_set: list[str | float] + ) -> None: + """Test VersionedAttribute.set() behaves as expected.""" + attr.set(default) + self.assertEqual(list(attr.history.values()), [default]) + # check same value does not get set twice in a row, + # and that not even its timestamp get updated + timestamp = list(attr.history.keys())[0] + attr.set(default) + self.assertEqual(list(attr.history.values()), [default]) + self.assertEqual(list(attr.history.keys())[0], timestamp) + # check that different value _will_ be set/added + attr.set(to_set[0]) + timesorted_vals = [attr.history[t] for t in attr.history.keys()] + expected = [default, to_set[0]] + self.assertEqual(timesorted_vals, expected) + # check that a previously used value can be set if not most recent + attr.set(default) + timesorted_vals = [attr.history[t] for t in attr.history.keys()] + expected = [default, to_set[0], default] + self.assertEqual(timesorted_vals, expected) + # again check for same value not being set twice in a row, even for + # later items + attr.set(to_set[1]) + timesorted_vals = [attr.history[t] for t in attr.history.keys()] + expected = [default, to_set[0], default, to_set[1]] + self.assertEqual(timesorted_vals, expected) + attr.set(to_set[1]) + self.assertEqual(timesorted_vals, expected) + + @_within_checked_class + @_for_versioned_attr + def test_versioned_newest(self, attr: VersionedAttribute, + default: str | float, to_set: list[str | float] + ) -> None: + """Test VersionedAttribute.newest.""" + # check .newest on empty history returns .default + self.assertEqual(attr.newest, default) + # check newest element always returned + for v in [to_set[0], to_set[1]]: + attr.set(v) + self.assertEqual(attr.newest, v) + # check newest element returned even if also early value + attr.set(default) + self.assertEqual(attr.newest, default) + + @_within_checked_class + @_for_versioned_attr + def test_versioned_at(self, attr: VersionedAttribute, default: str | float, + to_set: list[str | float]) -> None: + """Test .at() returns values nearest to queried time, or default.""" + # check .at() return default on empty history + timestamp_a = datetime.now().strftime(TIMESTAMP_FMT) + self.assertEqual(attr.at(timestamp_a), default) + # check value exactly at timestamp returned + attr.set(to_set[0]) + timestamp_b = list(attr.history.keys())[0] + self.assertEqual(attr.at(timestamp_b), to_set[0]) + # check earliest value returned if exists, rather than default + self.assertEqual(attr.at(timestamp_a), to_set[0]) + # check reverts to previous value for timestamps not indexed + sleep(0.00001) + timestamp_between = datetime.now().strftime(TIMESTAMP_FMT) + sleep(0.00001) + attr.set(to_set[1]) + timestamp_c = sorted(attr.history.keys())[-1] + self.assertEqual(attr.at(timestamp_c), to_set[1]) + self.assertEqual(attr.at(timestamp_between), to_set[0]) + sleep(0.00001) + timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT) + self.assertEqual(attr.at(timestamp_after_c), to_set[1]) class TestCaseWithDB(TestCase): @@ -72,6 +156,18 @@ class TestCaseWithDB(TestCase): self.db_conn.close() remove_file(self.db_file.path) + @staticmethod + def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: TestCaseWithDB) -> None: + for attr_name, type_ in self.test_versioneds.items(): + owner = self.checked_class(None, **self.default_init_kwargs) + to_set = vals_str if str == type_ else vals_float + attr = getattr(owner, attr_name) + attr.set(to_set[0]) + attr.set(to_set[1]) + f(self, owner, attr_name, attr, to_set) + return wrapper + def _load_from_db(self, id_: int | str) -> list[object]: db_found: list[object] = [] for row in self.db_conn.row_where(self.checked_class.table_name, @@ -93,16 +189,6 @@ class TestCaseWithDB(TestCase): setattr(obj, attr_name, new_attr) return attr_name - def _versioned_attrs_and_owner(self, attr_name: str, type_: type - ) -> tuple[Any, list[str | float], - VersionedAttribute]: - owner = self.checked_class(None, **self.default_init_kwargs) - vals: list[str | float] = ['t1', 't2'] if type_ == str else [0.9, 1.1] - attr = getattr(owner, attr_name) - attr.set(vals[0]) - attr.set(vals[1]) - return owner, vals, attr - def check_identity_with_cache_and_db(self, content: list[Any]) -> None: """Test both cache and DB equal content.""" expected_cache = {} @@ -118,39 +204,41 @@ class TestCaseWithDB(TestCase): self.assertEqual(sorted(hashes_content), sorted(hashes_db_found)) @_within_checked_class - def test_saving_versioned_attributes(self) -> None: + @_for_versioned_attr + def test_saving_versioned_attributes(self, owner: Any, attr_name: str, + attr: VersionedAttribute, + vals: list[str | float] + ) -> None: """Test storage and initialization of versioned attributes.""" - def retrieve_attr_vals(attr: VersionedAttribute, owner_id: int - ) -> list[object]: + def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]: attr_vals_saved: list[object] = [] for row in self.db_conn.row_where(attr.table_name, 'parent', - owner_id): + owner.id_): attr_vals_saved += [row[2]] return attr_vals_saved - owner_id = 1 - for name, type_ in self.test_versioneds.items(): - # fail saving attributes on non-saved owner - owner, vals, attr = self._versioned_attrs_and_owner(name, type_) - with self.assertRaises(NotFoundException): - attr.save(self.db_conn) - # check owner.save() created entries as expected in attr table - owner.save(self.db_conn) - attr_vals_saved = retrieve_attr_vals(attr, owner_id) - self.assertEqual(vals, attr_vals_saved) - # check changing attr val without save affects owner in memory … - attr.set(vals[0]) - cmp_attr = getattr(owner, name) - self.assertEqual(cmp_attr.history, attr.history) - # … but does not yet affect DB - attr_vals_saved = retrieve_attr_vals(attr, owner_id) - self.assertEqual(vals, attr_vals_saved) - # check individual attr.save also stores new val to DB + # check that without attr.save() no rows in DB + rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_) + self.assertEqual([], rows) + # fail saving attributes on non-saved owner + with self.assertRaises(NotFoundException): attr.save(self.db_conn) - attr_vals_saved = retrieve_attr_vals(attr, owner_id) - self.assertEqual(vals + [vals[0]], attr_vals_saved) - owner_id += 1 + # check owner.save() created entries as expected in attr table + owner.save(self.db_conn) + attr_vals_saved = retrieve_attr_vals(attr) + self.assertEqual(vals, attr_vals_saved) + # check changing attr val without save affects owner in memory … + attr.set(vals[0]) + cmp_attr = getattr(owner, attr_name) + self.assertEqual(cmp_attr.history, attr.history) + # … but does not yet affect DB + attr_vals_saved = retrieve_attr_vals(attr) + self.assertEqual(vals, attr_vals_saved) + # check individual attr.save also stores new val to DB + attr.save(self.db_conn) + attr_vals_saved = retrieve_attr_vals(attr) + self.assertEqual(vals + [vals[0]], attr_vals_saved) @_within_checked_class def test_saving_and_caching(self) -> None: @@ -249,16 +337,31 @@ class TestCaseWithDB(TestCase): # check cache contains what .from_table_row just produced self.assertEqual({retrieved.id_: retrieved}, self.checked_class.get_cache()) - # check .from_table_row also reads versioned attributes from DB - for name, type_ in self.test_versioneds.items(): - owner, vals, attr = self._versioned_attrs_and_owner(name, type_) - owner.save(self.db_conn) - attr.set(vals[0]) - for row in self.db_conn.row_where(owner.table_name, 'id', + + @_within_checked_class + @_for_versioned_attr + def test_versioned_history_from_row(self, owner: Any, _: str, + attr: VersionedAttribute, + vals: list[str | float] + ) -> None: + """"Test VersionedAttribute.history_from_row() knows its DB rows.""" + vals += [vals[1] * 2] + attr.set(vals[2]) + attr.set(vals[1]) + attr.set(vals[2]) + owner.save(self.db_conn) + # make empty VersionedAttribute, fill from rows, compare to owner's + for row in self.db_conn.row_where(owner.table_name, 'id', + owner.id_): + loaded_attr = VersionedAttribute(owner, attr.table_name, + attr.default) + for row in self.db_conn.row_where(attr.table_name, 'parent', owner.id_): - retrieved = owner.__class__.from_table_row(self.db_conn, row) - attr = getattr(retrieved, name) - self.assertEqual(sorted(attr.history.values()), vals) + loaded_attr.history_from_row(row) + self.assertEqual(len(attr.history.keys()), + len(loaded_attr.history.keys())) + for timestamp, value in attr.history.items(): + self.assertEqual(value, loaded_attr.history[timestamp]) @_within_checked_class def test_all(self) -> None: @@ -291,16 +394,18 @@ class TestCaseWithDB(TestCase): self.assertEqual(new_attr, getattr(retrieved, attr_name)) @_within_checked_class - def test_versioned_singularity_title(self) -> None: - """Test singularity of VersionedAttributes on saving (with .title).""" - if 'title' in self.test_versioneds: - obj = self.checked_class(None) - obj.save(self.db_conn) - assert isinstance(obj.id_, int) - # change obj, expect retrieved through .by_id to carry change - obj.title.set('named') - retrieved = self.checked_class.by_id(self.db_conn, obj.id_) - self.assertEqual(obj.title.history, retrieved.title.history) + @_for_versioned_attr + def test_versioned_singularity(self, owner: Any, attr_name: str, + attr: VersionedAttribute, + vals: list[str | float] + ) -> None: + """Test singularity of VersionedAttributes on saving.""" + owner.save(self.db_conn) + # change obj, expect retrieved through .by_id to carry change + attr.set(vals[0]) + retrieved = self.checked_class.by_id(self.db_conn, owner.id_) + attr_retrieved = getattr(retrieved, attr_name) + self.assertEqual(attr.history, attr_retrieved.history) @_within_checked_class def test_remove(self) -> None: diff --git a/tests/versioned_attributes.py b/tests/versioned_attributes.py deleted file mode 100644 index a75fc3c..0000000 --- a/tests/versioned_attributes.py +++ /dev/null @@ -1,144 +0,0 @@ -""""Test Versioned Attributes in the abstract.""" -from unittest import TestCase -from time import sleep -from datetime import datetime -from tests.utils import TestCaseWithDB -from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT -from plomtask.db import BaseModel - -SQL_TEST_TABLE_STR = ''' -CREATE TABLE versioned_tests ( - parent INTEGER NOT NULL, - timestamp TEXT NOT NULL, - value TEXT NOT NULL, - PRIMARY KEY (parent, timestamp) -); -''' -SQL_TEST_TABLE_FLOAT = ''' -CREATE TABLE versioned_tests ( - parent INTEGER NOT NULL, - timestamp TEXT NOT NULL, - value REAL NOT NULL, - PRIMARY KEY (parent, timestamp) -); -''' - - -class TestParentType(BaseModel[int]): - """Dummy abstracting whatever may use VersionedAttributes.""" - - -class TestsSansDB(TestCase): - """Tests not requiring DB setup.""" - - def test_VersionedAttribute_set(self) -> None: - """Test .set() behaves as expected.""" - # check value gets set even if already is the default - attr = VersionedAttribute(None, '', 'A') - attr.set('A') - self.assertEqual(list(attr.history.values()), ['A']) - # check same value does not get set twice in a row, - # and that not even its timestamp get updated - timestamp = list(attr.history.keys())[0] - attr.set('A') - self.assertEqual(list(attr.history.values()), ['A']) - self.assertEqual(list(attr.history.keys())[0], timestamp) - # check that different value _will_ be set/added - attr.set('B') - self.assertEqual(sorted(attr.history.values()), ['A', 'B']) - # check that a previously used value can be set if not most recent - attr.set('A') - self.assertEqual(sorted(attr.history.values()), ['A', 'A', 'B']) - # again check for same value not being set twice in a row, even for - # later items - attr.set('D') - self.assertEqual(sorted(attr.history.values()), ['A', 'A', 'B', 'D']) - attr.set('D') - self.assertEqual(sorted(attr.history.values()), ['A', 'A', 'B', 'D']) - - def test_VersionedAttribute_newest(self) -> None: - """Test .newest returns newest element, or default on empty.""" - attr = VersionedAttribute(None, '', 'A') - self.assertEqual(attr.newest, 'A') - attr.set('B') - self.assertEqual(attr.newest, 'B') - attr.set('C') - - def test_VersionedAttribute_at(self) -> None: - """Test .at() returns values nearest to queried time, or default.""" - # check .at() return default on empty history - attr = VersionedAttribute(None, '', 'A') - timestamp_a = datetime.now().strftime(TIMESTAMP_FMT) - self.assertEqual(attr.at(timestamp_a), 'A') - # check value exactly at timestamp returned - attr.set('B') - timestamp_b = list(attr.history.keys())[0] - self.assertEqual(attr.at(timestamp_b), 'B') - # check earliest value returned if exists, rather than default - self.assertEqual(attr.at(timestamp_a), 'B') - # check reverts to previous value for timestamps not indexed - sleep(0.00001) - timestamp_between = datetime.now().strftime(TIMESTAMP_FMT) - sleep(0.00001) - attr.set('C') - timestamp_c = sorted(attr.history.keys())[-1] - self.assertEqual(attr.at(timestamp_c), 'C') - self.assertEqual(attr.at(timestamp_between), 'B') - sleep(0.00001) - timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT) - self.assertEqual(attr.at(timestamp_after_c), 'C') - - -class TestsWithDBStr(TestCaseWithDB): - """Module tests requiring DB setup.""" - default_vals: list[str | float] = ['A', 'B', 'C'] - init_sql = SQL_TEST_TABLE_STR - - def setUp(self) -> None: - super().setUp() - self.db_conn.exec(self.init_sql) - self.test_parent = TestParentType(1) - self.attr = VersionedAttribute(self.test_parent, - 'versioned_tests', self.default_vals[0]) - - def test_VersionedAttribute_save(self) -> None: - """Test .save() to write to DB.""" - # check mere .set() calls do not by themselves reflect in the DB - self.attr.set(self.default_vals[1]) - self.assertEqual([], - self.db_conn.row_where('versioned_tests', - 'parent', 1)) - # check .save() makes history appear in DB - self.attr.save(self.db_conn) - vals_found = [] - for row in self.db_conn.row_where('versioned_tests', 'parent', 1): - vals_found += [row[2]] - self.assertEqual([self.default_vals[1]], vals_found) - # check .save() also updates history in DB - self.attr.set(self.default_vals[2]) - self.attr.save(self.db_conn) - vals_found = [] - for row in self.db_conn.row_where('versioned_tests', 'parent', 1): - vals_found += [row[2]] - self.assertEqual([self.default_vals[1], self.default_vals[2]], - sorted(vals_found)) - - def test_VersionedAttribute_history_from_row(self) -> None: - """"Test .history_from_row() properly interprets DB rows.""" - self.attr.set(self.default_vals[1]) - self.attr.set(self.default_vals[2]) - self.attr.save(self.db_conn) - loaded_attr = VersionedAttribute(self.test_parent, 'versioned_tests', - self.default_vals[0]) - for row in self.db_conn.row_where('versioned_tests', 'parent', 1): - loaded_attr.history_from_row(row) - for timestamp, value in self.attr.history.items(): - self.assertEqual(value, loaded_attr.history[timestamp]) - self.assertEqual(len(self.attr.history.keys()), - len(loaded_attr.history.keys())) - - -class TestsWithDBFloat(TestsWithDBStr): - """Module tests requiring DB setup.""" - default_vals: list[str | float] = [0.9, 1.1, 2] - init_sql = SQL_TEST_TABLE_FLOAT -- 2.30.2