X-Git-Url: https://plomlompom.com/repos/foo.html?a=blobdiff_plain;f=new2%2Fplomrogue%2Fio_tcp.py;h=09b9db1bbf3d0f47539a2c4aaea27aacd9c0c18c;hb=8f4f247a8c36610a5cd4eb03ddb26dcc701e38ab;hp=203021cd06b9b7c34e2eab8efa8749c1e6188155;hpb=16f6849a62d13dd8b7a1104258139cef9462eb8a;p=plomrogue2-experiments
diff --git a/new2/plomrogue/io_tcp.py b/new2/plomrogue/io_tcp.py
index 203021c..09b9db1 100644
--- a/new2/plomrogue/io_tcp.py
+++ b/new2/plomrogue/io_tcp.py
@@ -6,6 +6,7 @@ socketserver.TCPServer.allow_reuse_address = True
+from plomrogue.errors import BrokenSocketConnection
class PlomSocket:
def __init__(self, socket):
@@ -32,7 +33,6 @@ class PlomSocket:
"""
- from plomrogue.errors import BrokenSocketConnection
escaped_message = ''
for char in message:
if char in ('\\', '$'):
@@ -46,6 +46,7 @@ class PlomSocket:
try:
sent = self.socket.send(data[totalsent:])
socket_broken = sent == 0
+ totalsent = totalsent + sent
except OSError as err:
if err.errno == 9: # "Bad file descriptor", when connection broken
socket_broken = True
@@ -53,7 +54,6 @@ class PlomSocket:
raise err
if socket_broken and not silent_connection_break:
raise BrokenSocketConnection
- totalsent = totalsent + sent
def recv(self):
"""Get full send()-prepared message from self.socket.
@@ -77,12 +77,14 @@ class PlomSocket:
data = b''
msg = b''
while True:
- data += self.socket.recv(1024)
+ try:
+ data = self.socket.recv(1024)
+ except OSError as err:
+ if err.errno == 9: # "Bad file descriptor", when connection broken
+ raise BrokenSocketConnection
if 0 == len(data):
- return
- cut_off = 0
+ break
for c in data:
- cut_off += 1
if esc:
msg += bytes([c])
esc = False
@@ -93,7 +95,6 @@ class PlomSocket:
yield msg.decode()
except UnicodeDecodeError:
yield None
- data = data[cut_off:]
msg = b''
else:
msg += bytes([c])
@@ -102,14 +103,12 @@ class PlomSocket:
class PlomSocketSSL(PlomSocket):
- def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs):
+ def __init__(self, *args, certfile, keyfile, **kwargs):
import ssl
super().__init__(*args, **kwargs)
- if server_side:
- self.socket = ssl.wrap_socket(self.socket, server_side=True,
- certfile=certfile, keyfile=keyfile)
- else:
- self.socket = ssl.wrap_socket(self.socket)
+ self.send('NEED_SSL')
+ self.socket = ssl.wrap_socket(self.socket, server_side=True,
+ certfile=certfile, keyfile=keyfile)
@@ -148,7 +147,6 @@ class IO_Handler(socketserver.BaseRequestHandler):
import threading
if self.server.socket_class == PlomSocketSSL:
plom_socket = self.server.socket_class(self.request,
- server_side=True,
certfile=self.server.certfile,
keyfile=self.server.keyfile)
else:
@@ -196,7 +194,7 @@ class PlomTCPServer(socketserver.ThreadingTCPServer):
class PlomTCPServerSSL(PlomTCPServer):
- def __init__(self, *args, certfile=None, keyfile=None, **kwargs):
+ def __init__(self, *args, certfile, keyfile, **kwargs):
super().__init__(*args, host='0.0.0.0', **kwargs)
self.certfile = certfile
self.keyfile = keyfile