From: Christian Heller Date: Thu, 11 Jul 2024 15:12:53 +0000 (+0200) Subject: Lots of refactoring. X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/decks/%27%29;%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20chunks.push%28escapeHTML%28span%5B2%5D%29%29;%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20chunks.push%28%27?a=commitdiff_plain;h=6c5e67fe668548619df9e903e033d29b3a216b75;p=plomtask Lots of refactoring. --- diff --git a/plomtask/conditions.py b/plomtask/conditions.py index 15dcb9d..e752e91 100644 --- a/plomtask/conditions.py +++ b/plomtask/conditions.py @@ -8,8 +8,8 @@ from plomtask.exceptions import HandledException class Condition(BaseModel[int]): """Non-Process dependency for ProcessSteps and Todos.""" table_name = 'conditions' - to_save = ['is_active'] - to_save_versioned = ['title', 'description'] + to_save_simples = ['is_active'] + versioned_defaults = {'title': 'UNNAMED', 'description': ''} to_search = ['title.newest', 'description.newest'] can_create_by_id = True sorters = {'is_active': lambda c: c.is_active, @@ -18,9 +18,10 @@ class Condition(BaseModel[int]): def __init__(self, id_: int | None, is_active: bool = False) -> None: super().__init__(id_) self.is_active = is_active - self.title = VersionedAttribute(self, 'condition_titles', 'UNNAMED') - self.description = VersionedAttribute(self, 'condition_descriptions', - '') + for name in ['title', 'description']: + attr = VersionedAttribute(self, f'condition_{name}s', + self.versioned_defaults[name]) + setattr(self, name, attr) def remove(self, db_conn: DatabaseConnection) -> None: """Remove from DB, with VersionedAttributes. diff --git a/plomtask/days.py b/plomtask/days.py index 2320130..18c9769 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -11,7 +11,7 @@ from plomtask.dating import (DATE_FORMAT, valid_date) class Day(BaseModel[str]): """Individual days defined by their dates.""" table_name = 'days' - to_save = ['comment'] + to_save_simples = ['comment'] add_to_dict = ['todos'] can_create_by_id = True diff --git a/plomtask/db.py b/plomtask/db.py index 13cdaef..f1169c3 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -232,9 +232,9 @@ BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]') class BaseModel(Generic[BaseModelId]): """Template for most of the models we use/derive from the DB.""" table_name = '' - to_save: list[str] = [] - to_save_versioned: list[str] = [] + to_save_simples: list[str] = [] to_save_relations: list[tuple[str, str, str, int]] = [] + versioned_defaults: dict[str, str | float] = {} add_to_dict: list[str] = [] id_: None | BaseModelId cache_: dict[BaseModelId, Self] @@ -253,11 +253,12 @@ class BaseModel(Generic[BaseModelId]): self.id_ = id_ def __hash__(self) -> int: - hashable = [self.id_] + [getattr(self, name) for name in self.to_save] + hashable = [self.id_] + [getattr(self, name) + for name in self.to_save_simples] for definition in self.to_save_relations: attr = getattr(self, definition[2]) hashable += [tuple(rel.id_ for rel in attr)] - for name in self.to_save_versioned: + for name in self.to_save_versioned(): hashable += [hash(getattr(self, name))] return hash(tuple(hashable)) @@ -274,20 +275,25 @@ class BaseModel(Generic[BaseModelId]): assert isinstance(other.id_, int) return self.id_ < other.id_ + @classmethod + def to_save_versioned(cls) -> list[str]: + """Return keys of cls.versioned_defaults assuming we wanna save 'em.""" + return list(cls.versioned_defaults.keys()) + @property def as_dict(self) -> dict[str, object]: """Return self as (json.dumps-compatible) dict.""" library: dict[str, dict[str | int, object]] = {} d: dict[str, object] = {'id': self.id_, '_library': library} - for to_save in self.to_save: + for to_save in self.to_save_simples: attr = getattr(self, to_save) if hasattr(attr, 'as_dict_into_reference'): d[to_save] = attr.as_dict_into_reference(library) else: d[to_save] = attr - if len(self.to_save_versioned) > 0: + if len(self.to_save_versioned()) > 0: d['_versioned'] = {} - for k in self.to_save_versioned: + for k in self.to_save_versioned(): attr = getattr(self, k) assert isinstance(d['_versioned'], dict) d['_versioned'][k] = attr.history @@ -438,7 +444,7 @@ class BaseModel(Generic[BaseModelId]): """Make from DB row (sans relations), update DB cache with it.""" obj = cls(*row) assert obj.id_ is not None - for attr_name in cls.to_save_versioned: + for attr_name in cls.to_save_versioned(): attr = getattr(obj, attr_name) table_name = attr.table_name for row_ in db_conn.row_where(table_name, 'parent', obj.id_): @@ -549,7 +555,7 @@ class BaseModel(Generic[BaseModelId]): """Write self to DB and cache and ensure .id_. Write both to DB, and to cache. To DB, write .id_ and attributes - listed in cls.to_save[_versioned|_relations]. + listed in cls.to_save_[simples|versioned|_relations]. Ensure self.id_ by setting it to what the DB command returns as the last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already @@ -557,14 +563,14 @@ class BaseModel(Generic[BaseModelId]): only the case with the Day class, where it's to be a date string. """ values = tuple([self.id_] + [getattr(self, key) - for key in self.to_save]) + for key in self.to_save_simples]) table_name = self.table_name cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES', values) if not isinstance(self.id_, str): self.id_ = cursor.lastrowid # type: ignore[assignment] self.cache() - for attr_name in self.to_save_versioned: + for attr_name in self.to_save_versioned(): getattr(self, attr_name).save(db_conn) for table, column, attr_name, key_index in self.to_save_relations: assert isinstance(self.id_, (int, str)) @@ -576,7 +582,7 @@ class BaseModel(Generic[BaseModelId]): """Remove from DB and cache, including dependencies.""" if self.id_ is None or self._get_cached(self.id_) is None: raise HandledException('cannot remove unsaved item') - for attr_name in self.to_save_versioned: + for attr_name in self.to_save_versioned(): getattr(self, attr_name).remove(db_conn) for table, column, attr_name, _ in self.to_save_relations: db_conn.delete_where(table, column, self.id_) diff --git a/plomtask/processes.py b/plomtask/processes.py index bb1de3a..9870ab3 100644 --- a/plomtask/processes.py +++ b/plomtask/processes.py @@ -25,8 +25,7 @@ class Process(BaseModel[int], ConditionsRelations): """Template for, and metadata for, Todos, and their arrangements.""" # pylint: disable=too-many-instance-attributes table_name = 'processes' - to_save = ['calendarize'] - to_save_versioned = ['title', 'description', 'effort'] + to_save_simples = ['calendarize'] to_save_relations = [('process_conditions', 'process', 'conditions', 0), ('process_blockers', 'process', 'blockers', 0), ('process_enables', 'process', 'enables', 0), @@ -34,6 +33,7 @@ class Process(BaseModel[int], ConditionsRelations): ('process_step_suppressions', 'process', 'suppressed_steps', 0)] add_to_dict = ['explicit_steps'] + versioned_defaults = {'title': 'UNNAMED', 'description': '', 'effort': 1.0} to_search = ['title.newest', 'description.newest'] can_create_by_id = True sorters = {'steps': lambda p: len(p.explicit_steps), @@ -44,9 +44,10 @@ class Process(BaseModel[int], ConditionsRelations): def __init__(self, id_: int | None, calendarize: bool = False) -> None: BaseModel.__init__(self, id_) ConditionsRelations.__init__(self) - self.title = VersionedAttribute(self, 'process_titles', 'UNNAMED') - self.description = VersionedAttribute(self, 'process_descriptions', '') - self.effort = VersionedAttribute(self, 'process_efforts', 1.0) + for name in ['title', 'description', 'effort']: + attr = VersionedAttribute(self, f'process_{name}s', + self.versioned_defaults[name]) + setattr(self, name, attr) self.explicit_steps: list[ProcessStep] = [] self.suppressed_steps: list[ProcessStep] = [] self.calendarize = calendarize @@ -210,7 +211,7 @@ class Process(BaseModel[int], ConditionsRelations): class ProcessStep(BaseModel[int]): """Sub-unit of Processes.""" table_name = 'process_steps' - to_save = ['owner_id', 'step_process_id', 'parent_step_id'] + to_save_simples = ['owner_id', 'step_process_id', 'parent_step_id'] def __init__(self, id_: int | None, owner_id: int, step_process_id: int, parent_step_id: int | None) -> None: diff --git a/plomtask/todos.py b/plomtask/todos.py index f5388b5..1f55ae7 100644 --- a/plomtask/todos.py +++ b/plomtask/todos.py @@ -39,8 +39,8 @@ class Todo(BaseModel[int], ConditionsRelations): # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-public-methods table_name = 'todos' - to_save = ['process_id', 'is_done', 'date', 'comment', 'effort', - 'calendarize'] + to_save_simples = ['process_id', 'is_done', 'date', 'comment', 'effort', + 'calendarize'] to_save_relations = [('todo_conditions', 'todo', 'conditions', 0), ('todo_blockers', 'todo', 'blockers', 0), ('todo_enables', 'todo', 'enables', 0), @@ -231,6 +231,7 @@ class Todo(BaseModel[int], ConditionsRelations): @property def title(self) -> VersionedAttribute: """Shortcut to .process.title.""" + assert isinstance(self.process.title, VersionedAttribute) return self.process.title @property diff --git a/plomtask/versioned_attributes.py b/plomtask/versioned_attributes.py index 8861c98..cfcbf87 100644 --- a/plomtask/versioned_attributes.py +++ b/plomtask/versioned_attributes.py @@ -17,12 +17,12 @@ class VersionedAttribute: parent: Any, table_name: str, default: str | float) -> None: self.parent = parent self.table_name = table_name - self.default = default + self._default = default self.history: dict[str, str | float] = {} def __hash__(self) -> int: history_tuples = tuple((k, v) for k, v in self.history.items()) - hashable = (self.parent.id_, self.table_name, self.default, + hashable = (self.parent.id_, self.table_name, self._default, history_tuples) return hash(hashable) @@ -31,11 +31,16 @@ class VersionedAttribute: """Return most recent timestamp.""" return sorted(self.history.keys())[-1] + @property + def value_type_name(self) -> str: + """Return string of name of attribute value type.""" + return type(self._default).__name__ + @property def newest(self) -> str | float: - """Return most recent value, or self.default if self.history empty.""" + """Return most recent value, or self._default if self.history empty.""" if 0 == len(self.history): - return self.default + return self._default return self.history[self._newest_timestamp] def reset_timestamp(self, old_str: str, new_str: str) -> None: @@ -89,7 +94,7 @@ class VersionedAttribute: queried_time += ' 23:59:59.999' sorted_timestamps = sorted(self.history.keys()) if 0 == len(sorted_timestamps): - return self.default + return self._default selected_timestamp = sorted_timestamps[0] for timestamp in sorted_timestamps[1:]: if timestamp > queried_time: diff --git a/tests/conditions.py b/tests/conditions.py index f84533e..69dcc66 100644 --- a/tests/conditions.py +++ b/tests/conditions.py @@ -9,14 +9,12 @@ from plomtask.exceptions import HandledException class TestsSansDB(TestCaseSansDB): """Tests requiring no DB setup.""" checked_class = Condition - versioned_defaults_to_test = {'title': 'UNNAMED', 'description': ''} class TestsWithDB(TestCaseWithDB): """Tests requiring DB, but not server setup.""" checked_class = Condition default_init_kwargs = {'is_active': False} - test_versioneds = {'title': str, 'description': str} def test_remove(self) -> None: """Test .remove() effects on DB and cache.""" diff --git a/tests/processes.py b/tests/processes.py index 1b20e21..501a163 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -10,20 +10,18 @@ from plomtask.todos import Todo class TestsSansDB(TestCaseSansDB): """Module tests not requiring DB setup.""" checked_class = Process - versioned_defaults_to_test = {'title': 'UNNAMED', 'description': '', - 'effort': 1.0} class TestsSansDBProcessStep(TestCaseSansDB): """Module tests not requiring DB setup.""" checked_class = ProcessStep - default_init_args = [2, 3, 4] + default_init_kwargs = {'owner_id': 2, 'step_process_id': 3, + 'parent_step_id': 4} class TestsWithDB(TestCaseWithDB): """Module tests requiring DB setup.""" checked_class = Process - test_versioneds = {'title': str, 'description': str, 'effort': float} def three_processes(self) -> tuple[Process, Process, Process]: """Return three saved processes.""" diff --git a/tests/todos.py b/tests/todos.py index dd57ee4..6b6276f 100644 --- a/tests/todos.py +++ b/tests/todos.py @@ -10,15 +10,12 @@ from plomtask.exceptions import (NotFoundException, BadFormatException, class TestsWithDB(TestCaseWithDB, TestCaseSansDB): """Tests requiring DB, but not server setup. - NB: We subclass TestCaseSansDB too, to pull in its .test_id_validation, - which for Todo wouldn't run without a DB being set up due to the need for - Processes with set IDs. + NB: We subclass TestCaseSansDB too, to run any tests there that due to any + Todo requiring a _saved_ Process wouldn't run without a DB. """ checked_class = Todo default_init_kwargs = {'process': None, 'is_done': False, 'date': '2024-01-01'} - # solely used for TestCaseSansDB.test_id_setting - default_init_args = [None, False, '2024-01-01'] def setUp(self) -> None: super().setUp() @@ -31,7 +28,6 @@ class TestsWithDB(TestCaseWithDB, TestCaseSansDB): self.cond2 = Condition(None) self.cond2.save(self.db_conn) self.default_init_kwargs['process'] = self.proc - self.default_init_args[0] = self.proc def test_Todo_init(self) -> None: """Test creation of Todo and what they default to.""" diff --git a/tests/utils.py b/tests/utils.py index 52cc66e..a03de94 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,50 +20,65 @@ from plomtask.versioned_attributes import VersionedAttribute, TIMESTAMP_FMT from plomtask.exceptions import NotFoundException, HandledException -def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: - def wrapper(self: TestCase) -> None: - if hasattr(self, 'checked_class'): - f(self) - return wrapper +VERSIONED_VALS: dict[str, + list[str] | list[float]] = {'str': ['A', 'B'], + 'float': [0.3, 1.1]} -vals_str: list[Any] = ['A', 'B'] -vals_float: list[Any] = [0.3, 1.1] - - -class TestCaseSansDB(TestCase): - """Tests requiring no DB setup.""" +class TestCaseAugmented(TestCase): + """Tester core providing helpful basic internal decorators and methods.""" checked_class: Any - default_init_args: list[Any] = [] - versioned_defaults_to_test: dict[str, str | float] = {} - legal_ids = [1, 5] - illegal_ids = [0] + default_init_kwargs: dict[str, Any] = {} @staticmethod - def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]: - def wrapper(self: TestCaseSansDB) -> None: - owner = self.checked_class(self.legal_ids[0], - *self.default_init_args) - for attr_name, default in self.versioned_defaults_to_test.items(): - to_set = vals_str if isinstance(default, str) else vals_float + def _within_checked_class(f: Callable[..., None]) -> Callable[..., None]: + def wrapper(self: TestCase) -> None: + if hasattr(self, 'checked_class'): + f(self) + return wrapper + + @classmethod + def _on_versioned_attributes(cls, + f: Callable[..., None] + ) -> Callable[..., None]: + @cls._within_checked_class + def wrapper(self: TestCase) -> None: + assert isinstance(self, TestCaseAugmented) + for attr_name in self.checked_class.to_save_versioned(): + default = self.checked_class.versioned_defaults[attr_name] + owner = self.checked_class(None, **self.default_init_kwargs) attr = getattr(owner, attr_name) - f(self, attr, default, to_set) + to_set = VERSIONED_VALS[attr.value_type_name] + f(self, owner, attr_name, attr, default, to_set) return wrapper - @_within_checked_class + @classmethod + def _make_from_defaults(cls, id_: float | str | None) -> Any: + return cls.checked_class(id_, **cls.default_init_kwargs) + + +class TestCaseSansDB(TestCaseAugmented): + """Tests requiring no DB setup.""" + legal_ids = [1, 5] + illegal_ids = [0] + + @TestCaseAugmented._within_checked_class def test_id_validation(self) -> None: """Test .id_ validation/setting.""" for id_ in self.illegal_ids: with self.assertRaises(HandledException): - self.checked_class(id_, *self.default_init_args) + self._make_from_defaults(id_) for id_ in self.legal_ids: - obj = self.checked_class(id_, *self.default_init_args) + obj = self._make_from_defaults(id_) self.assertEqual(obj.id_, id_) - @_within_checked_class - @_for_versioned_attr - def test_versioned_set(self, attr: VersionedAttribute, - default: str | float, to_set: list[str | float] + @TestCaseAugmented._on_versioned_attributes + def test_versioned_set(self, + _: Any, + __: str, + attr: VersionedAttribute, + default: str | float, + to_set: list[str | float] ) -> None: """Test VersionedAttribute.set() behaves as expected.""" attr.set(default) @@ -76,27 +91,33 @@ class TestCaseSansDB(TestCase): self.assertEqual(list(attr.history.keys())[0], timestamp) # check that different value _will_ be set/added attr.set(to_set[0]) - timesorted_vals = [attr.history[t] for t in attr.history.keys()] + timesorted_vals = [attr.history[t] for + t in sorted(attr.history.keys())] expected = [default, to_set[0]] self.assertEqual(timesorted_vals, expected) # check that a previously used value can be set if not most recent attr.set(default) - timesorted_vals = [attr.history[t] for t in attr.history.keys()] + timesorted_vals = [attr.history[t] for + t in sorted(attr.history.keys())] expected = [default, to_set[0], default] self.assertEqual(timesorted_vals, expected) # again check for same value not being set twice in a row, even for # later items attr.set(to_set[1]) - timesorted_vals = [attr.history[t] for t in attr.history.keys()] + timesorted_vals = [attr.history[t] for + t in sorted(attr.history.keys())] expected = [default, to_set[0], default, to_set[1]] self.assertEqual(timesorted_vals, expected) attr.set(to_set[1]) self.assertEqual(timesorted_vals, expected) - @_within_checked_class - @_for_versioned_attr - def test_versioned_newest(self, attr: VersionedAttribute, - default: str | float, to_set: list[str | float] + @TestCaseAugmented._on_versioned_attributes + def test_versioned_newest(self, + _: Any, + __: str, + attr: VersionedAttribute, + default: str | float, + to_set: list[str | float] ) -> None: """Test VersionedAttribute.newest.""" # check .newest on empty history returns .default @@ -109,10 +130,14 @@ class TestCaseSansDB(TestCase): attr.set(default) self.assertEqual(attr.newest, default) - @_within_checked_class - @_for_versioned_attr - def test_versioned_at(self, attr: VersionedAttribute, default: str | float, - to_set: list[str | float]) -> None: + @TestCaseAugmented._on_versioned_attributes + def test_versioned_at(self, + _: Any, + __: str, + attr: VersionedAttribute, + default: str | float, + to_set: list[str | float] + ) -> None: """Test .at() returns values nearest to queried time, or default.""" # check .at() return default on empty history timestamp_a = datetime.now().strftime(TIMESTAMP_FMT) @@ -136,12 +161,9 @@ class TestCaseSansDB(TestCase): self.assertEqual(attr.at(timestamp_after_c), to_set[1]) -class TestCaseWithDB(TestCase): +class TestCaseWithDB(TestCaseAugmented): """Module tests not requiring DB setup.""" - checked_class: Any default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3) - default_init_kwargs: dict[str, Any] = {} - test_versioneds: dict[str, type] = {} def setUp(self) -> None: Condition.empty_cache() @@ -156,18 +178,6 @@ class TestCaseWithDB(TestCase): self.db_conn.close() remove_file(self.db_file.path) - @staticmethod - def _for_versioned_attr(f: Callable[..., None]) -> Callable[..., None]: - def wrapper(self: TestCaseWithDB) -> None: - for attr_name, type_ in self.test_versioneds.items(): - owner = self.checked_class(None, **self.default_init_kwargs) - to_set = vals_str if str == type_ else vals_float - attr = getattr(owner, attr_name) - attr.set(to_set[0]) - attr.set(to_set[1]) - f(self, owner, attr_name, attr, to_set) - return wrapper - def _load_from_db(self, id_: int | str) -> list[object]: db_found: list[object] = [] for row in self.db_conn.row_where(self.checked_class.table_name, @@ -177,7 +187,7 @@ class TestCaseWithDB(TestCase): return db_found def _change_obj(self, obj: object) -> str: - attr_name: str = self.checked_class.to_save[-1] + attr_name: str = self.checked_class.to_save_simples[-1] attr = getattr(obj, attr_name) new_attr: str | int | float | bool if isinstance(attr, (int, float)): @@ -203,11 +213,13 @@ class TestCaseWithDB(TestCase): hashes_db_found = [hash(x) for x in db_found] self.assertEqual(sorted(hashes_content), sorted(hashes_db_found)) - @_within_checked_class - @_for_versioned_attr - def test_saving_versioned_attributes(self, owner: Any, attr_name: str, + @TestCaseAugmented._on_versioned_attributes + def test_saving_versioned_attributes(self, + owner: Any, + attr_name: str, attr: VersionedAttribute, - vals: list[str | float] + _: str | float, + to_set: list[str | float] ) -> None: """Test storage and initialization of versioned attributes.""" @@ -218,6 +230,7 @@ class TestCaseWithDB(TestCase): attr_vals_saved += [row[2]] return attr_vals_saved + attr.set(to_set[0]) # check that without attr.save() no rows in DB rows = self.db_conn.row_where(attr.table_name, 'parent', owner.id_) self.assertEqual([], rows) @@ -227,30 +240,31 @@ class TestCaseWithDB(TestCase): # check owner.save() created entries as expected in attr table owner.save(self.db_conn) attr_vals_saved = retrieve_attr_vals(attr) - self.assertEqual(vals, attr_vals_saved) + self.assertEqual([to_set[0]], attr_vals_saved) # check changing attr val without save affects owner in memory … - attr.set(vals[0]) + attr.set(to_set[1]) cmp_attr = getattr(owner, attr_name) + self.assertEqual(to_set, list(cmp_attr.history.values())) self.assertEqual(cmp_attr.history, attr.history) # … but does not yet affect DB attr_vals_saved = retrieve_attr_vals(attr) - self.assertEqual(vals, attr_vals_saved) + self.assertEqual([to_set[0]], attr_vals_saved) # check individual attr.save also stores new val to DB attr.save(self.db_conn) attr_vals_saved = retrieve_attr_vals(attr) - self.assertEqual(vals + [vals[0]], attr_vals_saved) + self.assertEqual(to_set, attr_vals_saved) - @_within_checked_class + @TestCaseAugmented._within_checked_class def test_saving_and_caching(self) -> None: """Test effects of .cache() and .save().""" id1 = self.default_ids[0] # check failure to cache without ID (if None-ID input possible) if isinstance(id1, int): - obj0 = self.checked_class(None, **self.default_init_kwargs) + obj0 = self._make_from_defaults(None) with self.assertRaises(HandledException): obj0.cache() # check mere object init itself doesn't even store in cache - obj1 = self.checked_class(id1, **self.default_init_kwargs) + obj1 = self._make_from_defaults(id1) self.assertEqual(self.checked_class.get_cache(), {}) # check .cache() fills cache, but not DB obj1.cache() @@ -262,7 +276,7 @@ class TestCaseWithDB(TestCase): # it's generated by cursor.lastrowid on the DB table, and with obj1 # not written there, obj2 should get it first!) id_input = None if isinstance(id1, int) else id1 - obj2 = self.checked_class(id_input, **self.default_init_kwargs) + obj2 = self._make_from_defaults(id_input) obj2.save(self.db_conn) self.assertEqual(self.checked_class.get_cache(), {id1: obj2}) # NB: we'll only compare hashes because obj2 itself disappears on @@ -275,23 +289,23 @@ class TestCaseWithDB(TestCase): with self.assertRaises(HandledException): obj1.save(self.db_conn) - @_within_checked_class + @TestCaseAugmented._within_checked_class def test_by_id(self) -> None: """Test .by_id().""" id1, id2, _ = self.default_ids # check failure if not yet saved - obj1 = self.checked_class(id1, **self.default_init_kwargs) + obj1 = self._make_from_defaults(id1) with self.assertRaises(NotFoundException): self.checked_class.by_id(self.db_conn, id1) # check identity of cached and retrieved obj1.cache() self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1)) # check identity of saved and retrieved - obj2 = self.checked_class(id2, **self.default_init_kwargs) + obj2 = self._make_from_defaults(id2) obj2.save(self.db_conn) self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2)) - @_within_checked_class + @TestCaseAugmented._within_checked_class def test_by_id_or_create(self) -> None: """Test .by_id_or_create.""" # check .by_id_or_create fails if wrong class @@ -314,11 +328,11 @@ class TestCaseWithDB(TestCase): self.checked_class.by_id(self.db_conn, item.id_) self.assertEqual(self.checked_class(item.id_), item) - @_within_checked_class + @TestCaseAugmented._within_checked_class def test_from_table_row(self) -> None: """Test .from_table_row() properly reads in class directly from DB.""" id_ = self.default_ids[0] - obj = self.checked_class(id_, **self.default_init_kwargs) + obj = self._make_from_defaults(id_) obj.save(self.db_conn) assert isinstance(obj.id_, type(id_)) for row in self.db_conn.row_where(self.checked_class.table_name, @@ -338,23 +352,21 @@ class TestCaseWithDB(TestCase): self.assertEqual({retrieved.id_: retrieved}, self.checked_class.get_cache()) - @_within_checked_class - @_for_versioned_attr - def test_versioned_history_from_row(self, owner: Any, _: str, + @TestCaseAugmented._on_versioned_attributes + def test_versioned_history_from_row(self, + owner: Any, + _: str, attr: VersionedAttribute, - vals: list[str | float] + default: str | float, + to_set: list[str | float] ) -> None: """"Test VersionedAttribute.history_from_row() knows its DB rows.""" - vals += [vals[1] * 2] - attr.set(vals[2]) - attr.set(vals[1]) - attr.set(vals[2]) + attr.set(to_set[0]) + attr.set(to_set[1]) owner.save(self.db_conn) # make empty VersionedAttribute, fill from rows, compare to owner's - for row in self.db_conn.row_where(owner.table_name, 'id', - owner.id_): - loaded_attr = VersionedAttribute(owner, attr.table_name, - attr.default) + for row in self.db_conn.row_where(owner.table_name, 'id', owner.id_): + loaded_attr = VersionedAttribute(owner, attr.table_name, default) for row in self.db_conn.row_where(attr.table_name, 'parent', owner.id_): loaded_attr.history_from_row(row) @@ -363,13 +375,13 @@ class TestCaseWithDB(TestCase): for timestamp, value in attr.history.items(): self.assertEqual(value, loaded_attr.history[timestamp]) - @_within_checked_class + @TestCaseAugmented._within_checked_class def test_all(self) -> None: """Test .all() and its relation to cache and savings.""" - id_1, id_2, id_3 = self.default_ids - item1 = self.checked_class(id_1, **self.default_init_kwargs) - item2 = self.checked_class(id_2, **self.default_init_kwargs) - item3 = self.checked_class(id_3, **self.default_init_kwargs) + id1, id2, id3 = self.default_ids + item1 = self._make_from_defaults(id1) + item2 = self._make_from_defaults(id2) + item3 = self._make_from_defaults(id3) # check .all() returns empty list on un-cached items self.assertEqual(self.checked_class.all(self.db_conn), []) # check that all() shows only cached/saved items @@ -381,11 +393,11 @@ class TestCaseWithDB(TestCase): self.assertEqual(sorted(self.checked_class.all(self.db_conn)), sorted([item1, item2, item3])) - @_within_checked_class + @TestCaseAugmented._within_checked_class def test_singularity(self) -> None: """Test pointers made for single object keep pointing to it.""" id1 = self.default_ids[0] - obj = self.checked_class(id1, **self.default_init_kwargs) + obj = self._make_from_defaults(id1) obj.save(self.db_conn) # change object, expect retrieved through .by_id to carry change attr_name = self._change_obj(obj) @@ -393,25 +405,27 @@ class TestCaseWithDB(TestCase): retrieved = self.checked_class.by_id(self.db_conn, id1) self.assertEqual(new_attr, getattr(retrieved, attr_name)) - @_within_checked_class - @_for_versioned_attr - def test_versioned_singularity(self, owner: Any, attr_name: str, + @TestCaseAugmented._on_versioned_attributes + def test_versioned_singularity(self, + owner: Any, + attr_name: str, attr: VersionedAttribute, - vals: list[str | float] + _: str | float, + to_set: list[str | float] ) -> None: """Test singularity of VersionedAttributes on saving.""" owner.save(self.db_conn) # change obj, expect retrieved through .by_id to carry change - attr.set(vals[0]) + attr.set(to_set[0]) retrieved = self.checked_class.by_id(self.db_conn, owner.id_) attr_retrieved = getattr(retrieved, attr_name) self.assertEqual(attr.history, attr_retrieved.history) - @_within_checked_class + @TestCaseAugmented._within_checked_class def test_remove(self) -> None: """Test .remove() effects on DB and cache.""" id_ = self.default_ids[0] - obj = self.checked_class(id_, **self.default_init_kwargs) + obj = self._make_from_defaults(id_) # check removal only works after saving with self.assertRaises(HandledException): obj.remove(self.db_conn)