home · contact · privacy
Use same tests for .source and .params.
authorChristian Heller <c.heller@plomlompom.de>
Tue, 19 Aug 2025 12:30:59 +0000 (14:30 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Tue, 19 Aug 2025 12:30:59 +0000 (14:30 +0200)
ircplom/client.py

index e78c2e42e8968489f81c13d932a65a74d6a87219..7e2bdc3dc5a0e9bdf846e0e4d7ac97b467e5dcf1 100644 (file)
@@ -39,93 +39,66 @@ _NUMERICS_TO_IGNORE = (
 )
 
 
-class _MsgSource(Enum):
+class _MsgTok(Enum):
     NONE = auto()
+    ANY = auto()
     USER_ADDRESS = auto()
     USER_ADDRESS_ME = auto()
     SERVER = auto()
-
-
-class _MsgParam(Enum):
-    ANY = auto()
     CHANNEL = auto()
     NICKNAME = auto()
     NICKNAME_ME = auto()
 
 
 class _MsgParseExpectation(NamedTuple):
+    source: _MsgTok
     verb: str
     len_min_params: int = 0
     len_max_params: int = 0
-    params: tuple[str | _MsgParam, ...] = tuple()
-    source: Optional[_MsgSource] = None
+    params: tuple[str | _MsgTok, ...] = tuple()
 
 
 _EXPECTATIONS: tuple[_MsgParseExpectation, ...] = (
-   _MsgParseExpectation('005', 3, 15, source=_MsgSource.SERVER),
-   _MsgParseExpectation('353', 4, source=_MsgSource.SERVER),
-   _MsgParseExpectation('366', 3, source=_MsgSource.SERVER),
-   _MsgParseExpectation('372', 2, source=_MsgSource.SERVER),
-   _MsgParseExpectation('376', 2, source=_MsgSource.SERVER),
-   _MsgParseExpectation('396', 3, source=_MsgSource.SERVER),
-   _MsgParseExpectation('401', 3, source=_MsgSource.SERVER),
-   _MsgParseExpectation('432', 3, source=_MsgSource.SERVER),
-   _MsgParseExpectation('433', 3, source=_MsgSource.SERVER),
-   _MsgParseExpectation('900', 4, source=_MsgSource.SERVER),
-   _MsgParseExpectation('903', params=(_MsgParam.NICKNAME_ME, _MsgParam.ANY),
-                        source=_MsgSource.SERVER),
-   _MsgParseExpectation('904', params=(_MsgParam.NICKNAME_ME, _MsgParam.ANY),
-                        source=_MsgSource.SERVER),
-   _MsgParseExpectation('AUTHENTICATE', params=('+',), source=_MsgSource.NONE),
-   _MsgParseExpectation('CAP', 3, 15, source=_MsgSource.SERVER),
-   _MsgParseExpectation('ERROR', params=(_MsgParam.ANY,),
-                        source=_MsgSource.NONE),
-   _MsgParseExpectation('JOIN', params=(_MsgParam.CHANNEL,),
-                        source=_MsgSource.USER_ADDRESS_ME),
-   _MsgParseExpectation('JOIN', params=(_MsgParam.CHANNEL,),
-                        source=_MsgSource.USER_ADDRESS),
-   _MsgParseExpectation('MODE', params=(_MsgParam.NICKNAME_ME, _MsgParam.ANY),
-                        source=_MsgSource.USER_ADDRESS_ME),
-   _MsgParseExpectation('NICK', params=(_MsgParam.NICKNAME,),
-                        source=_MsgSource.USER_ADDRESS_ME),
-   _MsgParseExpectation('NOTICE', 2, source=_MsgSource.USER_ADDRESS),
-   _MsgParseExpectation('NOTICE', 2, source=_MsgSource.SERVER),
-   _MsgParseExpectation('PART', params=(_MsgParam.CHANNEL,),
-                        source=_MsgSource.USER_ADDRESS_ME),
-   _MsgParseExpectation('PART', params=(_MsgParam.CHANNEL,),
-                        source=_MsgSource.USER_ADDRESS),
-   _MsgParseExpectation('PART', params=(_MsgParam.CHANNEL, _MsgParam.ANY),
-                        source=_MsgSource.USER_ADDRESS),
-   _MsgParseExpectation('PING', params=(_MsgParam.ANY,),
-                        source=_MsgSource.NONE),
-   _MsgParseExpectation('PRIVMSG', 2, source=_MsgSource.USER_ADDRESS),
-   _MsgParseExpectation('PRIVMSG', 2, source=_MsgSource.SERVER),
-   _MsgParseExpectation('QUIT', params=(_MsgParam.ANY,),
-                        source=_MsgSource.USER_ADDRESS),
+   _MsgParseExpectation(_MsgTok.SERVER, '005', 3, 15),
+   _MsgParseExpectation(_MsgTok.SERVER, '353', 4),
+   _MsgParseExpectation(_MsgTok.SERVER, '366', 3),
+   _MsgParseExpectation(_MsgTok.SERVER, '372', 2),
+   _MsgParseExpectation(_MsgTok.SERVER, '376', 2),
+   _MsgParseExpectation(_MsgTok.SERVER, '396', 3),
+   _MsgParseExpectation(_MsgTok.SERVER, '401', 3),
+   _MsgParseExpectation(_MsgTok.SERVER, '432', 3),
+   _MsgParseExpectation(_MsgTok.SERVER, '433', 3),
+   _MsgParseExpectation(_MsgTok.SERVER, '900', 4),
+   _MsgParseExpectation(_MsgTok.SERVER, '903',
+                        params=(_MsgTok.NICKNAME_ME, _MsgTok.ANY)),
+   _MsgParseExpectation(_MsgTok.SERVER, '904',
+                        params=(_MsgTok.NICKNAME_ME, _MsgTok.ANY)),
+   _MsgParseExpectation(_MsgTok.NONE, 'AUTHENTICATE', params=('+',)),
+   _MsgParseExpectation(_MsgTok.SERVER, 'CAP', 3, 15),
+   _MsgParseExpectation(_MsgTok.NONE, 'ERROR', params=(_MsgTok.ANY,)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS_ME, 'JOIN',
+                        params=(_MsgTok.CHANNEL,)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS, 'JOIN',
+                        params=(_MsgTok.CHANNEL,)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS_ME, 'MODE',
+                        params=(_MsgTok.NICKNAME_ME, _MsgTok.ANY)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS_ME, 'NICK',
+                        params=(_MsgTok.NICKNAME,)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS, 'NOTICE', 2),
+   _MsgParseExpectation(_MsgTok.SERVER, 'NOTICE', 2),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS_ME, 'PART',
+                        params=(_MsgTok.CHANNEL,)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS, 'PART',
+                        params=(_MsgTok.CHANNEL,)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS, 'PART',
+                        params=(_MsgTok.CHANNEL, _MsgTok.ANY)),
+   _MsgParseExpectation(_MsgTok.NONE, 'PING', params=(_MsgTok.ANY,)),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS, 'PRIVMSG', 2),
+   _MsgParseExpectation(_MsgTok.SERVER, 'PRIVMSG', 2),
+   _MsgParseExpectation(_MsgTok.USER_ADDRESS, 'QUIT', params=(_MsgTok.ANY,)),
 )
 
 
