Skip to content

Commit 1e83cfb

Browse files
committed
Add basic support for proxies.
This is missing proper error handling, tests and support for WSS.
1 parent f3df6ae commit 1e83cfb

File tree

4 files changed

+92
-18
lines changed

4 files changed

+92
-18
lines changed

docs/api.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Client
5454

5555
.. automodule:: websockets.client
5656

57-
.. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', **kwds)
57+
.. autofunction:: connect(uri, *, create_protocol=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None, compression='deflate', proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None, **kwds)
5858

5959
.. autoclass:: WebSocketClientProtocol(*, host=None, port=None, secure=None, timeout=10, max_size=2 ** 20, max_queue=2 ** 5, read_limit=2 ** 16, write_limit=2 ** 16, loop=None, origin=None, extensions=None, subprotocols=None, extra_headers=None)
6060

websockets/client.py

Lines changed: 89 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import asyncio
77
import collections.abc
88
import sys
9+
import urllib.request
910

1011
from .exceptions import (
1112
InvalidHandshake, InvalidMessage, InvalidStatusCode, NegotiationError
@@ -18,11 +19,13 @@
1819
)
1920
from .http import USER_AGENT, basic_auth_header, build_headers, read_response
2021
from .protocol import WebSocketCommonProtocol
21-
from .uri import parse_uri
22+
from .uri import parse_proxy_uri, parse_uri
2223

2324

2425
__all__ = ['connect', 'WebSocketClientProtocol']
2526

27+
USE_SYSTEM_PROXY = object()
28+
2629

