home · contact · privacy
Further code simplifications.
authorChristian Heller <c.heller@plomlompom.de>
Sat, 11 Jan 2025 06:57:24 +0000 (07:57 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Sat, 11 Jan 2025 06:57:24 +0000 (07:57 +0100)
plomtask/db.py
plomtask/http.py
tests/processes.py
tests/utils.py

index 2ce7a61f6e85991af36a64334246efac6f3190af..be849b62822e9e122110c9c4975ed48da55d0227 100644 (file)
@@ -537,8 +537,7 @@ class BaseModel:
                                      for key in self.to_save_simples])
         table_name = self.table_name
         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES', values)
-        if not isinstance(self.id_, str):
-            self.id_ = cursor.lastrowid
+        self.id_ = cursor.lastrowid
         self.cache()
         for attr_name in self.to_save_versioned():
             getattr(self, attr_name).save(db_conn)
index b6c6845489a4feb79aea1d25ff461a0b1eb8f287..a4d2ed42d63ff3af34921e9b19050e1d39cd226e 100644 (file)
@@ -165,7 +165,7 @@ class TaskHandler(BaseHTTPRequestHandler):
                     library[cls_name] = {}
                 if item.id_ not in library[cls_name]:
                     d, refs = item.as_dict_and_refs
-                    id_key = '?' if item.id_ is None else item.id_
+                    id_key = -1 if item.id_ is None else item.id_
                     library[cls_name][id_key] = d
                     for ref in refs:
                         update_library_with(ref)
@@ -189,7 +189,7 @@ class TaskHandler(BaseHTTPRequestHandler):
                 return str(node)
             return node
 
-        library: dict[str, dict[str | int, object]] = {}
+        library: dict[str, dict[int, object]] = {}
         for k, v in ctx.items():
             ctx[k] = flatten(v)
         ctx['_library'] = library
index 78d396ee208334dc66d3eb79c4e7c7739953a3fc..a762ebeb7970696302811f2f0919144de0cd8cd2 100644 (file)
@@ -368,7 +368,7 @@ class TestsWithServer(TestCaseWithServer):
         self.check_filter(exp, 'processes', 'sort_by', '-owners', [1, 2, 3])
         # test pattern matching on title
         exp.set('sort_by', 'title')
-        exp.lib_del('Process', '1')
+        exp.lib_del('Process', 1)
         self.check_filter(exp, 'processes', 'pattern', 'ba', [2, 3])
         # test pattern matching on description
         exp.lib_wipe('Process')
index b24335740e2cec1b5d0892e3d8fe08eb0b2735bb..ce5f2e5233ad4e91d2ba64923127c87997ff677d 100644 (file)
@@ -74,22 +74,22 @@ class TestCaseAugmented(TestCase):
         return wrapper
 
     @classmethod
-    def _make_from_defaults(cls, id_: float | str | None) -> Any:
+    def _make_from_defaults(cls, id_: int | None) -> Any:
         return cls.checked_class(id_, **cls.default_init_kwargs)
 
 
 class TestCaseSansDB(TestCaseAugmented):
     """Tests requiring no DB setup."""
-    legal_ids: list[str] | list[int] = [1, 5]
-    illegal_ids: list[str] | list[int] = [0]
+    _legal_ids: list[int] = [1, 5]
+    _illegal_ids: list[int] = [0]
 
     @TestCaseAugmented._run_if_sans_db
     def test_id_validation(self) -> None:
         """Test .id_ validation/setting."""
-        for id_ in self.illegal_ids:
+        for id_ in self._illegal_ids:
             with self.assertRaises(HandledException):
                 self._make_from_defaults(id_)
-        for id_ in self.legal_ids:
+        for id_ in self._legal_ids:
             obj = self._make_from_defaults(id_)
             self.assertEqual(obj.id_, id_)
 
@@ -187,7 +187,7 @@ class TestCaseSansDB(TestCaseAugmented):
 
 class TestCaseWithDB(TestCaseAugmented):
     """Module tests not requiring DB setup."""
-    default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
+    _default_ids: tuple[int, int, int] = (1, 2, 3)
 
     def setUp(self) -> None:
         Condition.empty_cache()
@@ -202,7 +202,7 @@ class TestCaseWithDB(TestCaseAugmented):
         self.db_conn.close()
         remove_file(self.db_file.path)
 
-    def _load_from_db(self, id_: int | str) -> list[object]:
+    def _load_from_db(self, id_: int) -> list[object]:
         db_found: list[object] = []
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', id_):
@@ -281,12 +281,11 @@ class TestCaseWithDB(TestCaseAugmented):
     @TestCaseAugmented._run_if_with_db_but_not_server
     def test_saving_and_caching(self) -> None:
         """Test effects of .cache() and .save()."""