-class _IrcMsg(IrcMessage):
-    'Extends IrcMessage with some conveniences.'
-
-    def source_is_server(self, compare: str) -> bool:
-        'Return if .source parse-able as our server.'
-        return '.' in self.source\
-            and compare.split('.')[-2:] == self.source.split('.')[-2:]\
-            and not ('!' in self.source or '@' in self.source)
-
-    @property
-    def source_to_user_address(self) -> Optional[tuple[str, ...]]:
-        'Parse .source into toks of full user address.'
-        toks = self.source.split('!')
-        if len(toks) != 2:
-            return None
-        toks = toks[0:1] + toks[1].split('@')
-        if len(toks) != 3:
-            return None
-        return tuple(toks)
-
-
 class LogScope(Enum):
     'Where log messages should go.'
     ALL = auto()
@@ -175,7 +148,7 @@ class _IrcConnection(BaseIrcConnection, ClientIdMixin):
 
     def _make_recv_event(self, msg: IrcMessage) -> ClientEvent:
         return ClientEvent.affector('handle_msg', client_id=self.client_id
-                                    ).kw(msg=_IrcMsg.from_raw(msg.raw))
+                                    ).kw(msg=msg)
 
     def _on_handled_loop_exception(self, e: IrcConnAbortException
                                    ) -> ClientEvent:
@@ -597,60 +570,66 @@ class Client(ABC, ClientQueueMixin):
         self._log(f'connection broken: {e}', alert=True)
         self.close()
 
-    def _match_msg(self, msg: _IrcMsg, verb: str):
+    def _match_msg(self, msg: IrcMessage, verb: str):
         'Test .source, .verb, .params.'
         for ex in [ex for ex in _EXPECTATIONS if verb == ex.verb == msg.verb]:
-            to_return: dict[str, Any] = {'': ''}
-            if ex.source is _MsgSource.NONE:
-                if msg.source != '':
-                    continue
-            elif ex.source is _MsgSource.SERVER:
-                if not msg.source_is_server(self._db.hostname):
+            if not ex.params:
+                len_p = len(msg.params)
+                if len_p < ex.len_min_params:
                     continue
-            elif ex.source in {_MsgSource.USER_ADDRESS,
-                               _MsgSource.USER_ADDRESS_ME}:
-                if (toks := msg.source_to_user_address):
-                    to_return['sender'] = toks[0]
-                    if ex.source is _MsgSource.USER_ADDRESS_ME:
-                        if not self._db.nickname == toks[0]:
-                            continue
-                        to_return['sender_me'] = to_return['sender']
-                else:
+                if ex.len_max_params and len_p > ex.len_max_params:
                     continue
-            len_params = len(msg.params)
-            if ex.params:
-                if len_params != len(ex.params):
+                if (not ex.len_max_params) and len_p != ex.len_min_params:
                     continue
-                for idx, exp_param in enumerate(ex.params):
-                    param = msg.params[idx]
-                    if isinstance(exp_param, str) and exp_param != param:
-                        continue
-                    if exp_param is _MsgParam.CHANNEL:
-                        if param[0] != '#':
-                            continue
-                        to_return['ch_name'] = param
-                        to_return['channel'] = self._db.chan(param)
-                    elif exp_param in {_MsgParam.NICKNAME,
-                                       _MsgParam.NICKNAME_ME}:
-                        if param[0] in '~&@%+# ':
-                            continue
-                        to_return['nickname'] = param
-                        if exp_param is _MsgParam.NICKNAME_ME:
-                            if param != self._db.nickname:
-                                continue
-                        to_return['nickname_me'] = to_return['nickname']
-                    elif exp_param is _MsgParam.ANY:
-                        to_return['any'] = param
-            elif len_params < ex.len_min_params:
-                continue
-            elif ex.len_max_params and len_params > ex.len_max_params:
-                continue
-            elif (not ex.len_max_params) and len_params != ex.len_min_params:
+            to_return: dict[str, Any] = {'': ''}
+            ex_tok_fields = tuple([ex.source] + list(ex.params))
+            msg_tok_fields = tuple([msg.source] + list(msg.params))
+            if ex.params and len(ex_tok_fields) != len(msg_tok_fields):
                 continue
-            return to_return
+            passing = True
+            for idx, ex_tok in enumerate(ex_tok_fields):
+                passing = False
+                msg_tok = msg_tok_fields[idx]
+                if ex_tok is _MsgTok.NONE and msg_tok != '':
+                    break
+                if ex_tok is _MsgTok.SERVER\
+                        and ('.' not in msg_tok or set('@!') & set(msg_tok)):
+                    break
+                key_nick = 'sender' if not idx else 'nickname'
+                key_nick_me = f'{key_nick}_me'
+                if ex_tok is _MsgTok.CHANNEL:
+                    if msg_tok[0] != '#':
+                        break
+                    to_return |= {'ch_name': msg_tok,
+                                  'channel': self._db.chan(msg_tok)}
+                elif ex_tok in {_MsgTok.NICKNAME, _MsgTok.NICKNAME_ME}:
+                    if msg_tok[0] in '~&@%+# ':
+                        break
+                    to_return[key_nick] = msg_tok
+                    if ex_tok is _MsgTok.NICKNAME_ME:
+                        if msg_tok != self._db.nickname:
+                            break
+                        to_return[key_nick_me] = to_return[key_nick]
+                elif ex_tok in {_MsgTok.USER_ADDRESS, _MsgTok.USER_ADDRESS_ME}:
+                    toks = msg_tok.split('!')
+                    if len(toks) != 2:
+                        break
+                    toks = toks[0:1] + toks[1].split('@')
+                    if len(toks) != 3:
+                        break
+                    to_return[key_nick] = toks[0]
+                    if ex_tok is _MsgTok.USER_ADDRESS_ME:
+                        if not self._db.nickname == toks[0]:
+                            break
+                        to_return[key_nick_me] = to_return[key_nick]
+                elif ex_tok is _MsgTok.ANY:
+                    to_return['any'] = msg_tok
+                passing = True
+            if passing:
+                return to_return
         return False
 
-    def handle_msg(self, msg: _IrcMsg) -> None:
+    def handle_msg(self, msg: IrcMessage) -> None:
         'Log msg.raw, then process incoming msg into appropriate client steps.'
         self._log(msg.raw, scope=LogScope.RAW, out=False)
         if _NumericsToConfirmNickname.contain(msg.verb):