77import urllib .parse
88from collections .abc import AsyncIterator , Generator , Sequence
99from types import TracebackType
10- from typing import Any , Callable
10+ from typing import Any , Callable , Literal
1111
1212from ..client import ClientProtocol , backoff
1313from ..datastructures import HeadersLike
1818from ..http11 import USER_AGENT , Response
1919from ..protocol import CONNECTING , Event
2020from ..typing import LoggerLike , Origin , Subprotocol
21- from ..uri import WebSocketURI , parse_uri
21+ from ..uri import Proxy , WebSocketURI , get_proxy , parse_proxy , parse_uri
2222from .compatibility import TimeoutError , asyncio_timeout
2323from .connection import Connection
2424
@@ -208,6 +208,10 @@ class connect:
208208 user_agent_header: Value of the ``User-Agent`` request header.
209209 It defaults to ``"Python/x.y.z websockets/X.Y"``.
210210 Setting it to :obj:`None` removes the header.
211+ proxy: If a proxy is configured, it is used by default. Set ``proxy``
212+ to :obj:`None` to disable the proxy or to the address of a proxy
213+ to override the system configuration. See the :doc:`proxy docs
214+ <../../topics/proxies>` for details.
211215 process_exception: When reconnecting automatically, tell whether an
212216 error is transient or fatal. The default behavior is defined by
213217 :func:`process_exception`. Refer to its documentation for details.
@@ -279,6 +283,7 @@ def __init__(
279283 # HTTP
280284 additional_headers : HeadersLike | None = None ,
281285 user_agent_header : str | None = USER_AGENT ,
286+ proxy : str | Literal [True ] | None = True ,
282287 process_exception : Callable [[Exception ], Exception | None ] = process_exception ,
283288 # Timeouts
284289 open_timeout : float | None = 10 ,
@@ -333,6 +338,7 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
333338 )
334339 return connection
335340
341+ self .proxy = proxy
336342 self .protocol_factory = protocol_factory
337343 self .handshake_args = (
338344 additional_headers ,
@@ -346,9 +352,20 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
346352 async def create_connection (self ) -> ClientConnection :
347353 """Create TCP or Unix connection."""
348354 loop = asyncio .get_running_loop ()
355+ kwargs = self .connection_kwargs .copy ()
349356
350357 ws_uri = parse_uri (self .uri )
351- kwargs = self .connection_kwargs .copy ()
358+
359+ proxy = self .proxy
360+ proxy_uri : Proxy | None = None
361+ if kwargs .get ("unix" , False ):
362+ proxy = None
363+ if kwargs .get ("sock" ) is not None :
364+ proxy = None
365+ if proxy is True :
366+ proxy = get_proxy (ws_uri )
367+ if proxy is not None :
368+ proxy_uri = parse_proxy (proxy )
352369
353370 def factory () -> ClientConnection :
354371 return self .protocol_factory (ws_uri )
@@ -365,6 +382,47 @@ def factory() -> ClientConnection:
365382 if kwargs .pop ("unix" , False ):
366383 _ , connection = await loop .create_unix_connection (factory , ** kwargs )
367384 else :
385+ if proxy_uri is not None :
386+ if proxy_uri .scheme [:5 ] == "socks" :
387+ try :
388+ from python_socks import ProxyType
389+ from python_socks .async_ .asyncio import Proxy
390+ except ImportError :
391+ raise ImportError (
392+ "python-socks is required to use a SOCKS proxy"
393+ )
394+ if proxy_uri .scheme == "socks5h" :
395+ proxy_type = ProxyType .SOCKS5
396+ rdns = True
397+ elif proxy_uri .scheme == "socks5" :
398+ proxy_type = ProxyType .SOCKS5
399+ rdns = False
400+ # We use mitmproxy for testing and it doesn't support SOCKS4.
401+ elif proxy_uri .scheme == "socks4a" : # pragma: no cover
402+ proxy_type = ProxyType .SOCKS4
403+ rdns = True
404+ elif proxy_uri .scheme == "socks4" : # pragma: no cover
405+ proxy_type = ProxyType .SOCKS4
406+ rdns = False
407+ # Proxy types are enforced in parse_proxy().
408+ else :
409+ raise AssertionError ("unsupported SOCKS proxy" )
410+ socks_proxy = Proxy (
411+ proxy_type ,
412+ proxy_uri .host ,
413+ proxy_uri .port ,
414+ proxy_uri .username ,
415+ proxy_uri .password ,
416+ rdns ,
417+ )
418+ kwargs ["sock" ] = await socks_proxy .connect (
419+ ws_uri .host ,
420+ ws_uri .port ,
421+ local_addr = kwargs .pop ("local_addr" , None ),
422+ )
423+ # Proxy types are enforced in parse_proxy().
424+ else :
425+ raise AssertionError ("unsupported proxy" )
368426 if kwargs .get ("sock" ) is None :
369427 kwargs .setdefault ("host" , ws_uri .host )
370428 kwargs .setdefault ("port" , ws_uri .port )
0 commit comments