From 6c73e0e0ba5a50fd34139821d80466c81251d52e Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 11 Jul 2024 07:34:11 +0200
Subject: [PATCH] Run VersionedAttributes tests over all models.

---
 tests/utils.py                | 227 +++++++++++++++++++++++++---------
 tests/versioned_attributes.py | 144 ---------------------
 2 files changed, 166 insertions(+), 205 deletions(-)
 delete mode 100644 tests/versioned_attributes.py

diff --git a/tests/utils.py b/tests/utils.py
index 4d81c91..52cc66e 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -4,6 +4,8 @@ from unittest import TestCase
 from typing import Mapping, Any, Callable
 from threading import Thread
 from http.client import HTTPConnection
+from datetime import datetime
+from time import sleep
 from json import loads as json_loads
 from urllib.parse import urlencode
 from uuid import uuid4
@@ -14,7 +16,7 @@ from plomtask.processes import Process, ProcessStep
 from plomtask.conditions import Condition
 from plomtask.days import Day
 from plomtask.todos import Todo
-from plomtask.versioned_attributes import VersionedAttribute
+from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
 from plomtask.exceptions import NotFoundException, HandledException
 
 
@@ -25,6 +27,10 @@ def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
     return wrapper
 
 
+vals_str: list[Any] = ['A', 'B']
+vals_float: list[Any] = [0.3, 1.1]
+
+
 class TestCaseSansDB(TestCase):
     """Tests requiring no DB setup."""
     checked_class: Any
@@ -33,6 +39,17 @@ class TestCaseSansDB(TestCase):
     legal_ids = [1, 5]
     illegal_ids = [0]
 
+    @staticmethod
+    def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseSansDB) -> None:
+            owner = self.checked_class(self.legal_ids[0],
+                                       *self.default_init_args)
+            for attr_name, default in self.versioned_defaults_to_test.items():
+                to_set = vals_str if isinstance(default, str) else vals_float
+                attr = getattr(owner, attr_name)
+                f(self, attr, default, to_set)
+        return wrapper
+
     @_within_checked_class
     def test_id_validation(self) -> None:
         """Test .id_ validation/setting."""
@@ -44,12 +61,79 @@ class TestCaseSansDB(TestCase):
             self.assertEqual(obj.id_, id_)
 
     @_within_checked_class
