home · contact · privacy
Register game commands and tasks outside of game module.
[plomrogue2-experiments] / new2 / plomrogue / io_tcp.py
index f0a49a97dad632fa4f3dc10770432f424ba90155..09b9db1bbf3d0f47539a2c4aaea27aacd9c0c18c 100644 (file)
@@ -6,6 +6,7 @@ socketserver.TCPServer.allow_reuse_address = True
 
 
 
+from plomrogue.errors import BrokenSocketConnection
 class PlomSocket:
 
     def __init__(self, socket):
@@ -32,7 +33,6 @@ class PlomSocket:
         <http://stackoverflow.com/q/34919846>
 
         """
-        from plomrogue.errors import BrokenSocketConnection
         escaped_message = ''
         for char in message:
             if char in ('\\', '$'):
@@ -77,7 +77,11 @@ class PlomSocket:
         data = b''
         msg = b''
         while True:
-            data = self.socket.recv(1024)
+            try:
+                data = self.socket.recv(1024)
+            except OSError as err:
+                if err.errno == 9:  # "Bad file descriptor", when connection broken
+                    raise BrokenSocketConnection
             if 0 == len(data):
                 break
             for c in data:
@@ -99,14 +103,12 @@ class PlomSocket:
 
 class PlomSocketSSL(PlomSocket):
 
-    def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs):
+    def __init__(self, *args, certfile, keyfile, **kwargs):
         import ssl
         super().__init__(*args, **kwargs)
-        if server_side:
-            self.socket = ssl.wrap_socket(self.socket, server_side=True,
-                                          certfile=certfile, keyfile=keyfile)
-        else:
-            self.socket = ssl.wrap_socket(self.socket)
+        self.send('NEED_SSL')
+        self.socket = ssl.wrap_socket(self.socket, server_side=True,
+                                      certfile=certfile, keyfile=keyfile)
 
 
 
@@ -145,7 +147,6 @@ class IO_Handler(socketserver.BaseRequestHandler):
         import threading
         if self.server.socket_class == PlomSocketSSL:
             plom_socket = self.server.socket_class(self.request,
-                                                   server_side=True,
                                                    certfile=self.server.certfile,
                                                    keyfile=self.server.keyfile)
         else:
@@ -193,7 +194,7 @@ class PlomTCPServer(socketserver.ThreadingTCPServer):
 
 class PlomTCPServerSSL(PlomTCPServer):
 
-    def __init__(self, *args, certfile=None, keyfile=None, **kwargs):
+    def __init__(self, *args, certfile, keyfile, **kwargs):
         super().__init__(*args, host='0.0.0.0', **kwargs)
         self.certfile = certfile
         self.keyfile = keyfile