home · contact · privacy
Enforce more type safety on ClientDb.
authorChristian Heller <c.heller@plomlompom.de>
Thu, 14 Aug 2025 14:37:01 +0000 (16:37 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Thu, 14 Aug 2025 14:37:01 +0000 (16:37 +0200)
ircplom/client.py

index 40f343e4767e6e36b75a2af0ab26a02bc91b4e43..563c69e1cacf0aed462c6ead439713fb419f813b 100644 (file)
@@ -6,7 +6,7 @@ from dataclasses import dataclass, InitVar
 from enum import Enum, auto
 from getpass import getuser
 from threading import Thread
-from typing import Any, Callable, NamedTuple, Optional
+from typing import Callable, NamedTuple, Optional
 # ourselves
 from ircplom.events import (
         AffectiveEvent, CrashingException, ExceptionEvent, QueueMixin)
@@ -131,7 +131,7 @@ class _CapsManager:
                     self._db.append('caps_LS', param, keep_confirmed=True)
             case 'DEL':
                 for param in params[-1].split():
-                    del self._db.caps_LS[param]
+                    self._db.caps_LS.remove(param)
             case 'ACK' | 'NAK':
                 for name in params[-1].split():
                     if params[0] == 'ACK':
@@ -176,19 +176,55 @@ class IrcConnSetup(NamedTuple):
 
 class ClientDbBase:
     'Optimized for dealing with variable confirmation of values.'
+    client_host: str
+    isupports: list[str]
+    motd: list[str]
+    nickname: str
 
     def __init__(self) -> None:
         self._dict: dict[str, ClientDbType] = {}
         self._confirmeds: list[str] = []
 
-    def __getattr__(self, key: str) -> Any:
+    def __getattr__(self, key: str) -> Optional[ClientDbType]:
         if key in self._dict and key in self._confirmeds:
-            return self._dict[key]
+            value = self._dict[key]
+            self._typecheck(key, value)
+            return value
         return None
 
+    @classmethod
+    def _typecheck(cls, key: str, value: ClientDbType) -> None:
+        candidates = [c.__annotations__[key] for c in cls.__mro__
+                      if c is not object and key in c.__annotations__]
+        if not candidates:
+            raise CrashingException(f'{cls} lacks annotation for {key}')
+        type_ = candidates[0]
+        fail = True
+        type_found = str(type(value))
+        if not isinstance(type_, type):              # gotta be GenericAlias, …
+            assert hasattr(type_, '__origin__')      # … which, if for list …
+            assert type_.__origin__ is list          # … (only probable …
+            if isinstance(value, type_.__origin__):  # … candidate so far), …
+                fail = False  # be ok if list emtpy  # … stores members' …
+                assert hasattr(type_, '__args__')    # … types at .__args__
+                subtypes_found = set()
+                for subtype in type_.__args__:
+                    for x in [x for x in value if not isinstance(x, subtype)]:
+                        fail = True
+                        subtypes_found.add(str(type(x)))
+                type_found = f'{type_.__origin__}: ' + '|'.join(subtypes_found)
+        elif isinstance(value, type_):
+            return
+        if fail:
+            raise CrashingException(
+                    f'wrong type for {key}: {type_found} (should be: {type_}, '
+                    f'provided value: {value})')
+
     def set(self, key: str, value: Optional[ClientDbType], confirm=False
             ) -> tuple[bool, bool]:
         'Ensures setting, returns if changed value or confirmation.'
+        if value is not None:
+            self._typecheck(key, value)
         old_value, was_confirmed = self.get_force(key)
         value_changed = (value is not None) and value != old_value
         if value is None:
@@ -208,10 +244,19 @@ class ClientDbBase:
 
     def get_force(self, key: str) -> tuple[Optional[ClientDbType], bool]:
         'Get even if only stored unconfirmed, tell if confirmed..'
-        return (self._dict.get(key, None), key in self._confirmeds)
+        value = self._dict.get(key, None)
+        if value is not None:
+            self._typecheck(key, value)
+        return (value, key in self._confirmeds)
 
 
 class _ClientDb(ClientDbBase):
+    caps_LS: list[str]
+    caps_LIST: list[str]
+    hostname: str
+    password: str
+    port: int
+    realname: str
 
     def append(self, key: str, value: str, keep_confirmed=False) -> None:
         'To list[str] keyed by key, append value; if non-existant, create it.'