diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index cc9d9206..c2e8e53f 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -285,61 +285,69 @@ async def handle_async_request(self, request: Request) -> Response: headers=connect_headers, extensions=request.extensions, ) - connect_response = await self._connection.handle_async_request( - connect_request - ) - - if connect_response.status < 200 or connect_response.status > 299: - reason_bytes = connect_response.extensions.get("reason_phrase", b"") - reason_str = reason_bytes.decode("ascii", errors="ignore") - msg = "%d %s" % (connect_response.status, reason_str) - await self._connection.aclose() - raise ProxyError(msg) - - stream = connect_response.extensions["network_stream"] - - # Upgrade the stream to SSL - ssl_context = ( - default_ssl_context() - if self._ssl_context is None - else self._ssl_context - ) - alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] - ssl_context.set_alpn_protocols(alpn_protocols) - - kwargs = { - "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), - "timeout": timeout, - } - async with Trace("start_tls", logger, request, kwargs) as trace: - stream = await stream.start_tls(**kwargs) - trace.return_value = stream - - # Determine if we should be using HTTP/1.1 or HTTP/2 - ssl_object = stream.get_extra_info("ssl_object") - http2_negotiated = ( - ssl_object is not None - and ssl_object.selected_alpn_protocol() == "h2" - ) - # Create the HTTP/1.1 or HTTP/2 connection - if http2_negotiated or (self._http2 and not self._http1): - from .http2 import AsyncHTTP2Connection + try: + connect_response = await self._connection.handle_async_request( + connect_request + ) - self._connection = AsyncHTTP2Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get( + "reason_phrase", b"" + ) + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + await self._connection.aclose() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context ) - else: - self._connection = AsyncHTTP11Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("start_tls", logger, request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" ) - self._connected = True + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + except Exception: + await self._connection.aclose() + raise + return await self._connection.handle_async_request(request) def can_handle_request(self, origin: Origin) -> bool: diff --git a/httpcore/_async/socks_proxy.py b/httpcore/_async/socks_proxy.py index b363f55a..cb77660d 100644 --- a/httpcore/_async/socks_proxy.py +++ b/httpcore/_async/socks_proxy.py @@ -45,6 +45,7 @@ async def _init_socks5_connection( host: bytes, port: int, auth: tuple[bytes, bytes] | None = None, + timeout: float | None = None, # <--- FIX 1: Add timeout argument ) -> None: conn = socksio.socks5.SOCKS5Connection() @@ -56,10 +57,12 @@ async def _init_socks5_connection( ) conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method])) outgoing_bytes = conn.data_to_send() - await stream.write(outgoing_bytes) + await stream.write(outgoing_bytes, timeout=timeout) # <--- FIX 2: Pass timeout # Auth method response - incoming_bytes = await stream.read(max_bytes=4096) + incoming_bytes = await stream.read( + max_bytes=4096, timeout=timeout + ) # <--- FIX 3: Pass timeout response = conn.receive_data(incoming_bytes) assert isinstance(response, socksio.socks5.SOCKS5AuthReply) if response.method != auth_method: @@ -75,10 +78,12 @@ async def _init_socks5_connection( username, password = auth conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password)) outgoing_bytes = conn.data_to_send() - await stream.write(outgoing_bytes) + await stream.write(outgoing_bytes, timeout=timeout) # <--- FIX 4: Pass timeout # Username/password response - incoming_bytes = await stream.read(max_bytes=4096) + incoming_bytes = await stream.read( + max_bytes=4096, timeout=timeout + ) # <--- FIX 5: Pass timeout response = conn.receive_data(incoming_bytes) assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply) if not response.success: @@ -91,10 +96,12 @@ async def _init_socks5_connection( ) ) outgoing_bytes = conn.data_to_send() - await stream.write(outgoing_bytes) + await stream.write(outgoing_bytes, timeout=timeout) # <--- FIX 6: Pass timeout # Connect response - incoming_bytes = await stream.read(max_bytes=4096) + incoming_bytes = await stream.read( + max_bytes=4096, timeout=timeout + ) # <--- FIX 7: Pass timeout response = conn.receive_data(incoming_bytes) assert isinstance(response, socksio.socks5.SOCKS5Reply) if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED: @@ -122,33 +129,6 @@ def __init__( ) -> None: """ A connection pool for making HTTP requests. - - Parameters: - proxy_url: The URL to use when connecting to the proxy server. - For example `"http://127.0.0.1:8080/"`. - ssl_context: An SSL context to use for verifying connections. - If not specified, the default `httpcore.default_ssl_context()` - will be used. - max_connections: The maximum number of concurrent HTTP connections that - the pool should allow. Any attempt to send a request on a pool that - would exceed this amount will block until a connection is available. - max_keepalive_connections: The maximum number of idle HTTP connections - that will be maintained in the pool. - keepalive_expiry: The duration in seconds that an idle HTTP connection - may be maintained for before being expired from the pool. - http1: A boolean indicating if HTTP/1.1 requests should be supported - by the connection pool. Defaults to True. - http2: A boolean indicating if HTTP/2 requests should be supported by - the connection pool. Defaults to False. - retries: The maximum number of retries when trying to establish - a connection. - local_address: Local address to connect from. Can also be used to - connect using a particular address family. Using - `local_address="0.0.0.0"` will connect using an `AF_INET` address - (IPv4), while using `local_address="::"` will connect using an - `AF_INET6` address (IPv6). - uds: Path to a Unix Domain Socket to use instead of TCP sockets. - network_backend: A backend instance to use for handling network I/O. """ super().__init__( ssl_context=ssl_context, @@ -237,6 +217,7 @@ async def handle_async_request(self, request: Request) -> Response: "host": self._remote_origin.host.decode("ascii"), "port": self._remote_origin.port, "auth": self._proxy_auth, + "timeout": timeout, # <--- FIX 8: Pass timeout argument } async with Trace( "setup_socks5_connection", logger, request, kwargs diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index ecca88f7..394c8847 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -285,61 +285,69 @@ def handle_request(self, request: Request) -> Response: headers=connect_headers, extensions=request.extensions, ) - connect_response = self._connection.handle_request( - connect_request - ) - - if connect_response.status < 200 or connect_response.status > 299: - reason_bytes = connect_response.extensions.get("reason_phrase", b"") - reason_str = reason_bytes.decode("ascii", errors="ignore") - msg = "%d %s" % (connect_response.status, reason_str) - self._connection.close() - raise ProxyError(msg) - - stream = connect_response.extensions["network_stream"] - - # Upgrade the stream to SSL - ssl_context = ( - default_ssl_context() - if self._ssl_context is None - else self._ssl_context - ) - alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] - ssl_context.set_alpn_protocols(alpn_protocols) - - kwargs = { - "ssl_context": ssl_context, - "server_hostname": self._remote_origin.host.decode("ascii"), - "timeout": timeout, - } - with Trace("start_tls", logger, request, kwargs) as trace: - stream = stream.start_tls(**kwargs) - trace.return_value = stream - - # Determine if we should be using HTTP/1.1 or HTTP/2 - ssl_object = stream.get_extra_info("ssl_object") - http2_negotiated = ( - ssl_object is not None - and ssl_object.selected_alpn_protocol() == "h2" - ) - # Create the HTTP/1.1 or HTTP/2 connection - if http2_negotiated or (self._http2 and not self._http1): - from .http2 import HTTP2Connection + try: + connect_response = self._connection.handle_request( + connect_request + ) - self._connection = HTTP2Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get( + "reason_phrase", b"" + ) + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + self._connection.close() + raise ProxyError(msg) + + stream = connect_response.extensions["network_stream"] + + # Upgrade the stream to SSL + ssl_context = ( + default_ssl_context() + if self._ssl_context is None + else self._ssl_context ) - else: - self._connection = HTTP11Connection( - origin=self._remote_origin, - stream=stream, - keepalive_expiry=self._keepalive_expiry, + alpn_protocols = ["http/1.1", "h2"] if self._http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + kwargs = { + "ssl_context": ssl_context, + "server_hostname": self._remote_origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("start_tls", logger, request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + + # Determine if we should be using HTTP/1.1 or HTTP/2 + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" ) - self._connected = True + # Create the HTTP/1.1 or HTTP/2 connection + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + else: + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + + self._connected = True + except Exception: + self._connection.close() + raise + return self._connection.handle_request(request) def can_handle_request(self, origin: Origin) -> bool: diff --git a/httpcore/_sync/socks_proxy.py b/httpcore/_sync/socks_proxy.py index 0ca96ddf..4f89aca0 100644 --- a/httpcore/_sync/socks_proxy.py +++ b/httpcore/_sync/socks_proxy.py @@ -45,6 +45,7 @@ def _init_socks5_connection( host: bytes, port: int, auth: tuple[bytes, bytes] | None = None, + timeout: float | None = None, # <--- FIX 1: Add timeout argument ) -> None: conn = socksio.socks5.SOCKS5Connection() @@ -56,10 +57,12 @@ def _init_socks5_connection( ) conn.send(socksio.socks5.SOCKS5AuthMethodsRequest([auth_method])) outgoing_bytes = conn.data_to_send() - stream.write(outgoing_bytes) + stream.write(outgoing_bytes, timeout=timeout) # <--- FIX 2: Pass timeout # Auth method response - incoming_bytes = stream.read(max_bytes=4096) + incoming_bytes = stream.read( + max_bytes=4096, timeout=timeout + ) # <--- FIX 3: Pass timeout response = conn.receive_data(incoming_bytes) assert isinstance(response, socksio.socks5.SOCKS5AuthReply) if response.method != auth_method: @@ -75,10 +78,12 @@ def _init_socks5_connection( username, password = auth conn.send(socksio.socks5.SOCKS5UsernamePasswordRequest(username, password)) outgoing_bytes = conn.data_to_send() - stream.write(outgoing_bytes) + stream.write(outgoing_bytes, timeout=timeout) # <--- FIX 4: Pass timeout # Username/password response - incoming_bytes = stream.read(max_bytes=4096) + incoming_bytes = stream.read( + max_bytes=4096, timeout=timeout + ) # <--- FIX 5: Pass timeout response = conn.receive_data(incoming_bytes) assert isinstance(response, socksio.socks5.SOCKS5UsernamePasswordReply) if not response.success: @@ -91,10 +96,12 @@ def _init_socks5_connection( ) ) outgoing_bytes = conn.data_to_send() - stream.write(outgoing_bytes) + stream.write(outgoing_bytes, timeout=timeout) # <--- FIX 6: Pass timeout # Connect response - incoming_bytes = stream.read(max_bytes=4096) + incoming_bytes = stream.read( + max_bytes=4096, timeout=timeout + ) # <--- FIX 7: Pass timeout response = conn.receive_data(incoming_bytes) assert isinstance(response, socksio.socks5.SOCKS5Reply) if response.reply_code != socksio.socks5.SOCKS5ReplyCode.SUCCEEDED: @@ -122,33 +129,6 @@ def __init__( ) -> None: """ A connection pool for making HTTP requests. - - Parameters: - proxy_url: The URL to use when connecting to the proxy server. - For example `"http://127.0.0.1:8080/"`. - ssl_context: An SSL context to use for verifying connections. - If not specified, the default `httpcore.default_ssl_context()` - will be used. - max_connections: The maximum number of concurrent HTTP connections that - the pool should allow. Any attempt to send a request on a pool that - would exceed this amount will block until a connection is available. - max_keepalive_connections: The maximum number of idle HTTP connections - that will be maintained in the pool. - keepalive_expiry: The duration in seconds that an idle HTTP connection - may be maintained for before being expired from the pool. - http1: A boolean indicating if HTTP/1.1 requests should be supported - by the connection pool. Defaults to True. - http2: A boolean indicating if HTTP/2 requests should be supported by - the connection pool. Defaults to False. - retries: The maximum number of retries when trying to establish - a connection. - local_address: Local address to connect from. Can also be used to - connect using a particular address family. Using - `local_address="0.0.0.0"` will connect using an `AF_INET` address - (IPv4), while using `local_address="::"` will connect using an - `AF_INET6` address (IPv6). - uds: Path to a Unix Domain Socket to use instead of TCP sockets. - network_backend: A backend instance to use for handling network I/O. """ super().__init__( ssl_context=ssl_context, @@ -237,6 +217,7 @@ def handle_request(self, request: Request) -> Response: "host": self._remote_origin.host.decode("ascii"), "port": self._remote_origin.port, "auth": self._proxy_auth, + "timeout": timeout, # <--- FIX 8: Pass timeout argument } with Trace( "setup_socks5_connection", logger, request, kwargs diff --git a/reproduce_httpcore.py b/reproduce_httpcore.py new file mode 100644 index 00000000..be45af37 --- /dev/null +++ b/reproduce_httpcore.py @@ -0,0 +1,63 @@ +import httpcore +import socket +import threading +import time + +# --- Same Server Setup as before --- +TIMEOUT = 2.0 +HANG_TIME = 20 + +def get_free_port(): + with socket.socket() as s: + s.bind(('', 0)) + return s.getsockname()[1] + +def blackhole_proxy_server(port, stop_event): + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(('127.0.0.1', port)) + server.listen(1) + server.settimeout(1.0) + while not stop_event.is_set(): + try: + client, _ = server.accept() + time.sleep(HANG_TIME) # Hang the handshake + client.close() + except socket.timeout: continue + except: break + server.close() + +# --- The Test --- +def run_test(): + proxy_port = get_free_port() + stop_event = threading.Event() + t = threading.Thread(target=blackhole_proxy_server, args=(proxy_port, stop_event)) + t.start() + time.sleep(0.5) + + print(f"[*] Testing httpcore SOCKSProxy with {TIMEOUT}s timeout...") + start_time = time.time() + + # We use the low-level SOCKSProxy directly + with httpcore.SOCKSProxy( + proxy_url=f"socks5://127.0.0.1:{proxy_port}" + ) as pool: + try: + # We assume httpcore 1.0+ style request + pool.request( + "GET", + "http://example.com", + extensions={'timeout': {'connect': TIMEOUT, 'read': TIMEOUT}} + ) + except httpcore.TimeoutException: + print("[SUCCESS] Caught timeout correctly!") + except Exception as e: + print(f"[ERROR] {e}") + finally: + duration = time.time() - start_time + print(f"[*] Duration: {duration:.2f}s") + stop_event.set() + t.join() + +if __name__ == "__main__": + run_test() \ No newline at end of file diff --git a/reproduce_httpcore_async.py b/reproduce_httpcore_async.py new file mode 100644 index 00000000..3e577cfe --- /dev/null +++ b/reproduce_httpcore_async.py @@ -0,0 +1,69 @@ +import httpcore +import socket +import threading +import time +import asyncio + +# --- Server Setup (Same as before) --- +HANG_TIME = 20 +TIMEOUT = 2.0 + +def get_free_port(): + with socket.socket() as s: + s.bind(('', 0)) + return s.getsockname()[1] + +def blackhole_proxy_server(port, stop_event): + server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server.bind(('127.0.0.1', port)) + server.listen(1) + server.settimeout(1.0) + + while not stop_event.is_set(): + try: + client, _ = server.accept() + # print("[Server] Accepted connection, sleeping...") + time.sleep(HANG_TIME) + client.close() + except socket.timeout: continue + except: break + server.close() + +# --- Async Test --- +async def run_async_test(port): + print(f"[*] Testing ASYNC httpcore SOCKSProxy with {TIMEOUT}s timeout...") + start_time = time.time() + + async with httpcore.AsyncSOCKSProxy( + proxy_url=f"socks5://127.0.0.1:{port}" + ) as pool: + try: + await pool.request( + "GET", + "http://example.com", + extensions={'timeout': {'connect': TIMEOUT, 'read': TIMEOUT}} + ) + except httpcore.TimeoutException: + print("[SUCCESS] Caught timeout correctly!") + except Exception as e: + print(f"[ERROR] {type(e).__name__}: {e}") + finally: + duration = time.time() - start_time + print(f"[*] Duration: {duration:.2f}s") + +def main(): + port = get_free_port() + stop_event = threading.Event() + t = threading.Thread(target=blackhole_proxy_server, args=(port, stop_event)) + t.start() + time.sleep(0.5) + + try: + asyncio.run(run_async_test(port)) + finally: + stop_event.set() + t.join() + +if __name__ == "__main__": + main() \ No newline at end of file