66import asyncio
77import collections .abc
88import sys
9+ import urllib .request
910
1011from .exceptions import (
1112 InvalidHandshake , InvalidMessage , InvalidStatusCode , NegotiationError
1819)
1920from .http import USER_AGENT , basic_auth_header , build_headers , read_response
2021from .protocol import WebSocketCommonProtocol
21- from .uri import parse_uri
22+ from .uri import parse_proxy_uri , parse_uri
2223
2324
2425__all__ = ['connect' , 'WebSocketClientProtocol' ]
2526
27+ USE_SYSTEM_PROXY = object ()
28+
2629
2730class WebSocketClientProtocol (WebSocketCommonProtocol ):
2831 """
@@ -195,6 +198,37 @@ def process_subprotocol(headers, available_subprotocols):
195198
196199 return subprotocol
197200
201+ @asyncio .coroutine
202+ def proxy_connect (self , proxy_uri , uri , ssl = None ):
203+ assert ssl is None , "proxying TLS/SSL connections isn't supported yet"
204+
205+ request = ['CONNECT {uri.host}:{uri.port} HTTP/1.1' .format (uri = uri )]
206+
207+ headers = []
208+
209+ if uri .port == (443 if uri .secure else 80 ): # pragma: no cover
210+ headers .append (('Host' , uri .host ))
211+ else :
212+ headers .append (('Host' , '{uri.host}:{uri.port}' .format (uri = uri )))
213+
214+ if proxy_uri .user_info :
215+ headers .append ((
216+ 'Proxy-Authorization' ,
217+ basic_auth_header (* proxy_uri .user_info ),
218+ ))
219+
220+ request .extend ('{}: {}' .format (k , v ) for k , v in headers )
221+ request .append ('\r \n ' )
222+ request = '\r \n ' .join (request ).encode ()
223+
224+ self .writer .write (request )
225+
226+ status_code , headers = yield from read_response (self .reader )
227+
228+ if not 200 <= status_code < 300 :
229+ # TODO improve error handling
230+ raise ValueError ("proxy error: HTTP {}" .format (status_code ))
231+
198232 @asyncio .coroutine
199233 def handshake (self , uri , origin = None , available_extensions = None ,
200234 available_subprotocols = None , extra_headers = None ):
@@ -223,10 +257,10 @@ def handshake(self, uri, origin=None, available_extensions=None,
223257 if uri .port == (443 if uri .secure else 80 ): # pragma: no cover
224258 set_header ('Host' , uri .host )
225259 else :
226- set_header ('Host' , '{}:{}' .format (uri . host , uri . port ))
260+ set_header ('Host' , '{uri.host }:{uri.port }' .format (uri = uri ))
227261
228262 if uri .user_info :
229- set_header (* basic_auth_header (* uri .user_info ))
263+ set_header ('Authorization' , basic_auth_header (* uri .user_info ))
230264
231265 if origin is not None :
232266 set_header ('Origin' , origin )
@@ -318,6 +352,12 @@ class Connect:
318352 * ``compression`` is a shortcut to configure compression extensions;
319353 by default it enables the "permessage-deflate" extension; set it to
320354 ``None`` to disable compression
355+ * ``proxy`` defines the HTTP proxy for establishing the connection; by
356+ default, :func:`connect` uses proxies configured in the environment or
357+ the system (see :func:`~urllib.request.getproxies` for details); set
358+ ``proxy`` to ``None`` to disable this behavior
359+ * ``proxy_ssl`` may be set to a :class:`~ssl.SSLContext` to enforce TLS
360+ settings for connecting to a ``https://`` proxy; it defaults to ``True``
321361
322362 :func:`connect` raises :exc:`~websockets.uri.InvalidURI` if ``uri`` is
323363 invalid and :exc:`~websockets.handshake.InvalidHandshake` if the opening
@@ -331,7 +371,9 @@ def __init__(self, uri, *,
331371 read_limit = 2 ** 16 , write_limit = 2 ** 16 ,
332372 loop = None , legacy_recv = False , klass = None ,
333373 origin = None , extensions = None , subprotocols = None ,
334- extra_headers = None , compression = 'deflate' , ** kwds ):
374+ extra_headers = None , compression = 'deflate' ,
375+ proxy_uri = USE_SYSTEM_PROXY , proxy_ssl = None ,
376+ ssl = None , sock = None , ** kwds ):
335377 if loop is None :
336378 loop = asyncio .get_event_loop ()
337379
@@ -345,10 +387,13 @@ def __init__(self, uri, *,
345387
346388 uri = parse_uri (uri )
347389 if uri .secure :
348- kwds .setdefault ('ssl' , True )
349- elif kwds .get ('ssl' ) is not None :
350- raise ValueError ("connect() received a SSL context for a ws:// "
351- "URI, use a wss:// URI to enable TLS" )
390+ if ssl is None :
391+ ssl = True
392+ elif ssl is not None :
393+ raise ValueError (
394+ "connect() received a TLS/SSL context for a ws:// URI;"
395+ "use a wss:// URI to enable TLS" ,
396+ )
352397
353398 if compression == 'deflate' :
354399 if extensions is None :
@@ -372,18 +417,43 @@ def __init__(self, uri, *,
372417 extra_headers = extra_headers ,
373418 )
374419
375- if kwds .get ('sock' ) is None :
376- host , port = uri .host , uri .port
377- else :
420+ if proxy_uri is USE_SYSTEM_PROXY :
421+ proxies = urllib .request .getproxies ()
422+ if urllib .request .proxy_bypass (
423+ '{uri.host}:{uri.port}' .format (uri = uri )):
424+ proxy_uri = None
425+ else :
426+ # RFC 6455 recommends to prefer the proxy configured for HTTPS
427+ # connections over the proxy configured for HTTP connections.
428+ proxy_uri = proxies .get ('https' , proxies .get ('http' ))
429+
430+ if proxy_uri is not None :
431+ proxy_uri = parse_proxy_uri (proxy_uri )
432+ if proxy_uri .secure :
433+ if proxy_ssl is None :
434+ proxy_ssl = True
435+ elif proxy_ssl is not None :
436+ raise ValueError (
437+ "connect() received a TLS/SSL context for a HTTP proxy; "
438+ "use a HTTPS proxy to enable TLS" ,
439+ )
440+
441+ if sock is not None :
378442 # If sock is given, host and port mustn't be specified.
379- host , port = None , None
443+ conn_host , conn_port , conn_ssl = None , None , ssl
444+ elif proxy_uri is not None :
445+ conn_host , conn_port , conn_ssl = (
446+ proxy_uri .host , proxy_uri .port , proxy_ssl )
447+ else :
448+ conn_host , conn_port , conn_ssl = uri .host , uri .port , ssl
380449
450+ self ._proxy_uri = proxy_uri
381451 self ._uri = uri
382- self ._origin = origin
452+ self ._ssl = ssl
383453
384454 # This is a coroutine object.
385455 self ._creating_connection = loop .create_connection (
386- factory , host , port , ** kwds )
456+ factory , conn_host , conn_port , ssl = conn_ssl , sock = sock , ** kwds )
387457
388458 @asyncio .coroutine
389459 def __aenter__ (self ):
@@ -397,8 +467,12 @@ def __await__(self):
397467 transport , protocol = yield from self ._creating_connection
398468
399469 try :
470+ if self ._proxy_uri is not None :
471+ yield from protocol .proxy_connect (
472+ self ._proxy_uri , self ._uri , self ._ssl )
400473 yield from protocol .handshake (
401- self ._uri , origin = self ._origin ,
474+ self ._uri ,
475+ origin = protocol .origin ,
402476 available_extensions = protocol .available_extensions ,
403477 available_subprotocols = protocol .available_subprotocols ,
404478 extra_headers = protocol .extra_headers ,
0 commit comments