From 475bdf61e96966a15c3808cef630118577435c71 Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Thu, 14 Aug 2025 16:37:01 +0200 Subject: [PATCH] Enforce more type safety on ClientDb. --- ircplom/client.py | 55 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/ircplom/client.py b/ircplom/client.py index 40f343e..563c69e 100644 --- a/ircplom/client.py +++ b/ircplom/client.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, InitVar from enum import Enum, auto from getpass import getuser from threading import Thread -from typing import Any, Callable, NamedTuple, Optional +from typing import Callable, NamedTuple, Optional # ourselves from ircplom.events import ( AffectiveEvent, CrashingException, ExceptionEvent, QueueMixin) @@ -131,7 +131,7 @@ class _CapsManager: self._db.append('caps_LS', param, keep_confirmed=True) case 'DEL': for param in params[-1].split(): - del self._db.caps_LS[param] + self._db.caps_LS.remove(param) case 'ACK' | 'NAK': for name in params[-1].split(): if params[0] == 'ACK': @@ -176,19 +176,55 @@ class IrcConnSetup(NamedTuple): class ClientDbBase: 'Optimized for dealing with variable confirmation of values.' + client_host: str + isupports: list[str] + motd: list[str] + nickname: str def __init__(self) -> None: self._dict: dict[str, ClientDbType] = {} self._confirmeds: list[str] = [] - def __getattr__(self, key: str) -> Any: + def __getattr__(self, key: str) -> Optional[ClientDbType]: if key in self._dict and key in self._confirmeds: - return self._dict[key] + value = self._dict[key] + self._typecheck(key, value) + return value return None + @classmethod + def _typecheck(cls, key: str, value: ClientDbType) -> None: + candidates = [c.__annotations__[key] for c in cls.__mro__ + if c is not object and key in c.__annotations__] + if not candidates: + raise CrashingException(f'{cls} lacks annotation for {key}') + type_ = candidates[0] + fail = True + type_found = str(type(value)) + if not isinstance(type_, type): # gotta be GenericAlias, … + assert hasattr(type_, '__origin__') # … which, if for list … + assert type_.__origin__ is list # … (only probable … + if isinstance(value, type_.__origin__): # … candidate so far), … + fail = False # be ok if list emtpy # … stores members' … + assert hasattr(type_, '__args__') # … types at .__args__ + subtypes_found = set() + for subtype in type_.__args__: + for x in [x for x in value if not isinstance(x, subtype)]: + fail = True + subtypes_found.add(str(type(x))) + type_found = f'{type_.__origin__}: ' + '|'.join(subtypes_found) + elif isinstance(value, type_): + return + if fail: + raise CrashingException( + f'wrong type for {key}: {type_found} (should be: {type_}, ' + f'provided value: {value})') + def set(self, key: str, value: Optional[ClientDbType], confirm=False ) -> tuple[bool, bool]: 'Ensures setting, returns if changed value or confirmation.' + if value is not None: + self._typecheck(key, value) old_value, was_confirmed = self.get_force(key) value_changed = (value is not None) and value != old_value if value is None: @@ -208,10 +244,19 @@ class ClientDbBase: def get_force(self, key: str) -> tuple[Optional[ClientDbType], bool]: 'Get even if only stored unconfirmed, tell if confirmed..' - return (self._dict.get(key, None), key in self._confirmeds) + value = self._dict.get(key, None) + if value is not None: + self._typecheck(key, value) + return (value, key in self._confirmeds) class _ClientDb(ClientDbBase): + caps_LS: list[str] + caps_LIST: list[str] + hostname: str + password: str + port: int + realname: str def append(self, key: str, value: str, keep_confirmed=False) -> None: 'To list[str] keyed by key, append value; if non-existant, create it.' -- 2.30.2