-        id1 = self.default_ids[0]
+        id1 = self._default_ids[0]
         # check failure to cache without ID (if None-ID input possible)
-        if isinstance(id1, int):
-            obj0 = self._make_from_defaults(None)
-            with self.assertRaises(HandledException):
-                obj0.cache()
+        obj0 = self._make_from_defaults(None)
+        with self.assertRaises(HandledException):
+            obj0.cache()
         # check mere object init itself doesn't even store in cache
         obj1 = self._make_from_defaults(id1)
         self.assertEqual(self.checked_class.get_cache(), {})
@@ -295,12 +294,11 @@ class TestCaseWithDB(TestCaseAugmented):
         self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
         found_in_db = self._load_from_db(id1)
         self.assertEqual(found_in_db, [])
-        # check .save() sets ID (for int IDs), updates cache, and fills DB
+        # check .save() sets ID, updates cache, and fills DB
         # (expect ID to be set to id1, despite obj1 already having that as ID:
         # it's generated by cursor.lastrowid on the DB table, and with obj1
         # not written there, obj2 should get it first!)
-        id_input = None if isinstance(id1, int) else id1
-        obj2 = self._make_from_defaults(id_input)
+        obj2 = self._make_from_defaults(None)
         obj2.save(self.db_conn)
         self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
         # NB: we'll only compare hashes because obj2 itself disappears on
@@ -316,7 +314,7 @@ class TestCaseWithDB(TestCaseAugmented):
     @TestCaseAugmented._run_if_with_db_but_not_server
     def test_by_id(self) -> None:
         """Test .by_id()."""
-        id1, id2, _ = self.default_ids
+        id1, id2, _ = self._default_ids
         # check failure if not yet saved
         obj1 = self._make_from_defaults(id1)
         with self.assertRaises(NotFoundException):
@@ -337,15 +335,14 @@ class TestCaseWithDB(TestCaseAugmented):
             with self.assertRaises(HandledException):
                 self.checked_class.by_id_or_create(self.db_conn, None)
             return
-        # check ID input of None creates, on saving, ID=1,2,… for int IDs
-        if isinstance(self.default_ids[0], int):
-            for n in range(2):
-                item = self.checked_class.by_id_or_create(self.db_conn, None)
-                self.assertEqual(item.id_, None)
-                item.save(self.db_conn)
-                self.assertEqual(item.id_, n+1)
+        # check ID input of None creates, on saving, ID=1,2,…
+        for n in range(2):
+            item = self.checked_class.by_id_or_create(self.db_conn, None)
+            self.assertEqual(item.id_, None)
+            item.save(self.db_conn)
+            self.assertEqual(item.id_, n+1)
         # check .by_id_or_create acts like normal instantiation (sans saving)
-        id_ = self.default_ids[2]
+        id_ = self._default_ids[2]
         item = self.checked_class.by_id_or_create(self.db_conn, id_)
         self.assertEqual(item.id_, id_)
         with self.assertRaises(NotFoundException):
@@ -355,8 +352,7 @@ class TestCaseWithDB(TestCaseAugmented):
     @TestCaseAugmented._run_if_with_db_but_not_server
     def test_from_table_row(self) -> None:
         """Test .from_table_row() properly reads in class directly from DB."""
-        id_ = self.default_ids[0]
-        obj = self._make_from_defaults(id_)
+        obj = self._make_from_defaults(self._default_ids[0])
         obj.save(self.db_conn)
         for row in self.db_conn.row_where(self.checked_class.table_name,
                                           'id', obj.id_):
@@ -402,7 +398,7 @@ class TestCaseWithDB(TestCaseAugmented):
     @TestCaseAugmented._run_if_with_db_but_not_server
     def test_all(self) -> None:
         """Test .all() and its relation to cache and savings."""
-        id1, id2, id3 = self.default_ids
+        id1, id2, id3 = self._default_ids
         item1 = self._make_from_defaults(id1)
         item2 = self._make_from_defaults(id2)
         item3 = self._make_from_defaults(id3)
@@ -420,7 +416,7 @@ class TestCaseWithDB(TestCaseAugmented):
     @TestCaseAugmented._run_if_with_db_but_not_server
     def test_singularity(self) -> None:
         """Test pointers made for single object keep pointing to it."""