-    def test_versioned_defaults(self) -> None:
-        """Test defaults of VersionedAttributes."""
-        id_ = self.legal_ids[0]
-        obj = self.checked_class(id_, *self.default_init_args)
-        for k, v in self.versioned_defaults_to_test.items():
-            self.assertEqual(getattr(obj, k).newest, v)
+    @_for_versioned_attr
+    def test_versioned_set(self, attr: VersionedAttribute,
+                           default: str | float, to_set: list[str | float]
+                           ) -> None:
+        """Test VersionedAttribute.set() behaves as expected."""
+        attr.set(default)
+        self.assertEqual(list(attr.history.values()), [default])
+        # check same value does not get set twice in a row,
+        # and that not even its timestamp get updated
+        timestamp = list(attr.history.keys())[0]
+        attr.set(default)
+        self.assertEqual(list(attr.history.values()), [default])
+        self.assertEqual(list(attr.history.keys())[0], timestamp)
+        # check that different value _will_ be set/added
+        attr.set(to_set[0])
+        timesorted_vals = [attr.history[t] for t in attr.history.keys()]
+        expected = [default, to_set[0]]
+        self.assertEqual(timesorted_vals, expected)
+        # check that a previously used value can be set if not most recent
+        attr.set(default)
+        timesorted_vals = [attr.history[t] for t in attr.history.keys()]
+        expected = [default, to_set[0], default]
+        self.assertEqual(timesorted_vals, expected)
+        # again check for same value not being set twice in a row, even for
+        # later items
+        attr.set(to_set[1])
+        timesorted_vals = [attr.history[t] for t in attr.history.keys()]
+        expected = [default, to_set[0], default, to_set[1]]
+        self.assertEqual(timesorted_vals, expected)
+        attr.set(to_set[1])
+        self.assertEqual(timesorted_vals, expected)
+
+    @_within_checked_class
+    @_for_versioned_attr
+    def test_versioned_newest(self, attr: VersionedAttribute,
+                              default: str | float, to_set: list[str | float]
+                              ) -> None:
+        """Test VersionedAttribute.newest."""
+        # check .newest on empty history returns .default
+        self.assertEqual(attr.newest, default)
+        # check newest element always returned
+        for v in [to_set[0], to_set[1]]:
+            attr.set(v)
+            self.assertEqual(attr.newest, v)
+        # check newest element returned even if also early value
+        attr.set(default)
+        self.assertEqual(attr.newest, default)
+
+    @_within_checked_class
+    @_for_versioned_attr
+    def test_versioned_at(self, attr: VersionedAttribute, default: str | float,
+                          to_set: list[str | float]) -> None:
+        """Test .at() returns values nearest to queried time, or default."""
+        # check .at() return default on empty history
+        timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
+        self.assertEqual(attr.at(timestamp_a), default)
+        # check value exactly at timestamp returned
+        attr.set(to_set[0])
+        timestamp_b = list(attr.history.keys())[0]
+        self.assertEqual(attr.at(timestamp_b), to_set[0])
+        # check earliest value returned if exists, rather than default
+        self.assertEqual(attr.at(timestamp_a), to_set[0])
+        # check reverts to previous value for timestamps not indexed
+        sleep(0.00001)
+        timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
+        sleep(0.00001)
+        attr.set(to_set[1])
+        timestamp_c = sorted(attr.history.keys())[-1]
+        self.assertEqual(attr.at(timestamp_c), to_set[1])
+        self.assertEqual(attr.at(timestamp_between), to_set[0])
+        sleep(0.00001)
+        timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
+        self.assertEqual(attr.at(timestamp_after_c), to_set[1])
 
 
 class TestCaseWithDB(TestCase):
@@ -72,6 +156,18 @@ class TestCaseWithDB(TestCase):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
+    @staticmethod
+    def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]:
+        def wrapper(self: TestCaseWithDB) -> None:
+            for attr_name, type_ in self.test_versioneds.items():
+                owner = self.checked_class(None, **self.default_init_kwargs)
+                to_set = vals_str if str == type_ else vals_float
+                attr = getattr(owner, attr_name)
+                attr.set(to_set[0])
+                attr.set(to_set[1])
+                f(self, owner, attr_name, attr, to_set)
+        return wrapper
+
     def _load_from_db(self, id_: int | str) -> list[object]:
         db_found: list[object] = []
         for row in self.db_conn.row_where(self.checked_class.table_name,
@@ -93,16 +189,6 @@ class TestCaseWithDB(TestCase):
         setattr(obj, attr_name, new_attr)
         return attr_name
 
-    def _versioned_attrs_and_owner(self, attr_name: str, type_: type
-                                   ) -> tuple[Any, list[str | float],
-                                              VersionedAttribute]:
-        owner = self.checked_class(None, **self.default_init_kwargs)
-        vals: list[str | float] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
-        attr = getattr(owner, attr_name)
-        attr.set(vals[0])
-        attr.set(vals[1])
-        return owner, vals, attr
-
     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
         """Test both cache and DB equal content."""
         expected_cache = {}
@@ -118,39 +204,41 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
 
     @_within_checked_class
-    def test_saving_versioned_attributes(self) -> None:
+    @_for_versioned_attr
+    def test_saving_versioned_attributes(self, owner: Any, attr_name: str,
+                                         attr: VersionedAttribute,
+                                         vals: list[str | float]
+                                         ) -> None:
         """Test storage and initialization of versioned attributes."""
 
-        def retrieve_attr_vals(attr: VersionedAttribute, owner_id: int
-                               ) -> list[object]:
+        def retrieve_attr_vals(attr: VersionedAttribute) -> list[object]:
             attr_vals_saved: list[object] = []
             for row in self.db_conn.row_where(attr.table_name, 'parent',
-                                              owner_id):
+                                              owner.id_):
                 attr_vals_saved += [row[2]]
             return attr_vals_saved
 
