77import urllib .parse
88from collections .abc import AsyncIterator , Generator , Sequence
99from types import TracebackType
10- from typing import Any , Callable
10+ from typing import Any , Callable , Literal
1111
1212from ..client import ClientProtocol , backoff
1313from ..datastructures import HeadersLike
1818from ..http11 import USER_AGENT , Response
1919from ..protocol import CONNECTING , Event
2020from ..typing import LoggerLike , Origin , Subprotocol
21- from ..uri import WebSocketURI , parse_uri
21+ from ..uri import ProxyURI , WebSocketURI , get_proxy , parse_proxy_uri , parse_uri
2222from .compatibility import TimeoutError , asyncio_timeout
2323from .connection import Connection
2424
@@ -208,6 +208,10 @@ class connect:
208208 user_agent_header: Value of the ``User-Agent`` request header.
209209 It defaults to ``"Python/x.y.z websockets/X.Y"``.
210210 Setting it to :obj:`None` removes the header.
211+ proxy: If a proxy is configured, it is used by default. Set ``proxy``
212+ to :obj:`None` to disable the proxy or to the address of a proxy
213+ to override the system configuration. See the :doc:`proxy docs
214+ <../../topics/proxy>` for details.
211215 process_exception: When reconnecting automatically, tell whether an
212216 error is transient or fatal. The default behavior is defined by
213217 :func:`process_exception`. Refer to its documentation for details.
@@ -279,6 +283,7 @@ def __init__(
279283 # HTTP
280284 additional_headers : HeadersLike | None = None ,
281285 user_agent_header : str | None = USER_AGENT ,
286+ proxy : str | Literal [True ] | None = True ,
282287 process_exception : Callable [[Exception ], Exception | None ] = process_exception ,
283288 # Timeouts
284289 open_timeout : float | None = 10 ,
@@ -333,6 +338,7 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
333338 )
334339 return connection
335340
341+ self .proxy = proxy
336342 self .protocol_factory = protocol_factory
337343 self .handshake_args = (
338344 additional_headers ,
@@ -346,9 +352,20 @@ def protocol_factory(uri: WebSocketURI) -> ClientConnection:
346352 async def create_connection (self ) -> ClientConnection :
347353 """Create TCP or Unix connection."""
348354 loop = asyncio .get_running_loop ()
355+ kwargs = self .connection_kwargs .copy ()
349356
350357 ws_uri = parse_uri (self .uri )
351- kwargs = self .connection_kwargs .copy ()
358+
359+ proxy = self .proxy
360+ proxy_uri : ProxyURI | None = None
361+ if kwargs .pop ("unix" , False ):
362+ proxy = None
363+ if kwargs .get ("sock" ) is not None :
364+ proxy = None
365+ if proxy is True :
366+ proxy = get_proxy (ws_uri )
367+ if proxy is not None :
368+ proxy_uri = parse_proxy_uri (proxy )
352369
353370 def factory () -> ClientConnection :
354371 return self .protocol_factory (ws_uri )
@@ -365,6 +382,38 @@ def factory() -> ClientConnection:
365382 if kwargs .pop ("unix" , False ):
366383 _ , connection = await loop .create_unix_connection (factory , ** kwargs )
367384 else :
385+ if proxy_uri is not None :
386+ if proxy_uri .scheme [:5 ] == "socks" :
387+ try :
388+ from python_socks import ProxyType
389+ from python_socks .async_ .asyncio import Proxy
390+ except ImportError :
391+ raise ImportError (
392+ "python-socks is required to use a SOCKS proxy"
393+ )
394+ if proxy_uri .scheme [:6 ] == "socks5" :
395+ proxy_type = ProxyType .SOCKS5
396+ elif proxy_uri .scheme [:6 ] == "socks4" :
397+ proxy_type = ProxyType .SOCKS4
398+ else :
399+ raise AssertionError ("unsupported SOCKS proxy" )
400+ socks_proxy = Proxy (
401+ proxy_type ,
402+ proxy_uri .host ,
403+ proxy_uri .port ,
404+ proxy_uri .username ,
405+ proxy_uri .password ,
406+ rdns = kwargs .pop ("rdns" , None ),
407+ )
408+ kwargs ["sock" ] = await socks_proxy .connect (
409+ ws_uri .host ,
410+ ws_uri .port ,
411+ local_addr = kwargs .pop ("local_addr" , None ),
412+ )
413+ else :
414+ raise NotImplementedError (
415+ f"proxy scheme not implemented yet: { proxy_uri .scheme } "
416+ )
368417 if kwargs .get ("sock" ) is None :
369418 kwargs .setdefault ("host" , ws_uri .host )
370419 kwargs .setdefault ("port" , ws_uri .port )
0 commit comments