class BaseModel(Generic[BaseModelId]):
"""Template for most of the models we use/derive from the DB."""
table_name = ''
- to_save: list[str] = []
- to_save_versioned: list[str] = []
+ to_save_simples: list[str] = []
to_save_relations: list[tuple[str, str, str, int]] = []
+ versioned_defaults: dict[str, str | float] = {}
add_to_dict: list[str] = []
id_: None | BaseModelId
cache_: dict[BaseModelId, Self]
self.id_ = id_
def __hash__(self) -> int:
- hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
+ hashable = [self.id_] + [getattr(self, name)
+ for name in self.to_save_simples]
for definition in self.to_save_relations:
attr = getattr(self, definition[2])
hashable += [tuple(rel.id_ for rel in attr)]
- for name in self.to_save_versioned:
+ for name in self.to_save_versioned():
hashable += [hash(getattr(self, name))]
return hash(tuple(hashable))
assert isinstance(other.id_, int)
return self.id_ < other.id_
+ @classmethod
+ def to_save_versioned(cls) -> list[str]:
+ """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
+ return list(cls.versioned_defaults.keys())
+
@property
def as_dict(self) -> dict[str, object]:
"""Return self as (json.dumps-compatible) dict."""
library: dict[str, dict[str | int, object]] = {}
d: dict[str, object] = {'id': self.id_, '_library': library}
- for to_save in self.to_save:
+ for to_save in self.to_save_simples:
attr = getattr(self, to_save)
if hasattr(attr, 'as_dict_into_reference'):
d[to_save] = attr.as_dict_into_reference(library)
else:
d[to_save] = attr
- if len(self.to_save_versioned) > 0:
+ if len(self.to_save_versioned()) > 0:
d['_versioned'] = {}
- for k in self.to_save_versioned:
+ for k in self.to_save_versioned():
attr = getattr(self, k)
assert isinstance(d['_versioned'], dict)
d['_versioned'][k] = attr.history
"""Make from DB row (sans relations), update DB cache with it."""
obj = cls(*row)
assert obj.id_ is not None
- for attr_name in cls.to_save_versioned:
+ for attr_name in cls.to_save_versioned():
attr = getattr(obj, attr_name)
table_name = attr.table_name
for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
"""Write self to DB and cache and ensure .id_.
Write both to DB, and to cache. To DB, write .id_ and attributes
- listed in cls.to_save[_versioned|_relations].
+ listed in cls.to_save_[simples|versioned|_relations].
Ensure self.id_ by setting it to what the DB command returns as the
last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
only the case with the Day class, where it's to be a date string.
"""
values = tuple([self.id_] + [getattr(self, key)
- for key in self.to_save])
+ for key in self.to_save_simples])
table_name = self.table_name
cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
values)
if not isinstance(self.id_, str):
self.id_ = cursor.lastrowid # type: ignore[assignment]
self.cache()
- for attr_name in self.to_save_versioned:
+ for attr_name in self.to_save_versioned():
getattr(self, attr_name).save(db_conn)
for table, column, attr_name, key_index in self.to_save_relations:
assert isinstance(self.id_, (int, str))
"""Remove from DB and cache, including dependencies."""
if self.id_ is None or self._get_cached(self.id_) is None:
raise HandledException('cannot remove unsaved item')
- for attr_name in self.to_save_versioned:
+ for attr_name in self.to_save_versioned():
getattr(self, attr_name).remove(db_conn)
for table, column, attr_name, _ in self.to_save_relations:
db_conn.delete_where(table, column, self.id_)
from plomtask.exceptions import NotFoundException, HandledException
-def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
- def wrapper(self: TestCase) -> None:
- if hasattr(self, 'checked_class'):
- f(self)
- return wrapper
+VERSIONED_VALS: dict[str,
+ list[str] | list[float]] = {'str': ['A', 'B'],
+ 'float': [0.3, 1.1]}
-vals_str: list[Any] = ['A', 'B']
-vals_float: list[Any] = [0.3, 1.1]
-
-
-class TestCaseSansDB(TestCase):
- """Tests requiring no DB setup."""
+class TestCaseAugmented(TestCase):
+ """Tester core providing helpful basic internal decorators and methods."""
checked_class: Any
- default_init_args: list[Any] = []
- versioned_defaults_to_test: dict[str, str | float] = {}
- legal_ids = [1, 5]
- illegal_ids = [0]
+ default_init_kwargs: dict[str, Any] = {}
@staticmethod
- def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]:
- def wrapper(self: TestCaseSansDB) -> None:
- owner = self.checked_class(self.legal_ids[0],
- *self.default_init_args)
- for attr_name, default in self.versioned_defaults_to_test.items():
- to_set = vals_str if isinstance(default, str) else vals_float
+ def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]:
+ def wrapper(self: TestCase) -> None:
+ if hasattr(self, 'checked_class'):
+ f(self)
+ return wrapper
+
+ @classmethod
+ def _on_versioned_attributes(cls,
+ f: Callable[..., None]
+ ) -> Callable[..., None]:
+ @cls._within_checked_class
+ def wrapper(self: TestCase) -> None:
+ assert isinstance(self, TestCaseAugmented)
+ for attr_name in self.checked_class.to_save_versioned():
+ default = self.checked_class.versioned_defaults[attr_name]
+ owner = self.checked_class(None, **self.default_init_kwargs)
attr = getattr(owner, attr_name)
- f(self, attr, default, to_set)
+ to_set = VERSIONED_VALS[attr.value_type_name]
+ f(self, owner, attr_name, attr, default, to_set)
return wrapper
- @_within_checked_class
+ @classmethod
+ def _make_from_defaults(cls, id_: float | str | None) -> Any:
+ return cls.checked_class(id_, **cls.default_init_kwargs)
+
+
+class TestCaseSansDB(TestCaseAugmented):
+ """Tests requiring no DB setup."""
+ legal_ids = [1, 5]
+ illegal_ids = [0]
+
+ @TestCaseAugmented._within_checked_class
def test_id_validation(self) -> None:
"""Test .id_ validation/setting."""
for id_ in self.illegal_ids:
with self.assertRaises(HandledException):
- self.checked_class(id_, *self.default_init_args)
+ self._make_from_defaults(id_)
for id_ in self.legal_ids:
- obj = self.checked_class(id_, *self.default_init_args)
+ obj = self._make_from_defaults(id_)
self.assertEqual(obj.id_, id_)
- @_within_checked_class
- @_for_versioned_attr
- def test_versioned_set(self, attr: VersionedAttribute,
- default: str | float, to_set: list[str | float]
+ @TestCaseAugmented._on_versioned_attributes
+ def test_versioned_set(self,
+ _: Any,
+ __: str,
+ attr: VersionedAttribute,
+ default: str | float,
+ to_set: list[str | float]
) -> None:
"""Test VersionedAttribute.set() behaves as expected."""
attr.set(default)
self.assertEqual(list(attr.history.keys())[0], timestamp)
# check that different value _will_ be set/added
attr.set(to_set[0])
- timesorted_vals = [attr.history[t] for t in attr.history.keys()]
+ timesorted_vals = [attr.history[t] for
+ t in sorted(attr.history.keys())]
expected = [default, to_set[0]]
self.assertEqual(timesorted_vals, expected)
# check that a previously used value can be set if not most recent
attr.set(default)
- timesorted_vals = [attr.history[t] for t in attr.history.keys()]
+ timesorted_vals = [attr.history[t] for
+ t in sorted(attr.history.keys())]
expected = [default, to_set[0], default]
self.assertEqual(timesorted_vals, expected)
# again check for same value not being set twice in a row, even for
# later items
attr.set(to_set[1])
- timesorted_vals = [attr.history[t] for t in attr.history.keys()]
+ timesorted_vals = [attr.history[t] for
+ t in sorted(attr.history.keys())]
expected = [default, to_set[0], default, to_set[1]]
self.assertEqual(timesorted_vals, expected)
attr.set(to_set[1])
self.assertEqual(timesorted_vals, expected)
- @_within_checked_class
- @_for_versioned_attr
- def test_versioned_newest(self, attr: VersionedAttribute,
- default: str | float, to_set: list[str | float]
+ @TestCaseAugmented._on_versioned_attributes
+ def test_versioned_newest(self,
+ _: Any,
+ __: str,
+ attr: VersionedAttribute,
+ default: str | float,
+ to_set: list[str | float]
) -> None:
"""Test VersionedAttribute.newest."""
# check .newest on empty history returns .default
attr.set(default)
self.assertEqual(attr.newest, default)
- @_within_checked_class
- @_for_versioned_attr
- def test_versioned_at(self, attr: VersionedAttribute, default: str | float,
- to_set: list[str | float]) -> None:
+ @TestCaseAugmented._on_versioned_attributes
+ def test_versioned_at(self,
+ _: Any,
+ __: str,
+ attr: VersionedAttribute,
+ default: str | float,
+ to_set: list[str | float]
+ ) -> None:
"""Test .at() returns values nearest to queried time, or default."""
# check .at() return default on empty history
timestamp_a = datetime.now().strftime(TIMESTAMP_FMT)
self.assertEqual(attr.at(timestamp_after_c), to_set[1])
-class TestCaseWithDB(TestCase):
+class TestCaseWithDB(TestCaseAugmented):
"""Module tests not requiring DB setup."""
- checked_class: Any
default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
- default_init_kwargs: dict[str, Any] = {}
- test_versioneds: dict[str, type] = {}
def setUp(self) -> None:
Condition.empty_cache()
self.db_conn.close()
remove_file(self.db_file.path)
- @staticmethod
- def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]:
- def wrapper(self: TestCaseWithDB) -> None:
- for attr_name, type_ in self.test_versioneds.items():
- owner = self.checked_class(None, **self.default_init_kwargs)
- to_set = vals_str if str == type_ else vals_float
- attr = getattr(owner, attr_name)
- attr.set(to_set[0])
- attr.set(to_set[1])
- f(self, owner, attr_name, attr, to_set)
- return wrapper
-
def _load_from_db(self, id_: int | str) -> list[object]:
db_found: list[object] = []
for row in self.db_conn.row_where(self.checked_class.table_name,
return db_found
def _change_obj(self, obj: object) -> str:
- attr_name: str = self.checked_class.to_save[-1]
+ attr_name: str = self.checked_class.to_save_simples[-1]
attr = getattr(obj, attr_name)
new_attr: str | int | float | bool
if isinstance(attr, (int, float)):
hashes_db_found = [hash(x) for x in db_found]
self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
- @_within_checked_class
- @_for_versioned_attr
- def test_saving_versioned_attributes(self, owner: Any, attr_name: str,
+ @TestCaseAugmented._on_versioned_attributes
+ def test_saving_versioned_attributes(self,
+ owner: Any,
+ attr_name: str,
attr: VersionedAttribute,
- vals: list[str | float]
+ _: str | float,
+ to_set: list[str | float]
) -> None:
"""Test storage and initialization of versioned attributes."""
attr_vals_saved += [row[2]]
return attr_vals_saved
+ attr.set(to_set[0])
# check that without attr.save() no rows in DB
rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_)
self.assertEqual([], rows)
# check owner.save() created entries as expected in attr table
owner.save(self.db_conn)
attr_vals_saved = retrieve_attr_vals(attr)
- self.assertEqual(vals, attr_vals_saved)
+ self.assertEqual([to_set[0]], attr_vals_saved)
# check changing attr val without save affects owner in memory …
- attr.set(vals[0])
+ attr.set(to_set[1])
cmp_attr = getattr(owner, attr_name)
+ self.assertEqual(to_set, list(cmp_attr.history.values()))
self.assertEqual(cmp_attr.history, attr.history)
# … but does not yet affect DB
attr_vals_saved = retrieve_attr_vals(attr)
- self.assertEqual(vals, attr_vals_saved)
+ self.assertEqual([to_set[0]], attr_vals_saved)
# check individual attr.save also stores new val to DB
attr.save(self.db_conn)
attr_vals_saved = retrieve_attr_vals(attr)
- self.assertEqual(vals + [vals[0]], attr_vals_saved)
+ self.assertEqual(to_set, attr_vals_saved)
- @_within_checked_class
+ @TestCaseAugmented._within_checked_class
def test_saving_and_caching(self) -> None:
"""Test effects of .cache() and .save()."""
id1 = self.default_ids[0]
# check failure to cache without ID (if None-ID input possible)
if isinstance(id1, int):
- obj0 = self.checked_class(None, **self.default_init_kwargs)
+ obj0 = self._make_from_defaults(None)
with self.assertRaises(HandledException):
obj0.cache()
# check mere object init itself doesn't even store in cache
- obj1 = self.checked_class(id1, **self.default_init_kwargs)
+ obj1 = self._make_from_defaults(id1)
self.assertEqual(self.checked_class.get_cache(), {})
# check .cache() fills cache, but not DB
obj1.cache()
# it's generated by cursor.lastrowid on the DB table, and with obj1
# not written there, obj2 should get it first!)
id_input = None if isinstance(id1, int) else id1
- obj2 = self.checked_class(id_input, **self.default_init_kwargs)
+ obj2 = self._make_from_defaults(id_input)
obj2.save(self.db_conn)
self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
# NB: we'll only compare hashes because obj2 itself disappears on
with self.assertRaises(HandledException):
obj1.save(self.db_conn)
- @_within_checked_class
+ @TestCaseAugmented._within_checked_class
def test_by_id(self) -> None:
"""Test .by_id()."""
id1, id2, _ = self.default_ids
# check failure if not yet saved
- obj1 = self.checked_class(id1, **self.default_init_kwargs)
+ obj1 = self._make_from_defaults(id1)
with self.assertRaises(NotFoundException):
self.checked_class.by_id(self.db_conn, id1)
# check identity of cached and retrieved
obj1.cache()
self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
# check identity of saved and retrieved
- obj2 = self.checked_class(id2, **self.default_init_kwargs)
+ obj2 = self._make_from_defaults(id2)
obj2.save(self.db_conn)
self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
- @_within_checked_class
+ @TestCaseAugmented._within_checked_class
def test_by_id_or_create(self) -> None:
"""Test .by_id_or_create."""
# check .by_id_or_create fails if wrong class
self.checked_class.by_id(self.db_conn, item.id_)
self.assertEqual(self.checked_class(item.id_), item)
- @_within_checked_class
+ @TestCaseAugmented._within_checked_class
def test_from_table_row(self) -> None:
"""Test .from_table_row() properly reads in class directly from DB."""
id_ = self.default_ids[0]
- obj = self.checked_class(id_, **self.default_init_kwargs)
+ obj = self._make_from_defaults(id_)
obj.save(self.db_conn)
assert isinstance(obj.id_, type(id_))
for row in self.db_conn.row_where(self.checked_class.table_name,
self.assertEqual({retrieved.id_: retrieved},
self.checked_class.get_cache())
- @_within_checked_class
- @_for_versioned_attr
- def test_versioned_history_from_row(self, owner: Any, _: str,
+ @TestCaseAugmented._on_versioned_attributes
+ def test_versioned_history_from_row(self,
+ owner: Any,
+ _: str,
attr: VersionedAttribute,
- vals: list[str | float]
+ default: str | float,
+ to_set: list[str | float]
) -> None:
""""Test VersionedAttribute.history_from_row() knows its DB rows."""
- vals += [vals[1] * 2]
- attr.set(vals[2])
- attr.set(vals[1])
- attr.set(vals[2])
+ attr.set(to_set[0])
+ attr.set(to_set[1])
owner.save(self.db_conn)
# make empty VersionedAttribute, fill from rows, compare to owner's
- for row in self.db_conn.row_where(owner.table_name, 'id',
- owner.id_):
- loaded_attr = VersionedAttribute(owner, attr.table_name,
- attr.default)
+ for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_):
+ loaded_attr = VersionedAttribute(owner, attr.table_name, default)
for row in self.db_conn.row_where(attr.table_name, 'parent',
owner.id_):
loaded_attr.history_from_row(row)
for timestamp, value in attr.history.items():
self.assertEqual(value, loaded_attr.history[timestamp])
- @_within_checked_class
+ @TestCaseAugmented._within_checked_class
def test_all(self) -> None:
"""Test .all() and its relation to cache and savings."""
- id_1, id_2, id_3 = self.default_ids
- item1 = self.checked_class(id_1, **self.default_init_kwargs)
- item2 = self.checked_class(id_2, **self.default_init_kwargs)
- item3 = self.checked_class(id_3, **self.default_init_kwargs)
+ id1, id2, id3 = self.default_ids
+ item1 = self._make_from_defaults(id1)
+ item2 = self._make_from_defaults(id2)
+ item3 = self._make_from_defaults(id3)
# check .all() returns empty list on un-cached items
self.assertEqual(self.checked_class.all(self.db_conn), [])
# check that all() shows only cached/saved items
self.assertEqual(sorted(self.checked_class.all(self.db_conn)),
sorted([item1, item2, item3]))
- @_within_checked_class
+ @TestCaseAugmented._within_checked_class
def test_singularity(self) -> None:
"""Test pointers made for single object keep pointing to it."""
id1 = self.default_ids[0]
- obj = self.checked_class(id1, **self.default_init_kwargs)
+ obj = self._make_from_defaults(id1)
obj.save(self.db_conn)
# change object, expect retrieved through .by_id to carry change
attr_name = self._change_obj(obj)
retrieved = self.checked_class.by_id(self.db_conn, id1)
self.assertEqual(new_attr, getattr(retrieved, attr_name))
- @_within_checked_class
- @_for_versioned_attr
- def test_versioned_singularity(self, owner: Any, attr_name: str,
+ @TestCaseAugmented._on_versioned_attributes
+ def test_versioned_singularity(self,
+ owner: Any,
+ attr_name: str,
attr: VersionedAttribute,
- vals: list[str | float]
+ _: str | float,
+ to_set: list[str | float]
) -> None:
"""Test singularity of VersionedAttributes on saving."""
owner.save(self.db_conn)
# change obj, expect retrieved through .by_id to carry change
- attr.set(vals[0])
+ attr.set(to_set[0])
retrieved = self.checked_class.by_id(self.db_conn, owner.id_)
attr_retrieved = getattr(retrieved, attr_name)
self.assertEqual(attr.history, attr_retrieved.history)
- @_within_checked_class
+ @TestCaseAugmented._within_checked_class
def test_remove(self) -> None:
"""Test .remove() effects on DB and cache."""
id_ = self.default_ids[0]
- obj = self.checked_class(id_, **self.default_init_kwargs)
+ obj = self._make_from_defaults(id_)
# check removal only works after saving
with self.assertRaises(HandledException):
obj.remove(self.db_conn)