-        owner_id = 1
-        for name, type_ in self.test_versioneds.items():
-            # fail saving attributes on non-saved owner
-            owner, vals, attr = self._versioned_attrs_and_owner(name, type_)
-            with self.assertRaises(NotFoundException):
-                attr.save(self.db_conn)
-            # check owner.save() created entries as expected in attr table
-            owner.save(self.db_conn)
-            attr_vals_saved = retrieve_attr_vals(attr, owner_id)
-            self.assertEqual(vals, attr_vals_saved)
-            # check changing attr val without save affects owner in memory …
-            attr.set(vals[0])
-            cmp_attr = getattr(owner, name)
-            self.assertEqual(cmp_attr.history, attr.history)
-            # … but does not yet affect DB
-            attr_vals_saved = retrieve_attr_vals(attr, owner_id)
-            self.assertEqual(vals, attr_vals_saved)
-            # check individual attr.save also stores new val to DB
+        # check that without attr.save() no rows in DB
+        rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
+        self.assertEqual([], rows)
+        # fail saving attributes on non-saved owner
+        with self.assertRaises(NotFoundException):
             attr.save(self.db_conn)
-            attr_vals_saved = retrieve_attr_vals(attr, owner_id)
-            self.assertEqual(vals + [vals[0]], attr_vals_saved)
-            owner_id += 1
+        # check owner.save() created entries as expected in attr table
+        owner.save(self.db_conn)
+        attr_vals_saved = retrieve_attr_vals(attr)
+        self.assertEqual(vals, attr_vals_saved)
+        # check changing attr val without save affects owner in memory …
+        attr.set(vals[0])
+        cmp_attr = getattr(owner, attr_name)
+        self.assertEqual(cmp_attr.history, attr.history)
+        # … but does not yet affect DB
+        attr_vals_saved = retrieve_attr_vals(attr)
+        self.assertEqual(vals, attr_vals_saved)
+        # check individual attr.save also stores new val to DB
+        attr.save(self.db_conn)
+        attr_vals_saved = retrieve_attr_vals(attr)
+        self.assertEqual(vals + [vals[0]], attr_vals_saved)
 
     @_within_checked_class
     def test_saving_and_caching(self) -> None:
@@ -249,16 +337,31 @@ class TestCaseWithDB(TestCase):
             # check cache contains what .from_table_row just produced
             self.assertEqual({retrieved.id_: retrieved},
                              self.checked_class.get_cache())
