From 478d293913c37ed1bc98ab65db9658c58d6f7081 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 9 Nov 2020 23:55:46 +0100
Subject: [PATCH] Improve SSL negotation.

---
 new2/plomrogue/io_tcp.py  | 13 +++++--------
 new2/rogue_chat_curses.py |  6 +++++-
 2 files changed, 10 insertions(+), 9 deletions(-)

diff --git a/new2/plomrogue/io_tcp.py b/new2/plomrogue/io_tcp.py
index b030f1b..09b9db1 100644
--- a/new2/plomrogue/io_tcp.py
+++ b/new2/plomrogue/io_tcp.py
@@ -103,14 +103,12 @@ class PlomSocket:
 
 class PlomSocketSSL(PlomSocket):
 
-    def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs):
+    def __init__(self, *args, certfile, keyfile, **kwargs):
         import ssl
         super().__init__(*args, **kwargs)
-        if server_side:
-            self.socket = ssl.wrap_socket(self.socket, server_side=True,
-                                          certfile=certfile, keyfile=keyfile)
-        else:
-            self.socket = ssl.wrap_socket(self.socket)
+        self.send('NEED_SSL')
+        self.socket = ssl.wrap_socket(self.socket, server_side=True,
+                                      certfile=certfile, keyfile=keyfile)
 
 
 
@@ -149,7 +147,6 @@ class IO_Handler(socketserver.BaseRequestHandler):
         import threading
         if self.server.socket_class == PlomSocketSSL:
             plom_socket = self.server.socket_class(self.request,
-                                                   server_side=True,
                                                    certfile=self.server.certfile,
                                                    keyfile=self.server.keyfile)
         else:
@@ -197,7 +194,7 @@ class PlomTCPServer(socketserver.ThreadingTCPServer):
 
 class PlomTCPServerSSL(PlomTCPServer):
 
-    def __init__(self, *args, certfile=None, keyfile=None, **kwargs):
+    def __init__(self, *args, certfile, keyfile, **kwargs):
         super().__init__(*args, host='0.0.0.0', **kwargs)
         self.certfile = certfile
         self.keyfile = keyfile
diff --git a/new2/rogue_chat_curses.py b/new2/rogue_chat_curses.py
index 15d0053..c8a17c8 100755
--- a/new2/rogue_chat_curses.py
+++ b/new2/rogue_chat_curses.py
@@ -39,8 +39,12 @@ class PlomSocketClient(PlomSocket):
         self.socket.close()
 
     def run(self):
+        import ssl
         try:
             for msg in self.recv():
+                if msg == 'NEED_SSL':
+                    self.socket = ssl.wrap_socket(self.socket)
+                    continue
                 self.recv_handler(msg)
         except BrokenSocketConnection:
             pass  # we assume socket will be known as dead by now
@@ -606,4 +610,4 @@ class TUI:
                 self.send('TASK:WRITE ' + key)
                 self.switch_mode('play')
 
-TUI('127.0.0.1:5000')
+TUI('localhost:5000')
-- 
2.30.2