2730
class WebSocketClientProtocol(WebSocketCommonProtocol):
2831
"""
@@ -195,6 +198,37 @@ def process_subprotocol(headers, available_subprotocols):
195198

196199
return subprotocol
197200

201+
@asyncio.coroutine
202+
def proxy_connect(self, proxy_uri, uri, ssl=None):
203+
assert ssl is None, "proxying TLS/SSL connections isn't supported yet"
204+
205+
request = ['CONNECT {uri.host}:{uri.port} HTTP/1.1'.format(uri=uri)]
206+
207+
headers = []
208+
209+
if uri.port == (443 if uri.secure else 80): # pragma: no cover
210+
headers.append(('Host', uri.host))
211+
else:
212+
headers.append(('Host', '{uri.host}:{uri.port}'.format(uri=uri)))
213+
214+
if proxy_uri.user_info:
215+
headers.append((
216+
'Proxy-Authorization',
217+
basic_auth_header(*proxy_uri.user_info),
218+
))
219+
220+
request.extend('{}: {}'.format(k, v) for k, v in headers)
221+
request.append('\r\n')
222+
request = '\r\n'.join(request).encode()
223+
224+
self.writer.write(request)
225+
226+
status_code, headers = yield from read_response(self.reader)
227+
228+
if not 200 <= status_code < 300:
229+
# TODO improve error handling
230+
raise ValueError("proxy error: HTTP {}".format(status_code))
231+
198232
@asyncio.coroutine
199233
def handshake(self, uri, origin=None, available_extensions=None,
200234
available_subprotocols=None, extra_headers=None):
@@ -223,10 +257,10 @@ def handshake(self, uri, origin=None, available_extensions=None,
223257
if uri.port == (443 if uri.secure else 80): # pragma: no cover
224258
set_header('Host', uri.host)
225259
else:
226-
set_header('Host', '{}:{}'.format(uri.host, uri.port))
260+
set_header('Host', '{uri.host}:{uri.port}'.format(uri=uri))
227261

228262
if uri.user_info:
229-
set_header(*basic_auth_header(*uri.user_info))
263+
set_header('Authorization', basic_auth_header(*uri.user_info))
230264

231265
if origin is not None:
232266
set_header('Origin', origin)
@@ -318,6 +352,12 @@ class Connect:
318352
* ``compression`` is a shortcut to configure compression extensions;
319353
by default it enables the "permessage-deflate" extension; set it to
320354
``None`` to disable compression
355+
* ``proxy`` defines the HTTP proxy for establishing the connection; by
356+
default, :func:`connect` uses proxies configured in the environment or
357+
the system (see :func:`~urllib.request.getproxies` for details); set
358+
``proxy`` to ``None`` to disable this behavior
359+
* ``proxy_ssl`` may be set to a :class:`~ssl.SSLContext` to enforce TLS
360+
settings for connecting to a ``https://`` proxy; it defaults to ``True``
321361
322362
:func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is
323363
invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening
@@ -331,7 +371,9 @@ def __init__(self, uri, *,
331371
read_limit=2 ** 16, write_limit=2 ** 16,
332372
loop=None, legacy_recv=False, klass=None,
333373
origin=None, extensions=None, subprotocols=None,
334-
extra_headers=None, compression='deflate', **kwds):
374+
extra_headers=None, compression='deflate',
375+
proxy_uri=USE_SYSTEM_PROXY, proxy_ssl=None,
376+
ssl=None, sock=None, **kwds):
335377
if loop is None:
336378
loop = asyncio.get_event_loop()
337379

@@ -345,10 +387,13 @@ def __init__(self, uri, *,
345387

346388
uri = parse_uri(uri)
347389
if uri.secure:
348-
kwds.setdefault('ssl', True)
349-
elif kwds.get('ssl') is not None:
350-
raise ValueError("connect() received a SSL context for a ws:// "
351-
"URI, use a wss:// URI to enable TLS")
390+
if ssl is None:
391+
ssl = True
392+
elif ssl is not None:
393+
raise ValueError(
394+
"connect() received a TLS/SSL context for a ws:// URI;"
395+
"use a wss:// URI to enable TLS",
396+
)
352397

353398
if compression == 'deflate':
354399
if extensions is None:
@@ -372,18 +417,43 @@ def __init__(self, uri, *,
372417
extra_headers=extra_headers,
373418
)
374419

375-
if kwds.get('sock') is None:
376-
host, port = uri.host, uri.port
377-
else:
420+
if proxy_uri is USE_SYSTEM_PROXY:
421+
proxies = urllib.request.getproxies()
422+
if urllib.request.proxy_bypass(
423+
'{uri.host}:{uri.port}'.format(uri=uri)):
424+
proxy_uri = None
425+
else:
426+
# RFC 6455 recommends to prefer the proxy configured for HTTPS
427+
# connections over the proxy configured for HTTP connections.
428+
proxy_uri = proxies.get('https', proxies.get('http'))
429+
430+
if proxy_uri is not None:
431+
proxy_uri = parse_proxy_uri(proxy_uri)
432+
if proxy_uri.secure:
433+
if proxy_ssl is None:
434+
proxy_ssl = True
435+
elif proxy_ssl is not None:
436+
raise ValueError(
437+
"connect() received a TLS/SSL context for a HTTP proxy; "
438+
"use a HTTPS proxy to enable TLS",
439+
)
440+
441+
if sock is not None:
378442
# If sock is given, host and port mustn't be specified.
379-
host, port = None, None
443+
conn_host, conn_port, conn_ssl = None, None, ssl
444+
elif proxy_uri is not None:
445+
conn_host, conn_port, conn_ssl = (
446+
proxy_uri.host, proxy_uri.port, proxy_ssl)
447+
else:
448+
conn_host, conn_port, conn_ssl = uri.host, uri.port, ssl
380449

450+
self._proxy_uri = proxy_uri
381451
self._uri = uri
382-
self._origin = origin
452+
self._ssl = ssl
383453

384454
# This is a coroutine object.
385455
self._creating_connection = loop.create_connection(
386-
factory, host, port, **kwds)
456+
factory, conn_host, conn_port, ssl=conn_ssl, sock=sock, **kwds)
387457

388458
@asyncio.coroutine
389459
def __aenter__(self):
@@ -397,8 +467,12 @@ def __await__(self):
397467
transport, protocol = yield from self._creating_connection
398468

399469
try:
470+
if self._proxy_uri is not None:
471+
yield from protocol.proxy_connect(
472+
self._proxy_uri, self._uri, self._ssl)
400473
yield from protocol.handshake(
401-
self._uri, origin=self._origin,
474+
self._uri,
475+
origin=protocol.origin,
402476
available_extensions=protocol.available_extensions,
403477
available_subprotocols=protocol.available_subprotocols,
404478
extra_headers=protocol.extra_headers,

websockets/http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,4 +218,4 @@ def basic_auth_header(username, password):
218218
assert ':' not in username
219219
user_pass = '{}:{}'.format(username, password)
220220
basic_credentials = base64.b64encode(user_pass.encode()).decode()
221-
return ('Authorization', 'Basic ' + basic_credentials)
221+
return 'Basic ' + basic_credentials

websockets/test_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,5 +133,5 @@ def test_basic_auth_header(self):
133133
# Test vector from RFC 7617.
134134
self.assertEqual(
135135
basic_auth_header("Aladdin", "open sesame"),
136-
('Authorization', 'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='),
136+
'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==',
137137
)

0 commit comments

Comments
 (0)