X-Git-Url: https://plomlompom.com/repos/foo.html?a=blobdiff_plain;f=new2%2Fplomrogue%2Fio_tcp.py;h=f0a49a97dad632fa4f3dc10770432f424ba90155;hb=3f6aeac609f6337713f6cca17e4f54960ecf4d7f;hp=78e43f539df4aea810a15c3abe61366d807c4993;hpb=f01848a97bb686e2b9c823cdf7fc6b59072dbd79;p=plomrogue2-experiments
diff --git a/new2/plomrogue/io_tcp.py b/new2/plomrogue/io_tcp.py
index 78e43f5..f0a49a9 100644
--- a/new2/plomrogue/io_tcp.py
+++ b/new2/plomrogue/io_tcp.py
@@ -46,6 +46,7 @@ class PlomSocket:
try:
sent = self.socket.send(data[totalsent:])
socket_broken = sent == 0
+ totalsent = totalsent + sent
except OSError as err:
if err.errno == 9: # "Bad file descriptor", when connection broken
socket_broken = True
@@ -53,7 +54,6 @@ class PlomSocket:
raise err
if socket_broken and not silent_connection_break:
raise BrokenSocketConnection
- totalsent = totalsent + sent
def recv(self):
"""Get full send()-prepared message from self.socket.
@@ -77,12 +77,10 @@ class PlomSocket:
data = b''
msg = b''
while True:
- data += self.socket.recv(1024)
+ data = self.socket.recv(1024)
if 0 == len(data):
- return
- cut_off = 0
+ break
for c in data:
- cut_off += 1
if esc:
msg += bytes([c])
esc = False
@@ -93,7 +91,6 @@ class PlomSocket:
yield msg.decode()
except UnicodeDecodeError:
yield None
- data = data[cut_off:]
msg = b''
else:
msg += bytes([c])
@@ -102,14 +99,12 @@ class PlomSocket:
class PlomSocketSSL(PlomSocket):
- def __init__(self, *args, server_side=False, **kwargs):
+ def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs):
import ssl
- print('DEBUG', args, kwargs)
super().__init__(*args, **kwargs)
if server_side:
self.socket = ssl.wrap_socket(self.socket, server_side=True,
- certfile="server.pem",
- keyfile="key.pem")
+ certfile=certfile, keyfile=keyfile)
else:
self.socket = ssl.wrap_socket(self.socket)
@@ -117,8 +112,7 @@ class PlomSocketSSL(PlomSocket):
class IO_Handler(socketserver.BaseRequestHandler):
- def __init__(self, *args, socket_class=PlomSocket, **kwargs):
- self.socket_class = socket_class
+ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def handle(self):
@@ -149,10 +143,13 @@ class IO_Handler(socketserver.BaseRequestHandler):
import uuid
import queue
import threading
- if self.socket_class == PlomSocketSSL:
- plom_socket = self.socket_class(self.request, server_side=True)
+ if self.server.socket_class == PlomSocketSSL:
+ plom_socket = self.server.socket_class(self.request,
+ server_side=True,
+ certfile=self.server.certfile,
+ keyfile=self.server.keyfile)
else:
- plom_socket = self.socket_class(self.request)
+ plom_socket = self.server.socket_class(self.request)
print('CONNECTION FROM:', str(self.client_address))
connection_id = uuid.uuid4()
queue_in = queue.Queue()
@@ -176,13 +173,6 @@ class IO_Handler(socketserver.BaseRequestHandler):
-class IO_HandlerSSL(IO_Handler):
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, socket_class=PlomSocketSSL, **kwargs)
-
-
-
class PlomTCPServer(socketserver.ThreadingTCPServer):
"""Bind together threaded IO handling server and message queue.
@@ -192,8 +182,9 @@ class PlomTCPServer(socketserver.ThreadingTCPServer):
"""
- def __init__(self, queue, port, host='127.0.0.1', io_handler=IO_Handler, *args, **kwargs):
- super().__init__((host, port), io_handler, *args, **kwargs)
+ def __init__(self, queue, port, host='127.0.0.1', *args, **kwargs):
+ super().__init__((host, port), IO_Handler, *args, **kwargs)
+ self.socket_class = PlomSocket
self.queue_out = queue
self.daemon_threads = True # Else, server's threads have daemon=False.
self.clients = {}
@@ -202,5 +193,8 @@ class PlomTCPServer(socketserver.ThreadingTCPServer):
class PlomTCPServerSSL(PlomTCPServer):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, host='0.0.0.0', io_handler=IO_HandlerSSL, **kwargs)
+ def __init__(self, *args, certfile=None, keyfile=None, **kwargs):
+ super().__init__(*args, host='0.0.0.0', **kwargs)
+ self.certfile = certfile
+ self.keyfile = keyfile
+ self.socket_class = PlomSocketSSL