From: Christian Heller Date: Mon, 17 Jun 2024 23:54:46 +0000 (+0200) Subject: Refactor BaseModel.from_table_row testing. X-Git-Url: https://plomlompom.com/repos/%7B%7B%20web_path%20%7D%7D/decks/%7B%7B%20deck_id%20%7D%7D/template?a=commitdiff_plain;h=e3bfd84f9061d5f03ec5f5764f75e4137505ea45;p=plomtask Refactor BaseModel.from_table_row testing. --- diff --git a/plomtask/db.py b/plomtask/db.py index 385e798..853b4c6 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -344,7 +344,7 @@ class BaseModel(Generic[BaseModelId]): return obj return None - def _cache(self) -> None: + def cache(self) -> None: """Update object in class's cache. Also calls ._disappear if cache holds older reference to object of same @@ -383,7 +383,7 @@ class BaseModel(Generic[BaseModelId]): table_name = attr.table_name for row_ in db_conn.row_where(table_name, 'parent', obj.id_): attr.history_from_row(row_) - obj._cache() + obj.cache() return obj @classmethod @@ -497,7 +497,7 @@ class BaseModel(Generic[BaseModelId]): values) if not isinstance(self.id_, str): self.id_ = cursor.lastrowid # type: ignore[assignment] - self._cache() + self.cache() for attr_name in self.to_save_versioned: getattr(self, attr_name).save(db_conn) for table, column, attr_name, key_index in self.to_save_relations: diff --git a/tests/conditions.py b/tests/conditions.py index 5270812..afb1841 100644 --- a/tests/conditions.py +++ b/tests/conditions.py @@ -19,9 +19,9 @@ class TestsWithDB(TestCaseWithDB): default_init_kwargs = {'is_active': False} test_versioneds = {'title': str, 'description': str} - def test_Condition_from_table_row(self) -> None: + def test_from_table_row(self) -> None: """Test .from_table_row() properly reads in class from DB""" - self.check_from_table_row() + super().test_from_table_row() self.check_versioned_from_table_row('title', str) self.check_versioned_from_table_row('description', str) diff --git a/tests/days.py b/tests/days.py index 901667f..e4c9de5 100644 --- a/tests/days.py +++ b/tests/days.py @@ -53,10 +53,6 @@ class TestsWithDB(TestCaseWithDB): kwargs = {'date': self.default_ids[0], 'comment': 'foo'} self.check_saving_and_caching(**kwargs) - def test_Day_from_table_row(self) -> None: - """Test .from_table_row() properly reads in class from DB""" - self.check_from_table_row() - def test_Day_by_id(self) -> None: """Test .by_id().""" self.check_by_id() diff --git a/tests/processes.py b/tests/processes.py index 4d2252c..d54fe84 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -63,9 +63,9 @@ class TestsWithDB(TestCaseWithDB): self.assertEqual(sorted(r.enables), sorted(set2)) self.assertEqual(sorted(r.disables), sorted(set3)) - def test_Process_from_table_row(self) -> None: + def test_from_table_row(self) -> None: """Test .from_table_row() properly reads in class from DB""" - self.check_from_table_row() + super().test_from_table_row() self.check_versioned_from_table_row('title', str) self.check_versioned_from_table_row('description', str) self.check_versioned_from_table_row('effort', float) diff --git a/tests/utils.py b/tests/utils.py index d6c5b20..f76fe33 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -130,17 +130,33 @@ class TestCaseWithDB(TestCase): self.assertEqual(self.checked_class(id2), by_id_created) self.check_storage([obj]) - def check_from_table_row(self, *args: Any) -> None: - """Test .from_table_row() properly reads in class from DB""" + def test_from_table_row(self) -> None: + """Test .from_table_row() properly reads in class from DB.""" + if not hasattr(self, 'checked_class'): + return id_ = self.default_ids[0] - obj = self.checked_class(id_, *args) # pylint: disable=not-callable + obj = self.checked_class(id_, **self.default_init_kwargs) obj.save(self.db_conn) assert isinstance(obj.id_, type(self.default_ids[0])) for row in self.db_conn.row_where(self.checked_class.table_name, 'id', obj.id_): + # check .from_table_row reproduces state saved, no matter if obj + # later changed (with caching even) hash_original = hash(obj) + attr_name = self.checked_class.to_save[-1] + attr = getattr(obj, attr_name) + if isinstance(attr, (int, float)): + setattr(obj, attr_name, attr + 1) + elif isinstance(attr, str): + setattr(obj, attr_name, attr + "_") + elif isinstance(attr, bool): + setattr(obj, attr_name, not attr) + obj.cache() + to_cmp = getattr(obj, attr_name) retrieved = self.checked_class.from_table_row(self.db_conn, row) + self.assertNotEqual(to_cmp, getattr(retrieved, attr_name)) self.assertEqual(hash_original, hash(retrieved)) + # check cache contains what .from_table_row just produced self.assertEqual({retrieved.id_: retrieved}, self.checked_class.get_cache())