home · contact · privacy
Further refine msg source expectations.
authorChristian Heller <c.heller@plomlompom.de>
Tue, 19 Aug 2025 03:39:02 +0000 (05:39 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Tue, 19 Aug 2025 03:39:02 +0000 (05:39 +0200)
ircplom/client.py

index d1c2deef9ffcdb3570af291d322c00c1bbceea2d..648e284fe4000373214250e90615950be8c4e133 100644 (file)
@@ -39,35 +39,40 @@ _NUMERICS_TO_IGNORE = (
 )
 
 
+class _MsgSource(Enum):
+    NONE = auto()
+    SERVER = auto()
+
+
 class _MsgParseExpectation(NamedTuple):
     len_min_params: int = 0
     len_max_params: int = 0
     params: tuple[str, ...] = tuple()
-    source: Optional[str] = None
+    source: Optional[_MsgSource] = None
 
 
 _EXPECTATIONS: dict[str, _MsgParseExpectation] = {
-   '005': _MsgParseExpectation(3, 15),
-   '353': _MsgParseExpectation(4),
-   '366': _MsgParseExpectation(3),
-   '372': _MsgParseExpectation(2),
-   '376': _MsgParseExpectation(2),
-   '396': _MsgParseExpectation(3),
-   '401': _MsgParseExpectation(3),
-   '432': _MsgParseExpectation(3),
-   '433': _MsgParseExpectation(3),
-   '900': _MsgParseExpectation(4),
-   '903': _MsgParseExpectation(2),
-   '904': _MsgParseExpectation(2),
-   'AUTHENTICATE': _MsgParseExpectation(params=('+',), source=''),
-   'CAP': _MsgParseExpectation(3, 15),
-   'ERROR': _MsgParseExpectation(1, source=''),
+   '005': _MsgParseExpectation(3, 15, source=_MsgSource.SERVER),
+   '353': _MsgParseExpectation(4, source=_MsgSource.SERVER),
+   '366': _MsgParseExpectation(3, source=_MsgSource.SERVER),
+   '372': _MsgParseExpectation(2, source=_MsgSource.SERVER),
+   '376': _MsgParseExpectation(2, source=_MsgSource.SERVER),
+   '396': _MsgParseExpectation(3, source=_MsgSource.SERVER),
+   '401': _MsgParseExpectation(3, source=_MsgSource.SERVER),
+   '432': _MsgParseExpectation(3, source=_MsgSource.SERVER),
+   '433': _MsgParseExpectation(3, source=_MsgSource.SERVER),
+   '900': _MsgParseExpectation(4, source=_MsgSource.SERVER),
+   '903': _MsgParseExpectation(2, source=_MsgSource.SERVER),
+   '904': _MsgParseExpectation(2, source=_MsgSource.SERVER),
+   'AUTHENTICATE': _MsgParseExpectation(params=('+',), source=_MsgSource.NONE),
+   'CAP': _MsgParseExpectation(3, 15, source=_MsgSource.SERVER),
+   'ERROR': _MsgParseExpectation(1, source=_MsgSource.NONE),
    'JOIN': _MsgParseExpectation(1),
    'MODE': _MsgParseExpectation(2),
    'NICK': _MsgParseExpectation(1),
    'NOTICE': _MsgParseExpectation(2),
    'PART': _MsgParseExpectation(1, 2),
-   'PING': _MsgParseExpectation(1, source=''),
+   'PING': _MsgParseExpectation(1, source=_MsgSource.NONE),
    'PRIVMSG': _MsgParseExpectation(2),
    'QUIT': _MsgParseExpectation(1),
 }
@@ -76,19 +81,19 @@ _EXPECTATIONS: dict[str, _MsgParseExpectation] = {
 class _IrcMsg(IrcMessage):
     'Extends IrcMessage with some conveniences.'
 
-    def match(self, verb: str) -> bool:
-        'Test .verb, .params.'
-        if not verb == self.verb:
-            return False
-        expect = _EXPECTATIONS[verb]
-        if expect.source is not None and self.source != expect.source:
-            return False
-        if expect.params:
-            return self.params == expect.params
-        n_msg_params = len(self.params)
-        if expect.len_max_params <= expect.len_min_params:
-            return n_msg_params == expect.len_min_params
-        return expect.len_min_params <= n_msg_params <= expect.len_max_params
+    def match(self, verb: str) -> bool:
+        'Test .verb, .params.'
+        if not verb == self.verb:
+            return False
+        expect = _EXPECTATIONS[verb]
+        if expect.source is not None and self.source != expect.source:
+            return False
+        if expect.params:
+            return self.params == expect.params
+        n_msg_params = len(self.params)
+        if expect.len_max_params <= expect.len_min_params:
+            return n_msg_params == expect.len_min_params
+        return expect.len_min_params <= n_msg_params <= expect.len_max_params
 
     @property
     def nick_from_source(self) -> str:
@@ -567,6 +572,25 @@ class Client(ABC, ClientQueueMixin):
         self._log(f'connection broken: {e}', alert=True)
         self.close()
 
+    def _match_msg(self, msg: _IrcMsg, verb: str) -> bool:
+        'Test .verb, .params.'
+        if not msg.verb == verb:
+            return False
+        expect = _EXPECTATIONS[verb]
+        if expect.source is _MsgSource.NONE and msg.source != '':
+            return False
+        if expect.source is _MsgSource.SERVER and (
+                '!' in msg.source or '@' in msg.source or '.' not in msg.source
+                or self._db.hostname.split('.')[-2:]
+                != msg.source.split('.')[-2:]):
+            return False
+        if expect.params:
+            return msg.params == expect.params
+        n_msg_params = len(msg.params)
+        if expect.len_max_params <= expect.len_min_params:
+            return n_msg_params == expect.len_min_params
+        return expect.len_min_params <= n_msg_params <= expect.len_max_params
+
     def handle_msg(self, msg: _IrcMsg) -> None:
         'Log msg.raw, then process incoming msg into appropriate client steps.'
         self._log(msg.raw, scope=LogScope.RAW, out=False)
@@ -574,26 +598,27 @@ class Client(ABC, ClientQueueMixin):
             self.set_nick(msg.params[0], confirmed=True)
         if _NumericsToIgnore.contain(msg.verb):
             return
-        if msg.match('005'):  # RPL_ISUPPORT
+        if self._match_msg(msg, '005'):  # RPL_ISUPPORT
             self._db.process_isupport(msg.params[1:-1])
-        elif msg.match('353') and msg.params[1] == '=':  # RPL_NAMREPLY
+        elif self._match_msg(msg, '353')\
+                and msg.params[1] == '=':  # RPL_NAMREPLY
             for user in msg.params[3].split():
                 self._db.chan(msg.params[2]).append_completable(
                         'users', user.lstrip('~&@%+'))
-        elif msg.match('366'):  # RPL_ENDOFNAMES
+        elif self._match_msg(msg, '366'):  # RPL_ENDOFNAMES
             self._db.chan(msg.params[1]).declare_complete('users')
-        elif msg.match('372'):  # RPL_MOTD
+        elif self._match_msg(msg, '372'):  # RPL_MOTD
             self._db.append_completable('motd', msg.params[1])
-        elif msg.match('376'):  # RPL_ENDOFMOTD
+        elif self._match_msg(msg, '376'):  # RPL_ENDOFMOTD
             self._db.declare_complete('motd')
-        elif msg.match('396'):  # RPL_VISIBLEHOST
+        elif self._match_msg(msg, '396'):  # RPL_VISIBLEHOST
             # '@'-split because <https://defs.ircdocs.horse/defs/numerics>
             # claims: "<hostname> can also be in the form <user@hostname>"
             self._db.client_host = msg.params[1].split('@')[-1]
-        elif msg.match('401'):  # ERR_NOSUCHNICK
+        elif self._match_msg(msg, '401'):  # ERR_NOSUCHNICK
             self._log(f'{msg.params[1]} not online', scope=LogScope.CHAT,
                       target=msg.params[1], alert=True)
-        elif msg.match('432'):  # ERR_ERRONEOUSNICKNAME
+        elif self._match_msg(msg, '432'):  # ERR_ERRONEOUSNICKNAME
             alert = 'nickname refused for bad format'
             if msg.params[0] == '*':
                 alert += ', giving up'
@@ -601,23 +626,24 @@ class Client(ABC, ClientQueueMixin):
             else:
                 self.set_nick(msg.params[0], confirmed=True)
             self._log(alert, alert=True)
-        elif msg.match('433'):  # ERR_NICKNAMEINUSE
+        elif self._match_msg(msg, '433'):  # ERR_NICKNAMEINUSE
             self._log('nickname already in use, trying increment', alert=True)
             self.set_nick(self._db.nick_incremented)
-        elif msg.match('900'):  # RPL_LOGGEDIN
+        elif self._match_msg(msg, '900'):  # RPL_LOGGEDIN
             self._db.nickname, remainder = msg.params[1].split('!', maxsplit=1)
             self._db.username, self._db.client_host = remainder.split('@')
             self._db.sasl_account = msg.params[2]
-        elif msg.match('903') or msg.match('904'):  # RPL_SASLSUCCESS, …
+        elif self._match_msg(msg, '903')\
+                or self._match_msg(msg, '904'):  # RPL_SASLSUCCESS, …
             self._db.sasl_auth_state = 'WIN' if msg.verb == '903' else 'FAIL'
             self._caps.end_negotiation()                  # … or ERR_SASLFAIL
-        elif msg.match('AUTHENTICATE'):
+        elif self._match_msg(msg, 'AUTHENTICATE'):
             auth = b64encode((self._db.nick_wanted + '\0'
                               + self._db.nick_wanted + '\0'
                               + self._db.password
                               ).encode('utf-8')).decode('utf-8')
             self.send(IrcMessage('AUTHENTICATE', (auth,)))
-        elif msg.match('CAP'):
+        elif self._match_msg(msg, 'CAP'):
             if (self._caps.process_msg(msg.params[1:])
                     and self._db.caps.has('sasl')
                     and 'PLAIN' in self._db.caps['sasl'].data.split(',')):
@@ -626,32 +652,34 @@ class Client(ABC, ClientQueueMixin):
                     self.send(IrcMessage('AUTHENTICATE', ('PLAIN',)))
                 else:
                     self._caps.end_negotiation()
-        elif msg.match('ERROR'):
+        elif self._match_msg(msg, 'ERROR'):
             self.close()
-        elif msg.match('JOIN'):
+        elif self._match_msg(msg, 'JOIN'):
             channel = msg.params[0]
             log_msg = f'{msg.nick_from_source} {msg.verb.lower()}s {channel}'
             self._log(log_msg, scope=LogScope.CHAT, target=channel)
             if msg.nick_from_source != self._db.nickname:
                 self._db.chan(channel).append_completable(
                         'users', msg.nick_from_source, stay_complete=True)
-        elif msg.match('MODE') and msg.params[0] == self._db.nickname:
+        elif self._match_msg(msg, 'MODE')\
+                and msg.params[0] == self._db.nickname:
             self._db.user_modes = msg.params[1]
-        elif msg.match('NICK') and msg.nick_from_source == self._db.nickname:
+        elif self._match_msg(msg, 'NICK')\
+                and msg.nick_from_source == self._db.nickname:
             self.set_nick(msg.params[0], confirmed=True)
-        elif msg.match('NOTICE') and (msg.params[0] != '*'
-                                      or not self._db.nickname):
+        elif self._match_msg(msg, 'NOTICE')\
+                and (msg.params[0] != '*' or not self._db.nickname):
             kw: dict[str, str | LogScope] = {}
             if '!' in msg.source:
                 kw |= {'sender': msg.nick_from_source, 'scope': LogScope.CHAT}
             self._log(msg.params[-1], out=False, target=msg.params[0],
                       as_notice=True, **kw)
-        elif msg.match('PRIVMSG') and msg.params[0] != '*':
+        elif self._match_msg(msg, 'PRIVMSG') and msg.params[0] != '*':
             kw = {}
             if '!' in msg.source:
                 kw |= {'sender': msg.nick_from_source, 'scope': LogScope.CHAT}
             self._log(msg.params[-1], out=False, target=msg.params[0], **kw)
-        elif msg.match('PART'):
+        elif self._match_msg(msg, 'PART'):
             channel = msg.params[0]
             log_msg = f'{msg.nick_from_source} {msg.verb.lower()}s {channel}'
             if len(msg.params) == 2:
@@ -662,9 +690,9 @@ class Client(ABC, ClientQueueMixin):
             else:
                 self._db.chan(channel).remove_completable(
                         'users', msg.nick_from_source, stay_complete=True)
-        elif msg.match('PING'):
+        elif self._match_msg(msg, 'PING'):
             self.send(IrcMessage(verb='PONG', params=(msg.params[0],)))
-        elif msg.match('QUIT'):
+        elif self._match_msg(msg, 'QUIT'):
             user = msg.nick_from_source
             for chan_name in self._db.chan_names:
                 chan = self._db.chan(chan_name)