from unittest import TestCase
from threading import Thread
from http.client import HTTPConnection
+from json import loads as json_loads
from urllib.parse import urlencode
from uuid import uuid4
from os import remove as remove_file
self.assertEqual(self.checked_class(id2), by_id_created)
self.check_storage([obj])
- def check_from_table_row(self, *args: Any) -> None:
- """Test .from_table_row() properly reads in class from DB"""
+ def test_from_table_row(self) -> None:
+ """Test .from_table_row() properly reads in class from DB."""
+ if not hasattr(self, 'checked_class'):
+ return
id_ = self.default_ids[0]
- obj = self.checked_class(id_, *args) # pylint: disable=not-callable
+ obj = self.checked_class(id_, **self.default_init_kwargs)
obj.save(self.db_conn)
assert isinstance(obj.id_, type(self.default_ids[0]))
for row in self.db_conn.row_where(self.checked_class.table_name,
'id', obj.id_):
+ # check .from_table_row reproduces state saved, no matter if obj
+ # later changed (with caching even)
hash_original = hash(obj)
+ attr_name = self.checked_class.to_save[-1]
+ attr = getattr(obj, attr_name)
+ if isinstance(attr, (int, float)):
+ setattr(obj, attr_name, attr + 1)
+ elif isinstance(attr, str):
+ setattr(obj, attr_name, attr + "_")
+ elif isinstance(attr, bool):
+ setattr(obj, attr_name, not attr)
+ obj.cache()
+ to_cmp = getattr(obj, attr_name)
retrieved = self.checked_class.from_table_row(self.db_conn, row)
+ self.assertNotEqual(to_cmp, getattr(retrieved, attr_name))
self.assertEqual(hash_original, hash(retrieved))
+ # check cache contains what .from_table_row just produced
self.assertEqual({retrieved.id_: retrieved},
self.checked_class.get_cache())
self.check_post(form_data, f'/process?id={id_}', 302,
f'/process?id={id_}')
return form_data
+
+ def check_json_get(self, path: str, expected: dict[str, object]) -> None:
+ """Compare JSON on GET path with expected.
+
+ To simplify comparison of VersionedAttribute histories, transforms
+ timestamp keys of VersionedAttribute history keys into integers
+ counting chronologically forward from 0.
+ """
+ def rewrite_history_keys_in(item: Any) -> Any:
+ if isinstance(item, dict):
+ if '_versioned' in item.keys():
+ for k in item['_versioned']:
+ vals = item['_versioned'][k].values()
+ history = {}
+ for i, val in enumerate(vals):
+ history[i] = val
+ item['_versioned'][k] = history
+ for k in list(item.keys()):
+ rewrite_history_keys_in(item[k])
+ elif isinstance(item, list):
+ item[:] = [rewrite_history_keys_in(i) for i in item]
+ return item
+ self.conn.request('GET', path)
+ response = self.conn.getresponse()
+ self.assertEqual(response.status, 200)
+ retrieved = json_loads(response.read().decode())
+ rewrite_history_keys_in(retrieved)
+ self.assertEqual(expected, retrieved)