-        # check .from_table_row also reads versioned attributes from DB
-        for name, type_ in self.test_versioneds.items():
-            owner, vals, attr = self._versioned_attrs_and_owner(name, type_)
-            owner.save(self.db_conn)
-            attr.set(vals[0])
-            for row in self.db_conn.row_where(owner.table_name, 'id',
+
+    @_within_checked_class
+    @_for_versioned_attr
+    def test_versioned_history_from_row(self, owner: Any, _: str,
+                                        attr: VersionedAttribute,
+                                        vals: list[str | float]
+                                        ) -> None:
+        """"Test VersionedAttribute.history_from_row() knows its DB rows."""
+        vals += [vals[1] * 2]
+        attr.set(vals[2])
+        attr.set(vals[1])
+        attr.set(vals[2])
+        owner.save(self.db_conn)
+        # make empty VersionedAttribute, fill from rows, compare to owner's
+        for row in self.db_conn.row_where(owner.table_name, 'id',
+                                          owner.id_):
+            loaded_attr = VersionedAttribute(owner, attr.table_name,
+                                             attr.default)
+            for row in self.db_conn.row_where(attr.table_name, 'parent',
                                               owner.id_):
-                retrieved = owner.__class__.from_table_row(self.db_conn, row)
-                attr = getattr(retrieved, name)
-                self.assertEqual(sorted(attr.history.values()), vals)
+                loaded_attr.history_from_row(row)
+            self.assertEqual(len(attr.history.keys()),
+                             len(loaded_attr.history.keys()))
+            for timestamp, value in attr.history.items():
+                self.assertEqual(value, loaded_attr.history[timestamp])
 
     @_within_checked_class
     def test_all(self) -> None:
@@ -291,16 +394,18 @@ class TestCaseWithDB(TestCase):
         self.assertEqual(new_attr, getattr(retrieved, attr_name))
 
     @_within_checked_class
-    def test_versioned_singularity_title(self) -> None:
-        """Test singularity of VersionedAttributes on saving (with .title)."""
-        if 'title' in self.test_versioneds:
-            obj = self.checked_class(None)
-            obj.save(self.db_conn)
-            assert isinstance(obj.id_, int)
-            # change obj, expect retrieved through .by_id to carry change
-            obj.title.set('named')
-            retrieved = self.checked_class.by_id(self.db_conn, obj.id_)
-            self.assertEqual(obj.title.history, retrieved.title.history)
+    @_for_versioned_attr
+    def test_versioned_singularity(self, owner: Any, attr_name: str,
+                                   attr: VersionedAttribute,
+                                   vals: list[str | float]
+                                   ) -> None:
+        """Test singularity of VersionedAttributes on saving."""
+        owner.save(self.db_conn)
+        # change obj, expect retrieved through .by_id to carry change
+        attr.set(vals[0])
+        retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
+        attr_retrieved = getattr(retrieved, attr_name)
+        self.assertEqual(attr.history, attr_retrieved.history)
 
     @_within_checked_class
     def test_remove(self) -> None:
diff --git a/tests/versioned_attributes.py b/tests/versioned_attributes.py
deleted file mode 100644
index a75fc3c..0000000
--- a/tests/versioned_attributes.py
+++ /dev/null
@@ -1,144 +0,0 @@
-""""Test Versioned Attributes in the abstract."""
-from unittest import TestCase
-from time import sleep
-from datetime import datetime
-from tests.utils import TestCaseWithDB
-from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT
-from plomtask.db import BaseModel
-
-SQL_TEST_TABLE_STR = '''
-CREATE TABLE versioned_tests (
-  parent INTEGER NOT NULL,
-  timestamp TEXT NOT NULL,
-  value TEXT NOT NULL,
-  PRIMARY KEY (parent, timestamp)
-);
-'''
-SQL_TEST_TABLE_FLOAT = '''
-CREATE TABLE versioned_tests (
-  parent INTEGER NOT NULL,
-  timestamp TEXT NOT NULL,
-  value REAL NOT NULL,
-  PRIMARY KEY (parent, timestamp)
-);
-'''
-
-
-class TestParentType(BaseModel[int]):
-    """Dummy abstracting whatever may use VersionedAttributes."""
-
-
-class TestsSansDB(TestCase):
-    """Tests not requiring DB setup."""
-
-    def test_VersionedAttribute_set(self) -> None:
-        """Test .set() behaves as expected."""
-        # check value gets set even if already is the default
-        attr = VersionedAttribute(None, '', 'A')
-        attr.set('A')
-        self.assertEqual(list(attr.history.values()), ['A'])
-        # check same value does not get set twice in a row,
-        # and that not even its timestamp get updated
-        timestamp = list(attr.history.keys())[0]
-        attr.set('A')
-        self.assertEqual(list(attr.history.values()), ['A'])
-        self.assertEqual(list(attr.history.keys())[0], timestamp)
-        # check that different value _will_ be set/added
-        attr.set('B')
-        self.assertEqual(sorted(attr.history.values()), ['A', 'B'])
-        # check that a previously used value can be set if not most recent
-        attr.set('A')
-        self.assertEqual(sorted(attr.history.values()), ['A', 'A', 'B'])
-        # again check for same value not being set twice in a row, even for
-        # later items
-        attr.set('D')
-        self.assertEqual(sorted(attr.history.values()), ['A', 'A', 'B', 'D'])
-        attr.set('D')
-        self.assertEqual(sorted(attr.history.values()), ['A', 'A', 'B', 'D'])
-
-    def test_VersionedAttribute_newest(self) -> None:
-        """Test .newest returns newest element, or default on empty."""
-        attr = VersionedAttribute(None, '', 'A')
-        self.assertEqual(attr.newest, 'A')
-        attr.set('B')
-        self.assertEqual(attr.newest, 'B')
-        attr.set('C')
-
-    def test_VersionedAttribute_at(self) -> None:
-        """Test .at() returns values nearest to queried time, or default."""
-        # check .at() return default on empty history
-        attr = VersionedAttribute(None, '', 'A')
-        timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
-        self.assertEqual(attr.at(timestamp_a), 'A')
-        # check value exactly at timestamp returned
-        attr.set('B')
-        timestamp_b = list(attr.history.keys())[0]
-        self.assertEqual(attr.at(timestamp_b), 'B')
-        # check earliest value returned if exists, rather than default
-        self.assertEqual(attr.at(timestamp_a), 'B')
-        # check reverts to previous value for timestamps not indexed
-        sleep(0.00001)
-        timestamp_between = datetime.now().strftime(TIMESTAMP_FMT)
-        sleep(0.00001)
-        attr.set('C')
-        timestamp_c = sorted(attr.history.keys())[-1]
-        self.assertEqual(attr.at(timestamp_c), 'C')
-        self.assertEqual(attr.at(timestamp_between), 'B')
-        sleep(0.00001)
-        timestamp_after_c = datetime.now().strftime(TIMESTAMP_FMT)
-        self.assertEqual(attr.at(timestamp_after_c), 'C')
-
-
-class TestsWithDBStr(TestCaseWithDB):
-    """Module tests requiring DB setup."""
-    default_vals: list[str | float] = ['A', 'B', 'C']
-    init_sql = SQL_TEST_TABLE_STR
-
-    def setUp(self) -> None:
-        super().setUp()
-        self.db_conn.exec(self.init_sql)
-        self.test_parent = TestParentType(1)
-        self.attr = VersionedAttribute(self.test_parent,
-                                       'versioned_tests', self.default_vals[0])
-
-    def test_VersionedAttribute_save(self) -> None:
-        """Test .save() to write to DB."""
-        # check mere .set() calls do not by themselves reflect in the DB
-        self.attr.set(self.default_vals[1])
-        self.assertEqual([],
-                         self.db_conn.row_where('versioned_tests',
-                                                'parent', 1))
-        # check .save() makes history appear in DB
-        self.attr.save(self.db_conn)
-        vals_found = []
-        for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
-            vals_found += [row[2]]
-        self.assertEqual([self.default_vals[1]], vals_found)
-        # check .save() also updates history in DB
-        self.attr.set(self.default_vals[2])
-        self.attr.save(self.db_conn)
-        vals_found = []
-        for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
-            vals_found += [row[2]]
-        self.assertEqual([self.default_vals[1], self.default_vals[2]],
-                         sorted(vals_found))
-
-    def test_VersionedAttribute_history_from_row(self) -> None:
-        """"Test .history_from_row() properly interprets DB rows."""
-        self.attr.set(self.default_vals[1])
-        self.attr.set(self.default_vals[2])
-        self.attr.save(self.db_conn)
-        loaded_attr = VersionedAttribute(self.test_parent, 'versioned_tests',
-                                         self.default_vals[0])
-        for row in self.db_conn.row_where('versioned_tests', 'parent', 1):
-            loaded_attr.history_from_row(row)
-        for timestamp, value in self.attr.history.items():
-            self.assertEqual(value, loaded_attr.history[timestamp])
-        self.assertEqual(len(self.attr.history.keys()),
-                         len(loaded_attr.history.keys()))
-
-
-class TestsWithDBFloat(TestsWithDBStr):
-    """Module tests requiring DB setup."""
-    default_vals: list[str | float] = [0.9, 1.1, 2]
-    init_sql = SQL_TEST_TABLE_FLOAT
-- 
2.30.2