home · contact · privacy
Improve SSL negotation.
[plomrogue2-experiments] / new2 / plomrogue / io_tcp.py
1 import socketserver
2
3
4 # Avoid "Address already in use" errors.
5 socketserver.TCPServer.allow_reuse_address = True
6
7
8
9 from plomrogue.errors import BrokenSocketConnection
10 class PlomSocket:
11
12     def __init__(self, socket):
13         self.socket = socket
14
15     def send(self, message, silent_connection_break=False):
16         """Send via self.socket, encoded/delimited as way recv() expects.
17
18         In detail, all \ and $ in message are escaped with prefixed \,
19         and an unescaped $ is appended as a message delimiter. Then,
20         socket.send() is called as often as necessary to ensure
21         message is sent fully, as socket.send() due to buffering may
22         not send all of it right away.
23
24         Assuming socket is blocking, it's rather improbable that
25         socket.send() will be partial / return a positive value less
26         than the (byte) length of msg – but not entirely out of the
27         question. See: - <http://stackoverflow.com/q/19697218> -
28         <http://stackoverflow.com/q/2618736> -
29         <http://stackoverflow.com/q/8900474>
30
31         This also handles a socket.send() return value of 0, which
32         might be possible or not (?) for blocking sockets: -
33         <http://stackoverflow.com/q/34919846>
34
35         """
36         escaped_message = ''
37         for char in message:
38             if char in ('\\', '$'):
39                 escaped_message += '\\'
40             escaped_message += char
41         escaped_message += '$'
42         data = escaped_message.encode()
43         totalsent = 0
44         while totalsent < len(data):
45             socket_broken = False
46             try:
47                 sent = self.socket.send(data[totalsent:])
48                 socket_broken = sent == 0
49                 totalsent = totalsent + sent
50             except OSError as err:
51                 if err.errno == 9:  # "Bad file descriptor", when connection broken
52                     socket_broken = True
53                 else:
54                     raise err
55             if socket_broken and not silent_connection_break:
56                 raise BrokenSocketConnection
57
58     def recv(self):
59         """Get full send()-prepared message from self.socket.
60
61         In detail, socket.recv() is looped over for sequences of bytes
62         that can be decoded as a Unicode string delimited by an
63         unescaped $, with \ and $ escapable by \. If a sequence of
64         characters that ends in an unescaped $ cannot be decoded as
65         Unicode, None is returned as its representation. Stop once
66         socket.recv() returns nothing.
67
68         Under the hood, the TCP stack receives packets that construct
69         the input payload in an internal buffer; socket.recv(BUFSIZE)
70         pops up to BUFSIZE bytes from that buffer, without knowledge
71         either about the input's segmentation into packets, or whether
72         the input is segmented in any other meaningful way; that's why
73         we do our own message segmentation with $ as a delimiter.
74
75         """
76         esc = False
77         data = b''
78         msg = b''
79         while True:
80             try:
81                 data = self.socket.recv(1024)
82             except OSError as err:
83                 if err.errno == 9:  # "Bad file descriptor", when connection broken
84                     raise BrokenSocketConnection
85             if 0 == len(data):
86                 break
87             for c in data:
88                 if esc:
89                     msg += bytes([c])
90                     esc = False
91                 elif chr(c) == '\\':
92                     esc = True
93                 elif chr(c) == '$':
94                     try:
95                         yield msg.decode()
96                     except UnicodeDecodeError:
97                         yield None
98                     msg = b''
99                 else:
100                     msg += bytes([c])
101
102
103
104 class PlomSocketSSL(PlomSocket):
105
106     def __init__(self, *args, certfile, keyfile, **kwargs):
107         import ssl
108         super().__init__(*args, **kwargs)
109         self.send('NEED_SSL')
110         self.socket = ssl.wrap_socket(self.socket, server_side=True,
111                                       certfile=certfile, keyfile=keyfile)
112
113
114
115 class IO_Handler(socketserver.BaseRequestHandler):
116
117     def __init__(self, *args, **kwargs):
118         super().__init__(*args, **kwargs)
119
120     def handle(self):
121         """Move messages between network socket and game IO loop via queues.
122
123         On start (a new connection from client to server), sets up a
124         new queue, sends it via self.server.queue_out to the game IO
125         loop thread, and from then on receives messages to send back
126         from the game IO loop via that new queue.
127
128         At the same time, loops over socket's recv to get messages
129         from the outside into the game IO loop by way of
130         self.server.queue_out into the game IO. Ends connection once a
131         'QUIT' message is received from socket, and then also calls
132         for a kill of its own queue.
133
134         """
135
136         def send_queue_messages(plom_socket, queue_in, thread_alive):
137             """Send messages via socket from queue_in while thread_alive[0]."""
138             while thread_alive[0]:
139                 try:
140                     msg = queue_in.get(timeout=1)
141                 except queue.Empty:
142                     continue
143                 plom_socket.send(msg, True)
144
145         import uuid
146         import queue
147         import threading
148         if self.server.socket_class == PlomSocketSSL:
149             plom_socket = self.server.socket_class(self.request,
150                                                    certfile=self.server.certfile,
151                                                    keyfile=self.server.keyfile)
152         else:
153             plom_socket = self.server.socket_class(self.request)
154         print('CONNECTION FROM:', str(self.client_address))
155         connection_id = uuid.uuid4()
156         queue_in = queue.Queue()
157         self.server.clients[connection_id] = queue_in
158         thread_alive = [True]
159         t = threading.Thread(target=send_queue_messages,
160                              args=(plom_socket, queue_in, thread_alive))
161         t.start()
162         for message in plom_socket.recv():
163             if message is None:
164                 plom_socket.send('BAD MESSAGE', True)
165             elif 'QUIT' == message:
166                 plom_socket.send('BYE', True)
167                 break
168             else:
169                 self.server.queue_out.put((connection_id, message))
170         del self.server.clients[connection_id]
171         thread_alive[0] = False
172         print('CONNECTION CLOSED FROM:', str(self.client_address))
173         plom_socket.socket.close()
174
175
176
177 class PlomTCPServer(socketserver.ThreadingTCPServer):
178     """Bind together threaded IO handling server and message queue.
179
180     By default this only serves to localhost connections.  For remote
181     connections, consider using PlomTCPServerSSL for more security,
182     which defaults to serving all connections.
183
184     """
185
186     def __init__(self, queue, port, host='127.0.0.1', *args, **kwargs):
187         super().__init__((host, port), IO_Handler, *args, **kwargs)
188         self.socket_class = PlomSocket
189         self.queue_out = queue
190         self.daemon_threads = True  # Else, server's threads have daemon=False.
191         self.clients = {}
192
193
194
195 class PlomTCPServerSSL(PlomTCPServer):
196
197     def __init__(self, *args, certfile, keyfile, **kwargs):
198         super().__init__(*args, host='0.0.0.0', **kwargs)
199         self.certfile = certfile
200         self.keyfile = keyfile
201         self.socket_class = PlomSocketSSL