From: Christian Heller Date: Sat, 11 Jan 2025 06:57:24 +0000 (+0100) Subject: Further code simplifications. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/todos?a=commitdiff_plain;h=04fbe79d4632caba36402ce4fb0156943c77ae9a;p=plomtask Further code simplifications. --- diff --git a/plomtask/db.py b/plomtask/db.py index 2ce7a61..be849b6 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -537,8 +537,7 @@ class BaseModel: for key in self.to_save_simples]) table_name = self.table_name cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES', values) - if not isinstance(self.id_, str): - self.id_ = cursor.lastrowid + self.id_ = cursor.lastrowid self.cache() for attr_name in self.to_save_versioned(): getattr(self, attr_name).save(db_conn) diff --git a/plomtask/http.py b/plomtask/http.py index b6c6845..a4d2ed4 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -165,7 +165,7 @@ class TaskHandler(BaseHTTPRequestHandler): library[cls_name] = {} if item.id_ not in library[cls_name]: d, refs = item.as_dict_and_refs - id_key = '?' if item.id_ is None else item.id_ + id_key = -1 if item.id_ is None else item.id_ library[cls_name][id_key] = d for ref in refs: update_library_with(ref) @@ -189,7 +189,7 @@ class TaskHandler(BaseHTTPRequestHandler): return str(node) return node - library: dict[str, dict[str | int, object]] = {} + library: dict[str, dict[int, object]] = {} for k, v in ctx.items(): ctx[k] = flatten(v) ctx['_library'] = library diff --git a/tests/processes.py b/tests/processes.py index 78d396e..a762ebe 100644 --- a/tests/processes.py +++ b/tests/processes.py @@ -368,7 +368,7 @@ class TestsWithServer(TestCaseWithServer): self.check_filter(exp, 'processes', 'sort_by', '-owners', [1, 2, 3]) # test pattern matching on title exp.set('sort_by', 'title') - exp.lib_del('Process', '1') + exp.lib_del('Process', 1) self.check_filter(exp, 'processes', 'pattern', 'ba', [2, 3]) # test pattern matching on description exp.lib_wipe('Process') diff --git a/tests/utils.py b/tests/utils.py index b243357..ce5f2e5 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -74,22 +74,22 @@ class TestCaseAugmented(TestCase): return wrapper @classmethod - def _make_from_defaults(cls, id_: float | str | None) -> Any: + def _make_from_defaults(cls, id_: int | None) -> Any: return cls.checked_class(id_, **cls.default_init_kwargs) class TestCaseSansDB(TestCaseAugmented): """Tests requiring no DB setup.""" - legal_ids: list[str] | list[int] = [1, 5] - illegal_ids: list[str] | list[int] = [0] + _legal_ids: list[int] = [1, 5] + _illegal_ids: list[int] = [0] @TestCaseAugmented._run_if_sans_db def test_id_validation(self) -> None: """Test .id_ validation/setting.""" - for id_ in self.illegal_ids: + for id_ in self._illegal_ids: with self.assertRaises(HandledException): self._make_from_defaults(id_) - for id_ in self.legal_ids: + for id_ in self._legal_ids: obj = self._make_from_defaults(id_) self.assertEqual(obj.id_, id_) @@ -187,7 +187,7 @@ class TestCaseSansDB(TestCaseAugmented): class TestCaseWithDB(TestCaseAugmented): """Module tests not requiring DB setup.""" - default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3) + _default_ids: tuple[int, int, int] = (1, 2, 3) def setUp(self) -> None: Condition.empty_cache() @@ -202,7 +202,7 @@ class TestCaseWithDB(TestCaseAugmented): self.db_conn.close() remove_file(self.db_file.path) - def _load_from_db(self, id_: int | str) -> list[object]: + def _load_from_db(self, id_: int) -> list[object]: db_found: list[object] = [] for row in self.db_conn.row_where(self.checked_class.table_name, 'id', id_): @@ -281,12 +281,11 @@ class TestCaseWithDB(TestCaseAugmented): @TestCaseAugmented._run_if_with_db_but_not_server def test_saving_and_caching(self) -> None: """Test effects of .cache() and .save().""" - id1 = self.default_ids[0] + id1 = self._default_ids[0] # check failure to cache without ID (if None-ID input possible) - if isinstance(id1, int): - obj0 = self._make_from_defaults(None) - with self.assertRaises(HandledException): - obj0.cache() + obj0 = self._make_from_defaults(None) + with self.assertRaises(HandledException): + obj0.cache() # check mere object init itself doesn't even store in cache obj1 = self._make_from_defaults(id1) self.assertEqual(self.checked_class.get_cache(), {}) @@ -295,12 +294,11 @@ class TestCaseWithDB(TestCaseAugmented): self.assertEqual(self.checked_class.get_cache(), {id1: obj1}) found_in_db = self._load_from_db(id1) self.assertEqual(found_in_db, []) - # check .save() sets ID (for int IDs), updates cache, and fills DB + # check .save() sets ID, updates cache, and fills DB # (expect ID to be set to id1, despite obj1 already having that as ID: # it's generated by cursor.lastrowid on the DB table, and with obj1 # not written there, obj2 should get it first!) - id_input = None if isinstance(id1, int) else id1 - obj2 = self._make_from_defaults(id_input) + obj2 = self._make_from_defaults(None) obj2.save(self.db_conn) self.assertEqual(self.checked_class.get_cache(), {id1: obj2}) # NB: we'll only compare hashes because obj2 itself disappears on @@ -316,7 +314,7 @@ class TestCaseWithDB(TestCaseAugmented): @TestCaseAugmented._run_if_with_db_but_not_server def test_by_id(self) -> None: """Test .by_id().""" - id1, id2, _ = self.default_ids + id1, id2, _ = self._default_ids # check failure if not yet saved obj1 = self._make_from_defaults(id1) with self.assertRaises(NotFoundException): @@ -337,15 +335,14 @@ class TestCaseWithDB(TestCaseAugmented): with self.assertRaises(HandledException): self.checked_class.by_id_or_create(self.db_conn, None) return - # check ID input of None creates, on saving, ID=1,2,… for int IDs - if isinstance(self.default_ids[0], int): - for n in range(2): - item = self.checked_class.by_id_or_create(self.db_conn, None) - self.assertEqual(item.id_, None) - item.save(self.db_conn) - self.assertEqual(item.id_, n+1) + # check ID input of None creates, on saving, ID=1,2,… + for n in range(2): + item = self.checked_class.by_id_or_create(self.db_conn, None) + self.assertEqual(item.id_, None) + item.save(self.db_conn) + self.assertEqual(item.id_, n+1) # check .by_id_or_create acts like normal instantiation (sans saving) - id_ = self.default_ids[2] + id_ = self._default_ids[2] item = self.checked_class.by_id_or_create(self.db_conn, id_) self.assertEqual(item.id_, id_) with self.assertRaises(NotFoundException): @@ -355,8 +352,7 @@ class TestCaseWithDB(TestCaseAugmented): @TestCaseAugmented._run_if_with_db_but_not_server def test_from_table_row(self) -> None: """Test .from_table_row() properly reads in class directly from DB.""" - id_ = self.default_ids[0] - obj = self._make_from_defaults(id_) + obj = self._make_from_defaults(self._default_ids[0]) obj.save(self.db_conn) for row in self.db_conn.row_where(self.checked_class.table_name, 'id', obj.id_): @@ -402,7 +398,7 @@ class TestCaseWithDB(TestCaseAugmented): @TestCaseAugmented._run_if_with_db_but_not_server def test_all(self) -> None: """Test .all() and its relation to cache and savings.""" - id1, id2, id3 = self.default_ids + id1, id2, id3 = self._default_ids item1 = self._make_from_defaults(id1) item2 = self._make_from_defaults(id2) item3 = self._make_from_defaults(id3) @@ -420,7 +416,7 @@ class TestCaseWithDB(TestCaseAugmented): @TestCaseAugmented._run_if_with_db_but_not_server def test_singularity(self) -> None: """Test pointers made for single object keep pointing to it.""" - id1 = self.default_ids[0] + id1 = self._default_ids[0] obj = self._make_from_defaults(id1) obj.save(self.db_conn) # change object, expect retrieved through .by_id to carry change @@ -449,8 +445,7 @@ class TestCaseWithDB(TestCaseAugmented): @TestCaseAugmented._run_if_with_db_but_not_server def test_remove(self) -> None: """Test .remove() effects on DB and cache.""" - id_ = self.default_ids[0] - obj = self._make_from_defaults(id_) + obj = self._make_from_defaults(self._default_ids[0]) # check removal only works after saving with self.assertRaises(HandledException): obj.remove(self.db_conn) @@ -483,25 +478,11 @@ class Expected: _fields: dict[str, Any] _on_empty_make_temp: tuple[str, str] - def __init__(self, - todos: list[dict[str, Any]] | None = None, - procs: list[dict[str, Any]] | None = None, - procsteps: list[dict[str, Any]] | None = None, - conds: list[dict[str, Any]] | None = None, - days: list[dict[str, Any]] | None = None - ) -> None: - # pylint: disable=too-many-arguments + def __init__(self) -> None: for name in ['_default_dict', '_fields', '_forced']: if not hasattr(self, name): setattr(self, name, {}) - self._lib = {} - for title, items in [('Todo', todos), - ('Process', procs), - ('ProcessStep', procsteps), - ('Condition', conds), - ('Day', days)]: - if items: - self._lib[title] = self._as_refs(items) + self._lib: dict[str, dict[int, dict[str, Any]]] = {} for k, v in self._default_dict.items(): if k not in self._fields: self._fields[k] = v @@ -547,16 +528,15 @@ class Expected: d[k] = v if make_temp: json = json_dumps(d) - id_ = id_ if id_ is not None else '?' + id_ = id_ if id_ is not None else -1 self.lib_del(category, id_) d = json_loads(json) return d - def lib_get(self, category: str, id_: str | int) -> dict[str, Any]: + def lib_get(self, category: str, id_: int) -> dict[str, Any]: """From library, return item of category and id_, or empty dict.""" - str_id = str(id_) - if category in self._lib and str_id in self._lib[category]: - return self._lib[category][str_id] + if category in self._lib and id_ in self._lib[category]: + return self._lib[category][id_] return {} def lib_all(self, category: str) -> list[dict[str, Any]]: @@ -569,12 +549,14 @@ class Expected: """Update library for category with items.""" if category not in self._lib: self._lib[category] = {} - for k, v in self._as_refs(items).items(): - self._lib[category][k] = v + for item in items: + id_ = item['id'] if item['id'] is not None else -1 + assert isinstance(id_, int) + self._lib[category][id_] = item - def lib_del(self, category: str, id_: str | int) -> None: + def lib_del(self, category: str, id_: int) -> None: """Remove category element of id_ from library.""" - del self._lib[category][str(id_)] + del self._lib[category][id_] if 0 == len(self._lib[category]): del self._lib[category] @@ -591,20 +573,6 @@ class Expected: """Set ._forced field to ensure value in .as_dict.""" self._forced[field_name] = value - def unforce(self, field_name: str) -> None: - """Unset ._forced field.""" - del self._forced[field_name] - - @staticmethod - def _as_refs(items: list[dict[str, object]] - ) -> dict[str, dict[str, object]]: - """Return dictionary of items by their 'id' fields.""" - refs = {} - for item in items: - id_ = str(item['id']) if item['id'] is not None else '?' - refs[id_] = item - return refs - @staticmethod def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]: """Return list of only 'id' fields of items.""" @@ -1007,7 +975,8 @@ class TestCaseWithServer(TestCaseWithDB): self.assertEqual(response.status, 200) retrieved = json_loads(response.read().decode()) rewrite_history_keys_in(retrieved) - cmp = expected.as_dict + # to convert ._lib int keys to str + cmp = json_loads(json_dumps(expected.as_dict)) try: self.assertEqual(cmp, retrieved) except AssertionError as e: