home · contact · privacy
f0a49a97dad632fa4f3dc10770432f424ba90155
[plomrogue2-experiments] / new2 / plomrogue / io_tcp.py
1 import socketserver
2
3
4 # Avoid "Address already in use" errors.
5 socketserver.TCPServer.allow_reuse_address = True
6
7
8
9 class PlomSocket:
10
11     def __init__(self, socket):
12         self.socket = socket
13
14     def send(self, message, silent_connection_break=False):
15         """Send via self.socket, encoded/delimited as way recv() expects.
16
17         In detail, all \ and $ in message are escaped with prefixed \,
18         and an unescaped $ is appended as a message delimiter. Then,
19         socket.send() is called as often as necessary to ensure
20         message is sent fully, as socket.send() due to buffering may
21         not send all of it right away.
22
23         Assuming socket is blocking, it's rather improbable that
24         socket.send() will be partial / return a positive value less
25         than the (byte) length of msg – but not entirely out of the
26         question. See: - <http://stackoverflow.com/q/19697218> -
27         <http://stackoverflow.com/q/2618736> -
28         <http://stackoverflow.com/q/8900474>
29
30         This also handles a socket.send() return value of 0, which
31         might be possible or not (?) for blocking sockets: -
32         <http://stackoverflow.com/q/34919846>
33
34         """
35         from plomrogue.errors import BrokenSocketConnection
36         escaped_message = ''
37         for char in message:
38             if char in ('\\', '$'):
39                 escaped_message += '\\'
40             escaped_message += char
41         escaped_message += '$'
42         data = escaped_message.encode()
43         totalsent = 0
44         while totalsent < len(data):
45             socket_broken = False
46             try:
47                 sent = self.socket.send(data[totalsent:])
48                 socket_broken = sent == 0
49                 totalsent = totalsent + sent
50             except OSError as err:
51                 if err.errno == 9:  # "Bad file descriptor", when connection broken
52                     socket_broken = True
53                 else:
54                     raise err
55             if socket_broken and not silent_connection_break:
56                 raise BrokenSocketConnection
57
58     def recv(self):
59         """Get full send()-prepared message from self.socket.
60
61         In detail, socket.recv() is looped over for sequences of bytes
62         that can be decoded as a Unicode string delimited by an
63         unescaped $, with \ and $ escapable by \. If a sequence of
64         characters that ends in an unescaped $ cannot be decoded as
65         Unicode, None is returned as its representation. Stop once
66         socket.recv() returns nothing.
67
68         Under the hood, the TCP stack receives packets that construct
69         the input payload in an internal buffer; socket.recv(BUFSIZE)
70         pops up to BUFSIZE bytes from that buffer, without knowledge
71         either about the input's segmentation into packets, or whether
72         the input is segmented in any other meaningful way; that's why
73         we do our own message segmentation with $ as a delimiter.
74
75         """
76         esc = False
77         data = b''
78         msg = b''
79         while True:
80             data = self.socket.recv(1024)
81             if 0 == len(data):
82                 break
83             for c in data:
84                 if esc:
85                     msg += bytes([c])
86                     esc = False
87                 elif chr(c) == '\\':
88                     esc = True
89                 elif chr(c) == '$':
90                     try:
91                         yield msg.decode()
92                     except UnicodeDecodeError:
93                         yield None
94                     msg = b''
95                 else:
96                     msg += bytes([c])
97
98
99
100 class PlomSocketSSL(PlomSocket):
101
102     def __init__(self, *args, server_side=False, certfile=None, keyfile=None, **kwargs):
103         import ssl
104         super().__init__(*args, **kwargs)
105         if server_side:
106             self.socket = ssl.wrap_socket(self.socket, server_side=True,
107                                           certfile=certfile, keyfile=keyfile)
108         else:
109             self.socket = ssl.wrap_socket(self.socket)
110
111
112
113 class IO_Handler(socketserver.BaseRequestHandler):
114
115     def __init__(self, *args, **kwargs):
116         super().__init__(*args, **kwargs)
117
118     def handle(self):
119         """Move messages between network socket and game IO loop via queues.
120
121         On start (a new connection from client to server), sets up a
122         new queue, sends it via self.server.queue_out to the game IO
123         loop thread, and from then on receives messages to send back
124         from the game IO loop via that new queue.
125
126         At the same time, loops over socket's recv to get messages
127         from the outside into the game IO loop by way of
128         self.server.queue_out into the game IO. Ends connection once a
129         'QUIT' message is received from socket, and then also calls
130         for a kill of its own queue.
131
132         """
133
134         def send_queue_messages(plom_socket, queue_in, thread_alive):
135             """Send messages via socket from queue_in while thread_alive[0]."""
136             while thread_alive[0]:
137                 try:
138                     msg = queue_in.get(timeout=1)
139                 except queue.Empty:
140                     continue
141                 plom_socket.send(msg, True)
142
143         import uuid
144         import queue
145         import threading
146         if self.server.socket_class == PlomSocketSSL:
147             plom_socket = self.server.socket_class(self.request,
148                                                    server_side=True,
149                                                    certfile=self.server.certfile,
150                                                    keyfile=self.server.keyfile)
151         else:
152             plom_socket = self.server.socket_class(self.request)
153         print('CONNECTION FROM:', str(self.client_address))
154         connection_id = uuid.uuid4()
155         queue_in = queue.Queue()
156         self.server.clients[connection_id] = queue_in
157         thread_alive = [True]
158         t = threading.Thread(target=send_queue_messages,
159                              args=(plom_socket, queue_in, thread_alive))
160         t.start()
161         for message in plom_socket.recv():
162             if message is None:
163                 plom_socket.send('BAD MESSAGE', True)
164             elif 'QUIT' == message:
165                 plom_socket.send('BYE', True)
166                 break
167             else:
168                 self.server.queue_out.put((connection_id, message))
169         del self.server.clients[connection_id]
170         thread_alive[0] = False
171         print('CONNECTION CLOSED FROM:', str(self.client_address))
172         plom_socket.socket.close()
173
174
175
176 class PlomTCPServer(socketserver.ThreadingTCPServer):
177     """Bind together threaded IO handling server and message queue.
178
179     By default this only serves to localhost connections.  For remote
180     connections, consider using PlomTCPServerSSL for more security,
181     which defaults to serving all connections.
182
183     """
184
185     def __init__(self, queue, port, host='127.0.0.1', *args, **kwargs):
186         super().__init__((host, port), IO_Handler, *args, **kwargs)
187         self.socket_class = PlomSocket
188         self.queue_out = queue
189         self.daemon_threads = True  # Else, server's threads have daemon=False.
190         self.clients = {}
191
192
193
194 class PlomTCPServerSSL(PlomTCPServer):
195
196     def __init__(self, *args, certfile=None, keyfile=None, **kwargs):
197         super().__init__(*args, host='0.0.0.0', **kwargs)
198         self.certfile = certfile
199         self.keyfile = keyfile
200         self.socket_class = PlomSocketSSL