From a455c9f392345a42c0dec745db3b65080172db77 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Fri, 15 Feb 2019 03:32:51 +0100
Subject: [PATCH] Refactor socket code.

---
 client-curses.py  | 37 +++++++++---------
 plom_socket.py    | 95 +++++++++++++++++++++++++++++++++++++++++++++++
 plom_socket_io.py | 86 ------------------------------------------
 server_/io.py     | 24 +++++-------
 4 files changed, 123 insertions(+), 119 deletions(-)
 create mode 100644 plom_socket.py
 delete mode 100644 plom_socket_io.py

diff --git a/client-curses.py b/client-curses.py
index a8add97..f65d3a0 100755
--- a/client-curses.py
+++ b/client-curses.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python3
 import curses
-import plom_socket_io
+import plom_socket
 import socket
 import threading
 from parser import ArgError, Parser
@@ -186,8 +186,8 @@ ASCII_printable = ' !"#$%&\'\(\)*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWX'
                   'YZ[\\]^_\`abcdefghijklmnopqrstuvwxyz{|}~'
 
 
-def recv_loop(socket, game):
-    for msg in plom_socket_io.recv(s):
+def recv_loop(plom_socket, game):
+    for msg in plom_socket.recv():
         game.handle_input(msg)
 
 
@@ -353,8 +353,8 @@ class TurnWidget(Widget):
 
 class TUI:
 
-    def __init__(self, socket, game):
-        self.socket = socket
+    def __init__(self, plom_socket, game):
+        self.socket = plom_socket
         self.game = game
         self.parser = Parser(self.game)
         self.to_update = {'edit': False}
@@ -400,26 +400,26 @@ class TUI:
                 elif map_mode:
                     if type(self.game.world.map_) == MapSquare:
                         if key == 'a':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE LEFT')
+                            self.socket.send('TASK:MOVE LEFT')
                         elif key == 'd':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE RIGHT')
+                            self.socket.send('TASK:MOVE RIGHT')
                         elif key == 'w':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE UP')
+                            self.socket.send('TASK:MOVE UP')
                         elif key == 's':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE DOWN')
+                            self.socket.send('TASK:MOVE DOWN')
                     elif type(self.game.world.map_) == MapHex:
                         if key == 'w':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE UPLEFT')
+                            self.socket.send('TASK:MOVE UPLEFT')
                         elif key == 'e':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE UPRIGHT')
+                            self.socket.send('TASK:MOVE UPRIGHT')
                         if key == 's':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE LEFT')
+                            self.socket.send('TASK:MOVE LEFT')
                         elif key == 'd':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE RIGHT')
+                            self.socket.send('TASK:MOVE RIGHT')
                         if key == 'x':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE DOWNLEFT')
+                            self.socket.send('TASK:MOVE DOWNLEFT')
                         elif key == 'c':
-                            plom_socket_io.send(self.socket, 'TASK:MOVE DOWNRIGHT')
+                            self.socket.send('TASK:MOVE DOWNRIGHT')
                 else:
                     if len(key) == 1 and key in ASCII_printable and \
                             len(self.to_send) < len(self.edit):
@@ -429,7 +429,7 @@ class TUI:
                         self.to_send[:] = self.to_send[:-1]
                         self.to_update['edit'] = True
                     elif key == '\n':  # Return key
-                        plom_socket_io.send(self.socket, ''.join(self.to_send))
+                        self.socket.send(''.join(self.to_send))
                         self.to_send[:] = []
                         self.to_update['edit'] = True
             except curses.error:
@@ -439,7 +439,8 @@ class TUI:
 
 
 s = socket.create_connection(('127.0.0.1', 5000))
+plom_socket = plom_socket.PlomSocket(s)
 game = Game()
-t = threading.Thread(target=recv_loop, args=(s, game))
+t = threading.Thread(target=recv_loop, args=(plom_socket, game))
 t.start()
-TUI(s, game)
+TUI(plom_socket, game)
diff --git a/plom_socket.py b/plom_socket.py
new file mode 100644
index 0000000..e43560d
--- /dev/null
+++ b/plom_socket.py
@@ -0,0 +1,95 @@
+class BrokenSocketConnection(Exception):
+    pass
+
+
+
+class PlomSocket:
+
+    def __init__(self, socket):
+        self.socket = socket
+
+    def send(self, message, silent_connection_break=False):
+        """Send via self.socket, encoded/delimited as way recv() expects.
+
+        In detail, all \ and $ in message are escaped with prefixed \,
+        and an unescaped $ is appended as a message delimiter. Then,
+        socket.send() is called as often as necessary to ensure
+        message is sent fully, as socket.send() due to buffering may
+        not send all of it right away.
+
+        Assuming socket is blocking, it's rather improbable that
+        socket.send() will be partial / return a positive value less
+        than the (byte) length of msg – but not entirely out of the
+        question. See: - <http://stackoverflow.com/q/19697218> -
+        <http://stackoverflow.com/q/2618736> -
+        <http://stackoverflow.com/q/8900474>
+
+        This also handles a socket.send() return value of 0, which
+        might be possible or not (?) for blocking sockets: -
+        <http://stackoverflow.com/q/34919846>
+
+        """
+        escaped_message = ''
+        for char in message:
+            if char in ('\\', '$'):
+                escaped_message += '\\'
+            escaped_message += char
+        escaped_message += '$'
+        data = escaped_message.encode()
+        totalsent = 0
+        while totalsent < len(data):
+            socket_broken = False
+            try:
+                sent = self.socket.send(data[totalsent:])
+                socket_broken = sent == 0
+            except OSError as err:
+                if err.errno == 9:  # "Bad file descriptor", when connection broken
+                    socket_broken = True
+                else:
+                    raise err
+            if socket_broken and not silent_connection_break:
+                raise BrokenSocketConnection
+            totalsent = totalsent + sent
+
+    def recv(self):
+        """Get full send()-prepared message from self.socket.
+
+        In detail, socket.recv() is looped over for sequences of bytes
+        that can be decoded as a Unicode string delimited by an
+        unescaped $, with \ and $ escapable by \. If a sequence of
+        characters that ends in an unescaped $ cannot be decoded as
+        Unicode, None is returned as its representation. Stop once
+        socket.recv() returns nothing.
+
+        Under the hood, the TCP stack receives packets that construct
+        the input payload in an internal buffer; socket.recv(BUFSIZE)
+        pops up to BUFSIZE bytes from that buffer, without knowledge
+        either about the input's segmentation into packets, or whether
+        the input is segmented in any other meaningful way; that's why
+        we do our own message segmentation with $ as a delimiter.
+
+        """
+        esc = False
+        data = b''
+        msg = b''
+        while True:
+            data += self.socket.recv(1024)
+            if 0 == len(data):
+                return
+            cut_off = 0
+            for c in data:
+                cut_off += 1
+                if esc:
+                    msg += bytes([c])
+                    esc = False
+                elif chr(c) == '\\':
+                    esc = True
+                elif chr(c) == '$':
+                    try:
+                        yield msg.decode()
+                    except UnicodeDecodeError:
+                        yield None
+                    data = data[cut_off:]
+                    msg = b''
+                else:
+                    msg += bytes([c])
diff --git a/plom_socket_io.py b/plom_socket_io.py
deleted file mode 100644
index ebde3c1..0000000
--- a/plom_socket_io.py
+++ /dev/null
@@ -1,86 +0,0 @@
-class BrokenSocketConnection(Exception):
-    pass
-
-
-def send(socket, message):
-    """Send message via socket, encoded and delimited the way recv() expects.
-
-    In detail, all \ and $ in message are escaped with prefixed \, and an
-    unescaped $ is appended as a message delimiter. Then, socket.send() is
-    called as often as necessary to ensure message is sent fully, as
-    socket.send() due to buffering may not send all of it right away.
-
-    Assuming socket is blocking, it's rather improbable that socket.send() will
-    be partial / return a positive value less than the (byte) length of msg –
-    but not entirely out of the question. See:
-    - <http://stackoverflow.com/q/19697218>
-    - <http://stackoverflow.com/q/2618736>
-    - <http://stackoverflow.com/q/8900474>
-
-    This also handles a socket.send() return value of 0, which might be
-    possible or not (?) for blocking sockets:
-    - <http://stackoverflow.com/q/34919846>
-    """
-    escaped_message = ''
-    for char in message:
-        if char in ('\\', '$'):
-            escaped_message += '\\'
-        escaped_message += char
-    escaped_message += '$'
-    data = escaped_message.encode()
-    totalsent = 0
-    while totalsent < len(data):
-        socket_broken = False
-        try:
-            sent = socket.send(data[totalsent:])
-            socket_broken = sent == 0
-        except OSError as err:
-            if err.errno == 9:  # "Bad file descriptor", when connection broken
-                socket_broken = True
-            else:
-                raise err
-        if socket_broken:
-            raise BrokenSocketConnection
-        totalsent = totalsent + sent
-
-
-def recv(socket):
-    """Get full send()-prepared message from socket.
-
-    In detail, socket.recv() is looped over for sequences of bytes that can be
-    decoded as a Unicode string delimited by an unescaped $, with \ and $
-    escapable by \. If a sequence of characters that ends in an unescaped $
-    cannot be decoded as Unicode, None is returned as its representation. Stop
-    once socket.recv() returns nothing.
-
-    Under the hood, the TCP stack receives packets that construct the input
-    payload in an internal buffer; socket.recv(BUFSIZE) pops up to BUFSIZE
-    bytes from that buffer, without knowledge either about the input's
-    segmentation into packets, or whether the input is segmented in any other
-    meaningful way; that's why we do our own message segmentation with $ as a
-    delimiter.
-    """
-    esc = False
-    data = b''
-    msg = b''
-    while True:
-        data += socket.recv(1024)
-        if 0 == len(data):
-            return
-        cut_off = 0
-        for c in data:
-            cut_off += 1
-            if esc:
-                msg += bytes([c])
-                esc = False
-            elif chr(c) == '\\':
-                esc = True
-            elif chr(c) == '$':
-                try:
-                    yield msg.decode()
-                except UnicodeDecodeError:
-                    yield None
-                data = data[cut_off:]
-                msg = b''
-            else:
-                msg += bytes([c])
diff --git a/server_/io.py b/server_/io.py
index 501399f..c8f6114 100644
--- a/server_/io.py
+++ b/server_/io.py
@@ -45,45 +45,39 @@ class IO_Handler(socketserver.BaseRequestHandler):
         instructions.
 
         """
-        import plom_socket_io
 
-        def caught_send(socket, message):
-            """Send message by socket, catch broken socket connection error."""
-            try:
-                plom_socket_io.send(socket, message)
-            except plom_socket_io.BrokenSocketConnection:
-                pass
-
-        def send_queue_messages(socket, queue_in, thread_alive):
+        def send_queue_messages(plom_socket, queue_in, thread_alive):
             """Send messages via socket from queue_in while thread_alive[0]."""
             while thread_alive[0]:
                 try:
                     msg = queue_in.get(timeout=1)
                 except queue.Empty:
                     continue
-                caught_send(socket, msg)
+                plom_socket.send(msg, True)
 
         import uuid
+        import plom_socket
+        plom_socket = plom_socket.PlomSocket(self.request)
         print('CONNECTION FROM:', str(self.client_address))
         connection_id = uuid.uuid4()
         queue_in = queue.Queue()
         self.server.queue_out.put(('ADD_QUEUE', connection_id, queue_in))
         thread_alive = [True]
         t = threading.Thread(target=send_queue_messages,
-                             args=(self.request, queue_in, thread_alive))
+                             args=(plom_socket, queue_in, thread_alive))
         t.start()
-        for message in plom_socket_io.recv(self.request):
+        for message in plom_socket.recv():
             if message is None:
-                caught_send(self.request, 'BAD MESSAGE')
+                plom_socket.send('BAD MESSAGE', True)
             elif 'QUIT' == message:
-                caught_send(self.request, 'BYE')
+                plom_socket.send('BYE', True)
                 break
             else:
                 self.server.queue_out.put(('COMMAND', connection_id, message))
         self.server.queue_out.put(('KILL_QUEUE', connection_id))
         thread_alive[0] = False
         print('CONNECTION CLOSED FROM:', str(self.client_address))
-        self.request.close()
+        plom_socket.socket.close()
 
 
 class GameIO():
-- 
2.30.2