-        id1 = self.default_ids[0]
+        id1 = self._default_ids[0]
         obj = self._make_from_defaults(id1)
         obj.save(self.db_conn)
         # change object, expect retrieved through .by_id to carry change
@@ -449,8 +445,7 @@ class TestCaseWithDB(TestCaseAugmented):
     @TestCaseAugmented._run_if_with_db_but_not_server
     def test_remove(self) -> None:
         """Test .remove() effects on DB and cache."""
-        id_ = self.default_ids[0]
-        obj = self._make_from_defaults(id_)
+        obj = self._make_from_defaults(self._default_ids[0])
         # check removal only works after saving
         with self.assertRaises(HandledException):
             obj.remove(self.db_conn)
@@ -483,25 +478,11 @@ class Expected:
     _fields: dict[str, Any]
     _on_empty_make_temp: tuple[str, str]
 
-    def __init__(self,
-                 todos: list[dict[str, Any]] | None = None,
-                 procs: list[dict[str, Any]] | None = None,
-                 procsteps: list[dict[str, Any]] | None = None,
-                 conds: list[dict[str, Any]] | None = None,
-                 days: list[dict[str, Any]] | None = None
-                 ) -> None:
-        # pylint: disable=too-many-arguments
+    def __init__(self) -> None:
         for name in ['_default_dict', '_fields', '_forced']:
             if not hasattr(self, name):
                 setattr(self, name, {})
-        self._lib = {}
-        for title, items in [('Todo', todos),
-                             ('Process', procs),
-                             ('ProcessStep', procsteps),
-                             ('Condition', conds),
-                             ('Day', days)]:
-            if items:
-                self._lib[title] = self._as_refs(items)
+        self._lib: dict[str, dict[int, dict[str, Any]]] = {}
         for k, v in self._default_dict.items():
             if k not in self._fields:
                 self._fields[k] = v
@@ -547,16 +528,15 @@ class Expected:
             d[k] = v
         if make_temp:
             json = json_dumps(d)
-            id_ = id_ if id_ is not None else '?'
+            id_ = id_ if id_ is not None else -1
             self.lib_del(category, id_)
             d = json_loads(json)
         return d
 
-    def lib_get(self, category: str, id_: str | int) -> dict[str, Any]:
+    def lib_get(self, category: str, id_: int) -> dict[str, Any]:
         """From library, return item of category and id_, or empty dict."""
-        str_id = str(id_)
-        if category in self._lib and str_id in self._lib[category]:
-            return self._lib[category][str_id]
+        if category in self._lib and id_ in self._lib[category]:
+            return self._lib[category][id_]
         return {}
 
     def lib_all(self, category: str) -> list[dict[str, Any]]:
@@ -569,12 +549,14 @@ class Expected:
         """Update library for category with items."""
         if category not in self._lib:
             self._lib[category] = {}
-        for k, v in self._as_refs(items).items():
-            self._lib[category][k] = v
+        for item in items:
+            id_ = item['id'] if item['id'] is not None else -1
+            assert isinstance(id_, int)
+            self._lib[category][id_] = item
 
-    def lib_del(self, category: str, id_: str | int) -> None:
+    def lib_del(self, category: str, id_: int) -> None:
         """Remove category element of id_ from library."""
-        del self._lib[category][str(id_)]
+        del self._lib[category][id_]
         if 0 == len(self._lib[category]):
             del self._lib[category]
 
@@ -591,20 +573,6 @@ class Expected:
         """Set ._forced field to ensure value in .as_dict."""
         self._forced[field_name] = value
 
-    def unforce(self, field_name: str) -> None:
-        """Unset ._forced field."""
-        del self._forced[field_name]
-
-    @staticmethod
-    def _as_refs(items: list[dict[str, object]]
-                 ) -> dict[str, dict[str, object]]:
-        """Return dictionary of items by their 'id' fields."""
-        refs = {}
-        for item in items:
-            id_ = str(item['id']) if item['id'] is not None else '?'
-            refs[id_] = item
-        return refs
-
     @staticmethod
     def as_ids(items: list[dict[str, Any]]) -> list[int] | list[str]:
         """Return list of only 'id' fields of items."""
@@ -1007,7 +975,8 @@ class TestCaseWithServer(TestCaseWithDB):
         self.assertEqual(response.status, 200)
         retrieved = json_loads(response.read().decode())
         rewrite_history_keys_in(retrieved)
-        cmp = expected.as_dict
+        # to convert ._lib int keys to str
+        cmp = json_loads(json_dumps(expected.as_dict))
         try:
             self.assertEqual(cmp, retrieved)
         except AssertionError as e: