home · contact · privacy
Clean up and refactor into own manager class CAPS negotation.
authorChristian Heller <c.heller@plomlompom.de>
Mon, 4 Aug 2025 17:16:32 +0000 (19:16 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Mon, 4 Aug 2025 17:16:32 +0000 (19:16 +0200)
ircplom/client.py

index 81c53846160caa3d57f2010fec5241b7d6e663ee..d315ced267190278d1487ef92f79932d14c403e2 100644 (file)
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
 from dataclasses import dataclass
 from getpass import getuser
 from threading import Thread
-from typing import Optional
+from typing import Callable, Optional
 from uuid import UUID, uuid4
 # ourselves
 from ircplom.events import (AffectiveEvent, ExceptionEvent, Logger,
@@ -20,6 +20,8 @@ _LOG_PREFIX_SEND_FMT = '> '
 _LOG_PREFIX_SEND_RAW = '=>|'
 _LOG_PREFIX_RECV_RAW = '<-|'
 
+_NAMES_DESIRED_SERVER_CAPS = ('server-time', 'account-tag', 'sasl')
+
 
 @dataclass
 class ClientIdMixin:
@@ -70,6 +72,82 @@ class _ServerCapability:
         return listing
 
 
+class _CapsManager:
+
+    def __init__(self, sender: Callable[[IrcMessage], None]) -> None:
+        self._send = sender
+        self._challenges: dict[str, bool] = {}
+        self._dict: dict[str, _ServerCapability] = {}
+
+    def clear(self) -> None:
+        'Reset all negotiation knowledge to zero.'
+        self._challenges.clear()
+        self._dict.clear()
+
+    def process_msg(self, params: tuple[str, ...]) -> list[str]:
+        'Parse CAP params to negot. steps, DB inputs; once done return latter.'
+        if self._challenge_met('END'):
+            return []
+        match params[0]:
+            case 'LS' | 'LIST':
+                self._collect_caps(params)
+            case 'ACK' | 'NAK':
+                for cap_name in params[-1].split():
+                    self._challenge_set(f'REQ:{cap_name}', done=True)
+                    self._dict[cap_name].enabled = params[0] == 'ACK'
+        if self._challenge_met('LIST'):
+            self.challenge('END')
+            self._challenge_set('END', done=True)
+            return (['server capabilities (enabled: "+"):']
+                    + [cap.str_for_log(cap_name)
+                       for cap_name, cap in self._dict.items()])
+        if self._challenge_met('LS'):
+            for cap_name in _NAMES_DESIRED_SERVER_CAPS:
+                if (cap_name in self._dict
+                        and (not self._dict[cap_name].enabled)):
+                    self.challenge('REQ', cap_name, key_fused=True)
+            self.challenge('LIST')
+        return []
+
+    def challenge(self, *params, key_fused: bool = False) -> None:
+        'Run CAP command with params, handle cap neg. state.'
+        challenge_key = ':'.join(params) if key_fused else params[0]
+        if self._challenged(challenge_key):
+            return
+        self._send(IrcMessage(verb='CAP', params=params))
+        self._challenge_set(challenge_key)
+
+    def _challenge_met(self, step: str) -> bool:
+        return self._challenges.get(step, False)
+
+    def _challenged(self, step: str) -> bool:
+        return step in self._challenges
+
+    def _challenge_set(self, step: str, done: bool = False) -> None:
+        self._challenges[step] = done
+
+    def _collect_caps(self, params: tuple[str, ...]) -> None:
+        verb = params[0]
+        items = params[-1].strip().split()
+        is_final_line = params[1] != '*'
+        if self._challenge_met(verb):
+            if verb == 'LS':
+                self._dict.clear()
+            else:
+                for cap in self._dict.values():
+                    cap.enabled = False
+            self._challenge_set(verb)
+        for item in items:
+            if verb == 'LS':
+                splitted = item.split('=', maxsplit=1)
+                self._dict[splitted[0]] = _ServerCapability(
+                        enabled=False, data=''.join(splitted[1:]))
+            else:
+                self._dict[item].enabled = True
+        if is_final_line:
+            self._challenge_set(verb, done=True)
+
+
 @dataclass
 class IrcConnSetup:
     'All we need to know to set up a new Client connection.'
@@ -85,9 +163,8 @@ class Client(ABC, ClientQueueMixin):
 
     def __init__(self, conn_setup: IrcConnSetup, **kwargs) -> None:
         super().__init__(**kwargs)
+        self._caps = _CapsManager(self.send)
         self.conn_setup = conn_setup
-        self._cap_neg_states: dict[str, bool] = {}
-        self.caps: dict[str, _ServerCapability] = {}
         self.id_ = uuid4()
         self.log = Logger(self._log)
         self.update_login(nick_confirmed=False,
@@ -113,54 +190,12 @@ class Client(ABC, ClientQueueMixin):
     def on_connect(self) -> None:
         'Steps to perform right after connection.'
         self.log.add(msg='connected to server', chat=CHAT_GLOB)
-        self.try_send_cap('LS', ('302',))
+        self._caps.challenge('LS', ('302',))
         self.send(IrcMessage(verb='USER',
                              params=(getuser(), '0', '*',
                                      self.conn_setup.realname)))
         self.send(IrcMessage(verb='NICK', params=(self.conn_setup.nickname,)))
 
-    def cap_neg_done(self, negotiation_step: str) -> bool:
-        'Whether negotiation_step is registered as finished.'
-        return self._cap_neg_states.get(negotiation_step, False)
-
-    def cap_neg(self, negotiation_step: str) -> bool:
-        'Whether negotiation_step is registered at all (started or finished).'
-        return negotiation_step in self._cap_neg_states
-
-    def cap_neg_set(self, negotiation_step: str, done: bool = False) -> None:
-        'Declare negotiation_step started, or (if done) finished.'
-        self._cap_neg_states[negotiation_step] = done
-
-    def try_send_cap(self, *params, key_fused: bool = False) -> None:
-        'Run CAP command with params, handle cap neg. state.'
-        neg_state_key = ':'.join(params) if key_fused else params[0]
-        if self.cap_neg(neg_state_key):
-            return
-        self.send(IrcMessage(verb='CAP', params=params))
-        self.cap_neg_set(neg_state_key)
-
-    def collect_caps(self, params: tuple[str, ...]) -> None:
-        'Record available and enabled server capabilities.'
-        verb = params[0]
-        items = params[-1].strip().split()
-        is_final_line = params[1] != '*'
-        if self.cap_neg_done(verb):
-            if verb == 'LS':
-                self.caps.clear()
-            else:
-                for cap in self.caps.values():
-                    cap.enabled = False
-            self.cap_neg_set(verb)
-        for item in items:
-            if verb == 'LS':
-                splitted = item.split('=', maxsplit=1)
-                self.caps[splitted[0]] = _ServerCapability(
-                        enabled=False, data=''.join(splitted[1:]))
-            else:
-                self.caps[item].enabled = True
-        if is_final_line:
-            self.cap_neg_set(verb, done=True)
-
     @abstractmethod
     def _log(self, msg: str, chat: str = '') -> None:
         '''Write msg into log of chat, whatever shape that may have.
@@ -199,6 +234,7 @@ class Client(ABC, ClientQueueMixin):
     def close(self) -> None:
         'Close both recv Loop and socket.'
         self.log.add(msg='disconnecting from server', chat=CHAT_GLOB)
+        self._caps.clear()
         if self.conn:
             self.conn.close()
         self.conn = None
@@ -218,28 +254,8 @@ class Client(ABC, ClientQueueMixin):
                 self.log.add(msg=str(msg.params), prefix=_LOG_PREFIX_PRIVMSG,
                              chat=msg.source)
             case 'CAP':
-                match msg.params[1]:
-                    case 'LS' | 'LIST':
-                        self.collect_caps(msg.params[1:])
-                    case 'ACK' | 'NAK':
-                        cap_names = msg.params[-1].split()
-                        for cap_name in cap_names:
-                            self.cap_neg_set(f'REQ:{cap_name}', done=True)
-                            self.caps[cap_name].enabled = (msg.params[1]
-                                                           == 'ACK')
-                if self.cap_neg_done('LIST'):
-                    self.try_send_cap('END')
-                    if not self.cap_neg('printing'):
-                        self.log.add('server capabilities (enabled: "+"):')
-                        for cap_name, cap in self.caps.items():
-                            self.log.add(cap.str_for_log(cap_name))
-                        self.cap_neg_set('printing', done=True)
-                elif self.cap_neg_done('LS'):
-                    for cap_name in ('server-time', 'account-tag', 'sasl'):
-                        if (cap_name in self.caps
-                                and (not self.caps[cap_name].enabled)):
-                            self.try_send_cap('REQ', cap_name, key_fused=True)
-                    self.try_send_cap('LIST')
+                for to_log in self._caps.process_msg(msg.params[1:]):
+                    self.log.add(to_log)
 
 
 @dataclass