home · contact · privacy
Make curses client capable of websocket _and_ raw tcp connections.
[plomrogue2-experiments] / new2 / plomrogue / io_tcp.py
index 78e43f539df4aea810a15c3abe61366d807c4993..b030f1b9f1c98332084763812bf152666004bb7a 100644 (file)
@@ -6,6 +6,7 @@ socketserver.TCPServer.allow_reuse_address = True
 
 
 
+from plomrogue.errors import BrokenSocketConnection
 class PlomSocket:
 
     def __init__(self, socket):
@@ -32,7 +33,6 @@ class PlomSocket:
         <http://stackoverflow.com/q/34919846>
 
         """
-        from plomrogue.errors import BrokenSocketConnection
         escaped_message = ''
         for char in message:
             if char in ('\\', '$'):
@@ -46,6 +46,7 @@ class PlomSocket:
             try:
                 sent = self.socket.send(data[totalsent:])
                 socket_broken = sent == 0
+                totalsent = totalsent + sent
             except OSError as err:
                 if err.errno == 9:  # "Bad file descriptor", when connection broken
                     socket_broken = True
@@ -53,7 +54,6 @@ class PlomSocket:
                     raise err
             if socket_broken and not silent_connection_break:
                 raise BrokenSocketConnection
-            totalsent = totalsent + sent
 
     def recv(self):
         """Get full send()-prepared message from self.socket.
@@ -77,12 +77,14 @@ class PlomSocket:
         data = b''
         msg = b''
         while True:
-            data += self.socket.recv(1024)
+            try:
+                data = self.socket.recv(1024)
+            except OSError as err:
+                if err.errno == 9:  # "Bad file descriptor", when connection broken
+                    raise BrokenSocketConnection
             if 0 == len(data):
-                return
-            cut_off = 0
+                break
             for c in data:
-                cut_off += 1
                 if esc:
                     msg += bytes([c])
                     esc = False
@@ -93,7 +95,6 @@ class PlomSocket:
                         yield msg.decode()
                     except UnicodeDecodeError:
                         yield None
-                    data = data[cut_off:]
                     msg = b''
                 else:
                     msg += bytes([c])
@@ -102,14 +103,12 @@ class PlomSocket:
 
 class PlomSocketSSL(PlomSocket):
 
-    def __init__(self, *args, server_side=False, **kwargs):
+    def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs):
         import ssl
-        print('DEBUG', args, kwargs)
         super().__init__(*args, **kwargs)
         if server_side:
             self.socket = ssl.wrap_socket(self.socket, server_side=True,
-                                          certfile="server.pem",
-                                          keyfile="key.pem")
+                                          certfile=certfile, keyfile=keyfile)
         else:
             self.socket = ssl.wrap_socket(self.socket)
 
@@ -117,8 +116,7 @@ class PlomSocketSSL(PlomSocket):
 
 class IO_Handler(socketserver.BaseRequestHandler):
 
-    def __init__(self, *args, socket_class=PlomSocket, **kwargs):
-        self.socket_class = socket_class
+    def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
 
     def handle(self):
@@ -149,10 +147,13 @@ class IO_Handler(socketserver.BaseRequestHandler):
         import uuid
         import queue
         import threading
-        if self.socket_class == PlomSocketSSL:
-            plom_socket = self.socket_class(self.request, server_side=True)
+        if self.server.socket_class == PlomSocketSSL:
+            plom_socket = self.server.socket_class(self.request,
+                                                   server_side=True,
+                                                   certfile=self.server.certfile,
+                                                   keyfile=self.server.keyfile)
         else:
-            plom_socket = self.socket_class(self.request)
+            plom_socket = self.server.socket_class(self.request)
         print('CONNECTION FROM:', str(self.client_address))
         connection_id = uuid.uuid4()
         queue_in = queue.Queue()
@@ -176,13 +177,6 @@ class IO_Handler(socketserver.BaseRequestHandler):
 
 
 
-class IO_HandlerSSL(IO_Handler):
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, socket_class=PlomSocketSSL, **kwargs)
-
-
-
 class PlomTCPServer(socketserver.ThreadingTCPServer):
     """Bind together threaded IO handling server and message queue.
 
@@ -192,8 +186,9 @@ class PlomTCPServer(socketserver.ThreadingTCPServer):
 
     """
 
-    def __init__(self, queue, port, host='127.0.0.1', io_handler=IO_Handler, *args, **kwargs):
-        super().__init__((host, port), io_handler, *args, **kwargs)
+    def __init__(self, queue, port, host='127.0.0.1', *args, **kwargs):
+        super().__init__((host, port), IO_Handler, *args, **kwargs)
+        self.socket_class = PlomSocket
         self.queue_out = queue
         self.daemon_threads = True  # Else, server's threads have daemon=False.
         self.clients = {}
@@ -202,5 +197,8 @@ class PlomTCPServer(socketserver.ThreadingTCPServer):
 
 class PlomTCPServerSSL(PlomTCPServer):
 
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, host='0.0.0.0', io_handler=IO_HandlerSSL, **kwargs)
+    def __init__(self, *args, certfile=None, keyfile=None, **kwargs):
+        super().__init__(*args, host='0.0.0.0', **kwargs)
+        self.certfile = certfile
+        self.keyfile = keyfile
+        self.socket_class = PlomSocketSSL