return wrapper
@classmethod
- def _make_from_defaults(cls, id_: float | str | None) -> Any:
+ def _make_from_defaults(cls, id_: int | None) -> Any:
return cls.checked_class(id_, **cls.default_init_kwargs)
class TestCaseSansDB(TestCaseAugmented):
"""Tests requiring no DB setup."""
- legal_ids: list[str] | list[int] = [1, 5]
- illegal_ids: list[str] | list[int] = [0]
+ _legal_ids: list[int] = [1, 5]
+ _illegal_ids: list[int] = [0]
@TestCaseAugmented._run_if_sans_db
def test_id_validation(self) -> None:
"""Test .id_ validation/setting."""
- for id_ in self.illegal_ids:
+ for id_ in self._illegal_ids:
with self.assertRaises(HandledException):
self._make_from_defaults(id_)
- for id_ in self.legal_ids:
+ for id_ in self._legal_ids:
obj = self._make_from_defaults(id_)
self.assertEqual(obj.id_, id_)
class TestCaseWithDB(TestCaseAugmented):
"""Module tests not requiring DB setup."""
- default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
+ _default_ids: tuple[int, int, int] = (1, 2, 3)
def setUp(self) -> None:
Condition.empty_cache()
self.db_conn.close()
remove_file(self.db_file.path)
- def _load_from_db(self, id_: int | str) -> list[object]:
+ def _load_from_db(self, id_: int) -> list[object]:
db_found: list[object] = []
for row in self.db_conn.row_where(self.checked_class.table_name,
'id', id_):
@TestCaseAugmented._run_if_with_db_but_not_server
def test_saving_and_caching(self) -> None:
"""Test effects of .cache() and .save()."""
- id1 = self.default_ids[0]
+ id1 = self._default_ids[0]
# check failure to cache without ID (if None-ID input possible)
- if isinstance(id1, int):
- obj0 = self._make_from_defaults(None)
- with self.assertRaises(HandledException):
- obj0.cache()
+ obj0 = self._make_from_defaults(None)
+ with self.assertRaises(HandledException):
+ obj0.cache()
# check mere object init itself doesn't even store in cache
obj1 = self._make_from_defaults(id1)
self.assertEqual(self.checked_class.get_cache(), {})
self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
found_in_db = self._load_from_db(id1)
self.assertEqual(found_in_db, [])
- # check .save() sets ID (for int IDs), updates cache, and fills DB
+ # check .save() sets ID, updates cache, and fills DB
# (expect ID to be set to id1, despite obj1 already having that as ID:
# it's generated by cursor.lastrowid on the DB table, and with obj1
# not written there, obj2 should get it first!)
- id_input = None if isinstance(id1, int) else id1
- obj2 = self._make_from_defaults(id_input)
+ obj2 = self._make_from_defaults(None)
obj2.save(self.db_conn)
self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
# NB: we'll only compare hashes because obj2 itself disappears on
@TestCaseAugmented._run_if_with_db_but_not_server
def test_by_id(self) -> None:
"""Test .by_id()."""
- id1, id2, _ = self.default_ids
+ id1, id2, _ = self._default_ids
# check failure if not yet saved
obj1 = self._make_from_defaults(id1)
with self.assertRaises(NotFoundException):
with self.assertRaises(HandledException):
self.checked_class.by_id_or_create(self.db_conn, None)
return
- # check ID input of None creates, on saving, ID=1,2,… for int IDs
- if isinstance(self.default_ids[0], int):
- for n in range(2):
- item = self.checked_class.by_id_or_create(self.db_conn, None)
- self.assertEqual(item.id_, None)
- item.save(self.db_conn)
- self.assertEqual(item.id_, n+1)
+ # check ID input of None creates, on saving, ID=1,2,…
+ for n in range(2):
+ item = self.checked_class.by_id_or_create(self.db_conn, None)
+ self.assertEqual(item.id_, None)
+ item.save(self.db_conn)
+ self.assertEqual(item.id_, n+1)
# check .by_id_or_create acts like normal instantiation (sans saving)
- id_ = self.default_ids[2]
+ id_ = self._default_ids[2]
item = self.checked_class.by_id_or_create(self.db_conn, id_)
self.assertEqual(item.id_, id_)
with self.assertRaises(NotFoundException):
@TestCaseAugmented._run_if_with_db_but_not_server
def test_from_table_row(self) -> None:
"""Test .from_table_row() properly reads in class directly from DB."""
- id_ = self.default_ids[0]
- obj = self._make_from_defaults(id_)
+ obj = self._make_from_defaults(self._default_ids[0])
obj.save(self.db_conn)
for row in self.db_conn.row_where(self.checked_class.table_name,
'id', obj.id_):
@TestCaseAugmented._run_if_with_db_but_not_server
def test_all(self) -> None:
"""Test .all() and its relation to cache and savings."""
- id1, id2, id3 = self.default_ids
+ id1, id2, id3 = self._default_ids
item1 = self._make_from_defaults(id1)
item2 = self._make_from_defaults(id2)
item3 = self._make_from_defaults(id3)
@TestCaseAugmented._run_if_with_db_but_not_server
def test_singularity(self) -> None:
"""Test pointers made for single object keep pointing to it."""
- id1 = self.default_ids[0]
+ id1 = self._default_ids[0]
obj = self._make_from_defaults(id1)
obj.save(self.db_conn)
# change object, expect retrieved through .by_id to carry change
@TestCaseAugmented._run_if_with_db_but_not_server
def test_remove(self) -> None:
"""Test .remove() effects on DB and cache."""
- id_ = self.default_ids[0]
- obj = self._make_from_defaults(id_)
+ obj = self._make_from_defaults(self._default_ids[0])
# check removal only works after saving
with self.assertRaises(HandledException):
obj.remove(self.db_conn)
_fields: dict[str, Any]
_on_empty_make_temp: tuple[str, str]
- def __init__(self,
- todos: list[dict[str, Any]] | None = None,
- procs: list[dict[str, Any]] | None = None,
- procsteps: list[dict[str, Any]] | None = None,
- conds: list[dict[str, Any]] | None = None,
- days: list[dict[str, Any]] | None = None
- ) -> None:
- # pylint: disable=too-many-arguments
+ def __init__(self) -> None:
for name in ['_default_dict', '_fields', '_forced']:
if not hasattr(self, name):
setattr(self, name, {})
- self._lib = {}
- for title, items in [('Todo', todos),
- ('Process', procs),
- ('ProcessStep', procsteps),
- ('Condition', conds),
- ('Day', days)]:
- if items:
- self._lib[title] = self._as_refs(items)
+ self._lib: dict[str, dict[int, dict[str, Any]]] = {}
for k, v in self._default_dict.items():
if k not in self._fields:
self._fields[k] = v
d[k] = v
if make_temp:
json = json_dumps(d)
- id_ = id_ if id_ is not None else '?'
+ id_ = id_ if id_ is not None else -1
self.lib_del(category, id_)
d = json_loads(json)
return d
- def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
+ def lib_get(self, category: str, id_: int) -> dict[str, Any]:
"""From library, return item of category and id_, or empty dict."""
- str_id = str(id_)
- if category in self._lib and str_id in self._lib[category]:
- return self._lib[category][str_id]
+ if category in self._lib and id_ in self._lib[category]:
+ return self._lib[category][id_]
return {}
def lib_all(self, category: str) -> list[dict[str, Any]]:
"""Update library for category with items."""
if category not in self._lib:
self._lib[category] = {}
- for k, v in self._as_refs(items).items():
- self._lib[category][k] = v
+ for item in items:
+ id_ = item['id'] if item['id'] is not None else -1
+ assert isinstance(id_, int)
+ self._lib[category][id_] = item
- def lib_del(self, category: str, id_: str | int) -> None:
+ def lib_del(self, category: str, id_: int) -> None:
"""Remove category element of id_ from library."""
- del self._lib[category][str(id_)]
+ del self._lib[category][id_]
if 0 == len(self._lib[category]):
del self._lib[category]
"""Set ._forced field to ensure value in .as_dict."""
self._forced[field_name] = value
- def unforce(self, field_name: str) -> None:
- """Unset ._forced field."""
- del self._forced[field_name]
-
- @staticmethod
- def _as_refs(items: list[dict[str, object]]
- ) -> dict[str, dict[str, object]]:
- """Return dictionary of items by their 'id' fields."""
- refs = {}
- for item in items:
- id_ = str(item['id']) if item['id'] is not None else '?'
- refs[id_] = item
- return refs
-
@staticmethod
def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
"""Return list of only 'id' fields of items."""
self.assertEqual(response.status, 200)
retrieved = json_loads(response.read().decode())
rewrite_history_keys_in(retrieved)
- cmp = expected.as_dict
+ # to convert ._lib int keys to str
+ cmp = json_loads(json_dumps(expected.as_dict))
try:
self.assertEqual(cmp, retrieved)
except AssertionError as e: