From ca61b0ee39690ef3b4fc33c14655b73d4b31f5f9 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 27 May 2021 01:26:45 +0200
Subject: [PATCH] Refactor client connection code.

---
 plomrogue_client/socket.py | 124 +++++++++++++++++++++++++++++++++++++
 1 file changed, 124 insertions(+)
 create mode 100644 plomrogue_client/socket.py

diff --git a/plomrogue_client/socket.py b/plomrogue_client/socket.py
new file mode 100644
index 0000000..8330153
--- /dev/null
+++ b/plomrogue_client/socket.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python3
+import queue
+import threading
+import time
+import datetime
+from plomrogue.errors import BrokenSocketConnection
+from plomrogue.io_tcp import PlomSocket
+from ws4py.client import WebSocketBaseClient
+
+
+
+class WebSocketClient(WebSocketBaseClient):
+
+    def __init__(self, recv_handler, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.recv_handler = recv_handler
+        self.connect()
+
+    def received_message(self, message):
+        if message.is_text:
+            message = str(message)
+            self.recv_handler(message)
+
+    @property
+    def plom_closed(self):
+        return self.client_terminated
+
+
+
+class PlomSocketClient(PlomSocket):
+
+    def __init__(self, recv_handler, url):
+        import socket
+        self.recv_handler = recv_handler
+        host, port = url.split(':')
+        super().__init__(socket.create_connection((host, port)))
+
+    def close(self):
+        self.socket.close()
+
+    def run(self):
+        import ssl
+        try:
+            for msg in self.recv():
+                if msg == 'NEED_SSL':
+                    self.socket = ssl.wrap_socket(self.socket)
+                    continue
+                self.recv_handler(msg)
+        except BrokenSocketConnection:
+            pass  # we assume socket will be known as dead by now
+
+
+
+class ClientSocket():
+
+    def __init__(self, host, logger=None):
+        self.socket = None
+        self.host = host
+        self.queue = queue.Queue()
+        self.disconnected = True
+        self.force_instant_connect = True
+        self.interval = datetime.timedelta(seconds=5)
+        self.last_ping = datetime.datetime.now() - self.interval
+        self.logger = logger
+
+    def log(self, msg):
+        if self.logger:
+            self.logger(msg)
+
+    def connect(self):
+
+        def handle_recv(msg):
+            if msg == 'BYE':
+                self.socket.close()
+            else:
+                self.queue.put(msg)
+
+        self.log('attempting connect')
+        socket_client_class = PlomSocketClient
+        if self.host.startswith('ws://') or self.host.startswith('wss://'):
+            socket_client_class = WebSocketClient
+        try:
+            self.socket = socket_client_class(handle_recv, self.host)
+            self.socket_thread = threading.Thread(target=self.socket.run)
+            self.socket_thread.start()
+            self.disconnected = False
+            time.sleep(0.1)  # give potential SSL negotation some time …
+            self.log('connected')
+        except ConnectionRefusedError:
+            self.log('server connect failure')
+            self.disconnected = True
+
+    def send(self, msg):
+        try:
+            if self.socket is None:
+                raise BrokenSocketConnection
+            if hasattr(self.socket, 'plom_closed') and self.socket.plom_closed:
+                raise BrokenSocketConnection
+            self.socket.send(msg)
+        except (BrokenPipeError, BrokenSocketConnection):
+            self.log('server disconnected :(')
+            self.disconnected = True
+            self.force_instant_connect = True
+
+    def keep_connection_alive(self):
+        if self.disconnected and self.force_instant_connect:
+            self.force_instant_connect = False
+            self.connect()
+        now = datetime.datetime.now()
+        if now - self.last_ping > self.interval:
+            if self.disconnected:
+                self.connect()
+            else:
+                self.send('PING')
+            self.last_ping = now
+
+    def get_message(self):
+        while True:
+            try:
+                msg = self.queue.get(block=False)
+                yield msg 
+            except queue.Empty:
+                break
+        return None
-- 
2.30.2