From a47164e388a4ac15a2f9d9bdbc50b4bca094c086 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sun, 14 Jul 2024 16:57:21 +0200
Subject: [PATCH] Tighten mypy controls around consistency of list/tuple/etc.
 elements, or add suggestions towards doing that.

---
 plomtask/db.py                   |  9 ++++++++-
 plomtask/http.py                 |  2 +-
 plomtask/versioned_attributes.py |  4 ++++
 tests/utils.py                   | 32 +++++++++++++++++++++++++-------
 4 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/plomtask/db.py b/plomtask/db.py
index 704b709..6f0d13a 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -283,7 +283,7 @@ class BaseModel(Generic[BaseModelId]):
     @property
     def as_dict(self) -> dict[str, object]:
         """Return self as (json.dumps-compatible) dict."""
-        library: dict[str, dict[str | int, object]] = {}
+        library: dict[str, dict[str, object] | dict[int, object]] = {}
         d: dict[str, object] = {'id': self.id_, '_library': library}
         for to_save in self.to_save_simples:
             attr = getattr(self, to_save)
@@ -312,6 +312,12 @@ class BaseModel(Generic[BaseModelId]):
                                library: dict[str, dict[str | int, object]]
                                ) -> int | str:
         """Return self.id_ while writing .as_dict into library."""
+        # NB: For tighter mypy testing, we might prefer the library argument
+        # to be of type dict[str, dict[str, object] | dict[int, object]
+        # instead. But my current coding knowledge only manage to make that
+        # work by turning the code much more complex, so let's leave it at
+        # that for now …
+
         def into_library(library: dict[str, dict[str | int, object]],
                          cls_name: str,
                          id_: str | int,
@@ -326,6 +332,7 @@ class BaseModel(Generic[BaseModelId]):
                     raise HandledException(msg)
             else:
                 library[cls_name][id_] = d
+
         as_dict = self.as_dict
         assert isinstance(as_dict['_library'], dict)
         for cls_name, dict_of_objs in as_dict['_library'].items():
diff --git a/plomtask/http.py b/plomtask/http.py
index db54023..b3b9d7a 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -57,7 +57,7 @@ class TaskServer(HTTPServer):
             if isinstance(node, HandledException):
                 return str(node)
             return node
-        library: dict[str, dict[str | int, object]] = {}
+        library: dict[str, dict[str, object] | dict[int, object]] = {}
         for k, v in ctx.items():
             ctx[k] = walk_ctx(v)
         ctx['_library'] = library
diff --git a/plomtask/versioned_attributes.py b/plomtask/versioned_attributes.py
index cfcbf87..f5e17f3 100644
--- a/plomtask/versioned_attributes.py
+++ b/plomtask/versioned_attributes.py
@@ -19,6 +19,10 @@ class VersionedAttribute:
         self.table_name = table_name
         self._default = default
         self.history: dict[str, str | float] = {}
+        # NB: For tighter mypy testing, we might prefer self.history to be
+        # dict[str, float] | dict[str, str] instead, but my current coding
+        # knowledge only manages to make that work by adding much further
+        # complexity, so let's leave it at that for now …
 
     def __hash__(self) -> int:
         history_tuples = tuple((k, v) for k, v in self.history.items())
diff --git a/tests/utils.py b/tests/utils.py
index 8008033..b969424 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -79,7 +79,7 @@ class TestCaseSansDB(TestCaseAugmented):
                            __: str,
                            attr: VersionedAttribute,
                            default: str | float,
-                           to_set: list[str | float]
+                           to_set: list[str] | list[float]
                            ) -> None:
         """Test VersionedAttribute.set() behaves as expected."""
         attr.set(default)
@@ -118,7 +118,7 @@ class TestCaseSansDB(TestCaseAugmented):
                               __: str,
                               attr: VersionedAttribute,
                               default: str | float,
-                              to_set: list[str | float]
+                              to_set: list[str] | list[float]
                               ) -> None:
         """Test VersionedAttribute.newest."""
         # check .newest on empty history returns .default
@@ -137,7 +137,7 @@ class TestCaseSansDB(TestCaseAugmented):
                           __: str,
                           attr: VersionedAttribute,
                           default: str | float,
-                          to_set: list[str | float]
+                          to_set: list[str] | list[float]
                           ) -> None:
         """Test .at() returns values nearest to queried time, or default."""
         # check .at() return default on empty history
@@ -164,7 +164,7 @@ class TestCaseSansDB(TestCaseAugmented):
 
 class TestCaseWithDB(TestCaseAugmented):
     """Module tests not requiring DB setup."""
-    default_ids: tuple[int | str, int | str, int | str] = (1, 2, 3)
+    default_ids: tuple[int, int, int] | tuple[str, str, str] = (1, 2, 3)
 
     def setUp(self) -> None:
         Condition.empty_cache()
@@ -280,7 +280,7 @@ class TestCaseWithDB(TestCaseAugmented):
                                          attr_name: str,
                                          attr: VersionedAttribute,
                                          _: str | float,
-                                         to_set: list[str | float]
+                                         to_set: list[str] | list[float]
                                          ) -> None:
         """Test storage and initialization of versioned attributes."""
 
@@ -419,7 +419,7 @@ class TestCaseWithDB(TestCaseAugmented):
                                         _: str,
                                         attr: VersionedAttribute,
                                         default: str | float,
-                                        to_set: list[str | float]
+                                        to_set: list[str] | list[float]
                                         ) -> None:
         """"Test VersionedAttribute.history_from_row() knows its DB rows."""
         attr.set(to_set[0])
@@ -472,7 +472,7 @@ class TestCaseWithDB(TestCaseAugmented):
                                    attr_name: str,
                                    attr: VersionedAttribute,
                                    _: str | float,
-                                   to_set: list[str | float]
+                                   to_set: list[str] | list[float]
                                    ) -> None:
         """Test singularity of VersionedAttributes on saving."""
         owner.save(self.db_conn)
@@ -521,11 +521,29 @@ class TestCaseWithServer(TestCaseWithDB):
     @staticmethod
     def as_id_list(items: list[dict[str, object]]) -> list[int | str]:
         """Return list of only 'id' fields of items."""
+        # NB: To tighten the mypy test, consider to, instead of returning
+        # list[str | int], returnlist[int] | list[str]. But since so far to me
+        # the only way to make that work seems to be to repaclement of the
+        # currently active last line with complexity of the out-commented code
+        # block beneath, I currently opt for the status quo.
         id_list = []
         for item in items:
             assert isinstance(item['id'], (int, str))
             id_list += [item['id']]
         return id_list
+        # if id_list:
+        #     if isinstance(id_list[0], int):
+        #         for id_ in id_list:
+        #             assert isinstance(id_, int)
+        #         l_int: list[int] = [id_ for id_ in id_list
+        #                             if isinstance(id_, int)]
+        #         return l_int
+        #     for id_ in id_list:
+        #         assert isinstance(id_, str)
+        #     l_str: list[str] = [id_ for id_ in id_list
+        #                         if isinstance(id_, str)]
+        #     return l_str
+        # return []
 
     @staticmethod
     def as_refs(items: list[dict[str, object]]
-- 
2.30.2