X-Git-Url: https://plomlompom.com/repos/do_todos?a=blobdiff_plain;f=new2%2Fplomrogue%2Fio_tcp.py;h=09b9db1bbf3d0f47539a2c4aaea27aacd9c0c18c;hb=8f4f247a8c36610a5cd4eb03ddb26dcc701e38ab;hp=5dd2508a8708d5cca43836aceb396d9231ed1d0a;hpb=b747db93005261dbf46c657099be0bf687ad2ce3;p=plomrogue2-experiments diff --git a/new2/plomrogue/io_tcp.py b/new2/plomrogue/io_tcp.py index 5dd2508..09b9db1 100644 --- a/new2/plomrogue/io_tcp.py +++ b/new2/plomrogue/io_tcp.py @@ -6,6 +6,7 @@ socketserver.TCPServer.allow_reuse_address = True +from plomrogue.errors import BrokenSocketConnection class PlomSocket: def __init__(self, socket): @@ -32,7 +33,6 @@ class PlomSocket: """ - from plomrogue.errors import BrokenSocketConnection escaped_message = '' for char in message: if char in ('\\', '$'): @@ -46,6 +46,7 @@ class PlomSocket: try: sent = self.socket.send(data[totalsent:]) socket_broken = sent == 0 + totalsent = totalsent + sent except OSError as err: if err.errno == 9: # "Bad file descriptor", when connection broken socket_broken = True @@ -53,7 +54,6 @@ class PlomSocket: raise err if socket_broken and not silent_connection_break: raise BrokenSocketConnection - totalsent = totalsent + sent def recv(self): """Get full send()-prepared message from self.socket. @@ -77,12 +77,14 @@ class PlomSocket: data = b'' msg = b'' while True: - data += self.socket.recv(1024) + try: + data = self.socket.recv(1024) + except OSError as err: + if err.errno == 9: # "Bad file descriptor", when connection broken + raise BrokenSocketConnection if 0 == len(data): - return - cut_off = 0 + break for c in data: - cut_off += 1 if esc: msg += bytes([c]) esc = False @@ -93,7 +95,6 @@ class PlomSocket: yield msg.decode() except UnicodeDecodeError: yield None - data = data[cut_off:] msg = b'' else: msg += bytes([c]) @@ -102,21 +103,18 @@ class PlomSocket: class PlomSocketSSL(PlomSocket): - def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs): + def __init__(self, *args, certfile, keyfile, **kwargs): import ssl super().__init__(*args, **kwargs) - if server_side: - self.socket = ssl.wrap_socket(self.socket, server_side=True, - certfile=certfile, keyfile=keyfile) - else: - self.socket = ssl.wrap_socket(self.socket) + self.send('NEED_SSL') + self.socket = ssl.wrap_socket(self.socket, server_side=True, + certfile=certfile, keyfile=keyfile) class IO_Handler(socketserver.BaseRequestHandler): - def __init__(self, *args, socket_class=PlomSocket, **kwargs): - self.socket_class = socket_class + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def handle(self): @@ -147,12 +145,12 @@ class IO_Handler(socketserver.BaseRequestHandler): import uuid import queue import threading - if self.socket_class == PlomSocketSSL: - plom_socket = self.socket_class(self.request, server_side=True, - certfile=self.server.certfile, - keyfile=self.server.keyfile) + if self.server.socket_class == PlomSocketSSL: + plom_socket = self.server.socket_class(self.request, + certfile=self.server.certfile, + keyfile=self.server.keyfile) else: - plom_socket = self.socket_class(self.request) + plom_socket = self.server.socket_class(self.request) print('CONNECTION FROM:', str(self.client_address)) connection_id = uuid.uuid4() queue_in = queue.Queue() @@ -176,13 +174,6 @@ class IO_Handler(socketserver.BaseRequestHandler): -class IO_HandlerSSL(IO_Handler): - - def __init__(self, *args, **kwargs): - super().__init__(*args, socket_class=PlomSocketSSL, **kwargs) - - - class PlomTCPServer(socketserver.ThreadingTCPServer): """Bind together threaded IO handling server and message queue. @@ -192,8 +183,9 @@ class PlomTCPServer(socketserver.ThreadingTCPServer): """ - def __init__(self, queue, port, host='127.0.0.1', io_handler=IO_Handler, *args, **kwargs): - super().__init__((host, port), io_handler, *args, **kwargs) + def __init__(self, queue, port, host='127.0.0.1', *args, **kwargs): + super().__init__((host, port), IO_Handler, *args, **kwargs) + self.socket_class = PlomSocket self.queue_out = queue self.daemon_threads = True # Else, server's threads have daemon=False. self.clients = {} @@ -202,7 +194,8 @@ class PlomTCPServer(socketserver.ThreadingTCPServer): class PlomTCPServerSSL(PlomTCPServer): - def __init__(self, *args, certfile=None, keyfile=None, **kwargs): + def __init__(self, *args, certfile, keyfile, **kwargs): + super().__init__(*args, host='0.0.0.0', **kwargs) self.certfile = certfile self.keyfile = keyfile - super().__init__(*args, host='0.0.0.0', io_handler=IO_HandlerSSL, **kwargs) + self.socket_class = PlomSocketSSL