4a7105d17916a7237f3df6e59d65ca82375f8803
[SubU] /
1 import io
2 import socket
3 import ssl
4
5 from ..exceptions import ProxySchemeUnsupported
6 from ..packages import six
7
8 SSL_BLOCKSIZE = 16384
9
10
11 class SSLTransport:
12     """
13     The SSLTransport wraps an existing socket and establishes an SSL connection.
14
15     Contrary to Python's implementation of SSLSocket, it allows you to chain
16     multiple TLS connections together. It's particularly useful if you need to
17     implement TLS within TLS.
18
19     The class supports most of the socket API operations.
20     """
21
22     @staticmethod
23     def _validate_ssl_context_for_tls_in_tls(ssl_context):
24         """
25         Raises a ProxySchemeUnsupported if the provided ssl_context can't be used
26         for TLS in TLS.
27
28         The only requirement is that the ssl_context provides the 'wrap_bio'
29         methods.
30         """
31
32         if not hasattr(ssl_context, "wrap_bio"):
33             if six.PY2:
34                 raise ProxySchemeUnsupported(
35                     "TLS in TLS requires SSLContext.wrap_bio() which isn't "
36                     "supported on Python 2"
37                 )
38             else:
39                 raise ProxySchemeUnsupported(
40                     "TLS in TLS requires SSLContext.wrap_bio() which isn't "
41                     "available on non-native SSLContext"
42                 )
43
44     def __init__(
45         self, socket, ssl_context, server_hostname=None, suppress_ragged_eofs=True
46     ):
47         """
48         Create an SSLTransport around socket using the provided ssl_context.
49         """
50         self.incoming = ssl.MemoryBIO()
51         self.outgoing = ssl.MemoryBIO()
52
53         self.suppress_ragged_eofs = suppress_ragged_eofs
54         self.socket = socket
55
56         self.sslobj = ssl_context.wrap_bio(
57             self.incoming, self.outgoing, server_hostname=server_hostname
58         )
59
60         # Perform initial handshake.
61         self._ssl_io_loop(self.sslobj.do_handshake)
62
63     def __enter__(self):
64         return self
65
66     def __exit__(self, *_):
67         self.close()
68
69     def fileno(self):
70         return self.socket.fileno()
71
72     def read(self, len=1024, buffer=None):
73         return self._wrap_ssl_read(len, buffer)
74
75     def recv(self, len=1024, flags=0):
76         if flags != 0:
77             raise ValueError("non-zero flags not allowed in calls to recv")
78         return self._wrap_ssl_read(len)
79
80     def recv_into(self, buffer, nbytes=None, flags=0):
81         if flags != 0:
82             raise ValueError("non-zero flags not allowed in calls to recv_into")
83         if buffer and (nbytes is None):
84             nbytes = len(buffer)
85         elif nbytes is None:
86             nbytes = 1024
87         return self.read(nbytes, buffer)
88
89     def sendall(self, data, flags=0):
90         if flags != 0:
91             raise ValueError("non-zero flags not allowed in calls to sendall")
92         count = 0
93         with memoryview(data) as view, view.cast("B") as byte_view:
94             amount = len(byte_view)
95             while count < amount:
96                 v = self.send(byte_view[count:])
97                 count += v
98
99     def send(self, data, flags=0):
100         if flags != 0:
101             raise ValueError("non-zero flags not allowed in calls to send")
102         response = self._ssl_io_loop(self.sslobj.write, data)
103         return response
104
105     def makefile(
106         self, mode="r", buffering=None, encoding=None, errors=None, newline=None
107     ):
108         """
109         Python's httpclient uses makefile and buffered io when reading HTTP
110         messages and we need to support it.
111
112         This is unfortunately a copy and paste of socket.py makefile with small
113         changes to point to the socket directly.
114         """
115         if not set(mode) <= {"r", "w", "b"}:
116             raise ValueError("invalid mode %r (only r, w, b allowed)" % (mode,))
117
118         writing = "w" in mode
119         reading = "r" in mode or not writing
120         assert reading or writing
121         binary = "b" in mode
122         rawmode = ""
123         if reading:
124             rawmode += "r"
125         if writing:
126             rawmode += "w"
127         raw = socket.SocketIO(self, rawmode)
128         self.socket._io_refs += 1
129         if buffering is None:
130             buffering = -1
131         if buffering < 0:
132             buffering = io.DEFAULT_BUFFER_SIZE
133         if buffering == 0:
134             if not binary:
135                 raise ValueError("unbuffered streams must be binary")
136             return raw
137         if reading and writing:
138             buffer = io.BufferedRWPair(raw, raw, buffering)
139         elif reading:
140             buffer = io.BufferedReader(raw, buffering)
141         else:
142             assert writing
143             buffer = io.BufferedWriter(raw, buffering)
144         if binary:
145             return buffer
146         text = io.TextIOWrapper(buffer, encoding, errors, newline)
147         text.mode = mode
148         return text
149
150     def unwrap(self):
151         self._ssl_io_loop(self.sslobj.unwrap)
152
153     def close(self):
154         self.socket.close()
155
156     def getpeercert(self, binary_form=False):
157         return self.sslobj.getpeercert(binary_form)
158
159     def version(self):
160         return self.sslobj.version()
161
162     def cipher(self):
163         return self.sslobj.cipher()
164
165     def selected_alpn_protocol(self):
166         return self.sslobj.selected_alpn_protocol()
167
168     def selected_npn_protocol(self):
169         return self.sslobj.selected_npn_protocol()
170
171     def shared_ciphers(self):
172         return self.sslobj.shared_ciphers()
173
174     def compression(self):
175         return self.sslobj.compression()
176
177     def settimeout(self, value):
178         self.socket.settimeout(value)
179
180     def gettimeout(self):
181         return self.socket.gettimeout()
182
183     def _decref_socketios(self):
184         self.socket._decref_socketios()
185
186     def _wrap_ssl_read(self, len, buffer=None):
187         try:
188             return self._ssl_io_loop(self.sslobj.read, len, buffer)
189         except ssl.SSLError as e:
190             if e.errno == ssl.SSL_ERROR_EOF and self.suppress_ragged_eofs:
191                 return 0  # eof, return 0.
192             else:
193                 raise
194
195     def _ssl_io_loop(self, func, *args):
196         """Performs an I/O loop between incoming/outgoing and the socket."""
197         should_loop = True
198         ret = None
199
200         while should_loop:
201             errno = None
202             try:
203                 ret = func(*args)
204             except ssl.SSLError as e:
205                 if e.errno not in (ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE):
206                     # WANT_READ, and WANT_WRITE are expected, others are not.
207                     raise e
208                 errno = e.errno
209
210             buf = self.outgoing.read()
211             self.socket.sendall(buf)
212
213             if errno is None:
214                 should_loop = False
215             elif errno == ssl.SSL_ERROR_WANT_READ:
216                 buf = self.socket.recv(SSL_BLOCKSIZE)
217                 if buf:
218                     self.incoming.write(buf)
219                 else:
220                     self.incoming.write_eof()
221         return ret