X-Git-Url: https://plomlompom.com/repos/do_todos?a=blobdiff_plain;f=new2%2Fplomrogue%2Fio_tcp.py;h=09b9db1bbf3d0f47539a2c4aaea27aacd9c0c18c;hb=8f4f247a8c36610a5cd4eb03ddb26dcc701e38ab;hp=f0a49a97dad632fa4f3dc10770432f424ba90155;hpb=31d5ee2ee37c0e82e776053a1311c99dda2255e7;p=plomrogue2-experiments diff --git a/new2/plomrogue/io_tcp.py b/new2/plomrogue/io_tcp.py index f0a49a9..09b9db1 100644 --- a/new2/plomrogue/io_tcp.py +++ b/new2/plomrogue/io_tcp.py @@ -6,6 +6,7 @@ socketserver.TCPServer.allow_reuse_address = True +from plomrogue.errors import BrokenSocketConnection class PlomSocket: def __init__(self, socket): @@ -32,7 +33,6 @@ class PlomSocket: """ - from plomrogue.errors import BrokenSocketConnection escaped_message = '' for char in message: if char in ('\\', '$'): @@ -77,7 +77,11 @@ class PlomSocket: data = b'' msg = b'' while True: - data = self.socket.recv(1024) + try: + data = self.socket.recv(1024) + except OSError as err: + if err.errno == 9: # "Bad file descriptor", when connection broken + raise BrokenSocketConnection if 0 == len(data): break for c in data: @@ -99,14 +103,12 @@ class PlomSocket: class PlomSocketSSL(PlomSocket): - def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs): + def __init__(self, *args, certfile, keyfile, **kwargs): import ssl super().__init__(*args, **kwargs) - if server_side: - self.socket = ssl.wrap_socket(self.socket, server_side=True, - certfile=certfile, keyfile=keyfile) - else: - self.socket = ssl.wrap_socket(self.socket) + self.send('NEED_SSL') + self.socket = ssl.wrap_socket(self.socket, server_side=True, + certfile=certfile, keyfile=keyfile) @@ -145,7 +147,6 @@ class IO_Handler(socketserver.BaseRequestHandler): import threading if self.server.socket_class == PlomSocketSSL: plom_socket = self.server.socket_class(self.request, - server_side=True, certfile=self.server.certfile, keyfile=self.server.keyfile) else: @@ -193,7 +194,7 @@ class PlomTCPServer(socketserver.ThreadingTCPServer): class PlomTCPServerSSL(PlomTCPServer): - def __init__(self, *args, certfile=None, keyfile=None, **kwargs): + def __init__(self, *args, certfile, keyfile, **kwargs): super().__init__(*args, host='0.0.0.0', **kwargs) self.certfile = certfile self.keyfile = keyfile