home · contact · privacy
Add SSL capabilities to TCP socket library.
[plomrogue2-experiments] / new2 / plomrogue / io_tcp.py
1 import socketserver
2
3
4 # Avoid "Address already in use" errors.
5 socketserver.TCPServer.allow_reuse_address = True
6
7
8
9 class PlomSocket:
10
11     def __init__(self, socket):
12         self.socket = socket
13
14     def send(self, message, silent_connection_break=False):
15         """Send via self.socket, encoded/delimited as way recv() expects.
16
17         In detail, all \ and $ in message are escaped with prefixed \,
18         and an unescaped $ is appended as a message delimiter. Then,
19         socket.send() is called as often as necessary to ensure
20         message is sent fully, as socket.send() due to buffering may
21         not send all of it right away.
22
23         Assuming socket is blocking, it's rather improbable that
24         socket.send() will be partial / return a positive value less
25         than the (byte) length of msg – but not entirely out of the
26         question. See: - <http://stackoverflow.com/q/19697218> -
27         <http://stackoverflow.com/q/2618736> -
28         <http://stackoverflow.com/q/8900474>
29
30         This also handles a socket.send() return value of 0, which
31         might be possible or not (?) for blocking sockets: -
32         <http://stackoverflow.com/q/34919846>
33
34         """
35         from plomrogue.errors import BrokenSocketConnection
36         escaped_message = ''
37         for char in message:
38             if char in ('\\', '$'):
39                 escaped_message += '\\'
40             escaped_message += char
41         escaped_message += '$'
42         data = escaped_message.encode()
43         totalsent = 0
44         while totalsent < len(data):
45             socket_broken = False
46             try:
47                 sent = self.socket.send(data[totalsent:])
48                 socket_broken = sent == 0
49             except OSError as err:
50                 if err.errno == 9:  # "Bad file descriptor", when connection broken
51                     socket_broken = True
52                 else:
53                     raise err
54             if socket_broken and not silent_connection_break:
55                 raise BrokenSocketConnection
56             totalsent = totalsent + sent
57
58     def recv(self):
59         """Get full send()-prepared message from self.socket.
60
61         In detail, socket.recv() is looped over for sequences of bytes
62         that can be decoded as a Unicode string delimited by an
63         unescaped $, with \ and $ escapable by \. If a sequence of
64         characters that ends in an unescaped $ cannot be decoded as
65         Unicode, None is returned as its representation. Stop once
66         socket.recv() returns nothing.
67
68         Under the hood, the TCP stack receives packets that construct
69         the input payload in an internal buffer; socket.recv(BUFSIZE)
70         pops up to BUFSIZE bytes from that buffer, without knowledge
71         either about the input's segmentation into packets, or whether
72         the input is segmented in any other meaningful way; that's why
73         we do our own message segmentation with $ as a delimiter.
74
75         """
76         esc = False
77         data = b''
78         msg = b''
79         while True:
80             data += self.socket.recv(1024)
81             if 0 == len(data):
82                 return
83             cut_off = 0
84             for c in data:
85                 cut_off += 1
86                 if esc:
87                     msg += bytes([c])
88                     esc = False
89                 elif chr(c) == '\\':
90                     esc = True
91                 elif chr(c) == '$':
92                     try:
93                         yield msg.decode()
94                     except UnicodeDecodeError:
95                         yield None
96                     data = data[cut_off:]
97                     msg = b''
98                 else:
99                     msg += bytes([c])
100
101
102
103 class PlomSocketSSL(PlomSocket):
104
105     def __init__(self, *args, server_side=False, **kwargs):
106         import ssl
107         print('DEBUG', args, kwargs)
108         super().__init__(*args, **kwargs)
109         if server_side:
110             self.socket = ssl.wrap_socket(self.socket, server_side=True,
111                                           certfile="server.pem",
112                                           keyfile="key.pem")
113         else:
114             self.socket = ssl.wrap_socket(self.socket)
115
116
117
118 class IO_Handler(socketserver.BaseRequestHandler):
119
120     def __init__(self, *args, socket_class=PlomSocket, **kwargs):
121         self.socket_class = socket_class
122         super().__init__(*args, **kwargs)
123
124     def handle(self):
125         """Move messages between network socket and game IO loop via queues.
126
127         On start (a new connection from client to server), sets up a
128         new queue, sends it via self.server.queue_out to the game IO
129         loop thread, and from then on receives messages to send back
130         from the game IO loop via that new queue.
131
132         At the same time, loops over socket's recv to get messages
133         from the outside into the game IO loop by way of
134         self.server.queue_out into the game IO. Ends connection once a
135         'QUIT' message is received from socket, and then also calls
136         for a kill of its own queue.
137
138         """
139
140         def send_queue_messages(plom_socket, queue_in, thread_alive):
141             """Send messages via socket from queue_in while thread_alive[0]."""
142             while thread_alive[0]:
143                 try:
144                     msg = queue_in.get(timeout=1)
145                 except queue.Empty:
146                     continue
147                 plom_socket.send(msg, True)
148
149         import uuid
150         import queue
151         import threading
152         if self.socket_class == PlomSocketSSL:
153             plom_socket = self.socket_class(self.request, server_side=True)
154         else:
155             plom_socket = self.socket_class(self.request)
156         print('CONNECTION FROM:', str(self.client_address))
157         connection_id = uuid.uuid4()
158         queue_in = queue.Queue()
159         self.server.clients[connection_id] = queue_in
160         thread_alive = [True]
161         t = threading.Thread(target=send_queue_messages,
162                              args=(plom_socket, queue_in, thread_alive))
163         t.start()
164         for message in plom_socket.recv():
165             if message is None:
166                 plom_socket.send('BAD MESSAGE', True)
167             elif 'QUIT' == message:
168                 plom_socket.send('BYE', True)
169                 break
170             else:
171                 self.server.queue_out.put((connection_id, message))
172         del self.server.clients[connection_id]
173         thread_alive[0] = False
174         print('CONNECTION CLOSED FROM:', str(self.client_address))
175         plom_socket.socket.close()
176
177
178
179 class IO_HandlerSSL(IO_Handler):
180
181     def __init__(self, *args, **kwargs):
182         super().__init__(*args, socket_class=PlomSocketSSL, **kwargs)
183
184
185
186 class PlomTCPServer(socketserver.ThreadingTCPServer):
187     """Bind together threaded IO handling server and message queue.
188
189     By default this only serves to localhost connections.  For remote
190     connections, consider using PlomTCPServerSSL for more security,
191     which defaults to serving all connections.
192
193     """
194
195     def __init__(self, queue, port, host='127.0.0.1', io_handler=IO_Handler, *args, **kwargs):
196         super().__init__((host, port), io_handler, *args, **kwargs)
197         self.queue_out = queue
198         self.daemon_threads = True  # Else, server's threads have daemon=False.
199         self.clients = {}
200
201
202
203 class PlomTCPServerSSL(PlomTCPServer):
204
205     def __init__(self, *args, **kwargs):
206         super().__init__(*args, host='0.0.0.0', io_handler=IO_HandlerSSL, **kwargs)