Skip to content

Commit 87391aa

Browse files
committed
fix(fetch): add DNS rebinding TOCTOU protection via SSRFSafeTransport
Address review feedback on SSRF protection: - Add SSRFSafeTransport custom async transport that resolves DNS, validates the resolved IP, and replaces the hostname with the validated IP before connecting. This eliminates the TOCTOU window between validate_url_for_ssrf() and the actual HTTP request. - Integrate SSRFSafeTransport into fetch_url() and check_may_autonomously_fetch_url() replacing direct AsyncClient usage. - Add 6 DNS rebinding tests including full attack scenario simulation. - Update existing tests to match new transport-based architecture.
1 parent c99c358 commit 87391aa

3 files changed

Lines changed: 239 additions & 17 deletions

File tree

src/fetch/src/mcp_server_fetch/server.py

Lines changed: 91 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Annotated, Tuple
66
from urllib.parse import urlparse, urlunparse
77

8+
import httpx
89
import markdownify
910
import readabilipy.simple_json
1011
from mcp.shared.exceptions import McpError
@@ -237,9 +238,9 @@ def validate_url_for_ssrf(url: str) -> None:
237238
McpError: If the URL is potentially dangerous
238239
239240
Security Note:
240-
This validation happens BEFORE the request is made, but DNS rebinding
241-
attacks could still occur. For maximum security, use network-level
242-
controls (firewall rules, egress filtering).
241+
This validation provides early rejection of obviously dangerous URLs.
242+
DNS rebinding protection is handled at the transport layer by
243+
SSRFSafeTransport, which validates resolved IPs at connection time.
243244
"""
244245
try:
245246
parsed = urlparse(url)
@@ -332,6 +333,89 @@ def validate_url_for_ssrf(url: str) -> None:
332333
))
333334

334335

336+
class SSRFSafeTransport(httpx.AsyncBaseTransport):
337+
"""
338+
Custom async transport that prevents DNS rebinding attacks.
339+
340+
DNS rebinding TOCTOU (Time-of-Check-Time-of-Use) attack:
341+
1. validate_url_for_ssrf() resolves DNS → gets public IP → passes check
342+
2. Attacker's DNS server changes the record to a private IP (e.g., 169.254.169.254)
343+
3. httpx resolves DNS again → gets private IP → connects to internal service
344+
345+
This transport eliminates the TOCTOU window by:
346+
1. Resolving DNS ourselves
347+
2. Validating the resolved IP
348+
3. Replacing the hostname in the URL with the validated IP
349+
4. Preserving the original Host header for correct HTTP routing
350+
"""
351+
352+
def __init__(self, proxy: str | None = None, verify: bool = True):
353+
kwargs: dict = {"verify": verify}
354+
if proxy:
355+
kwargs["proxy"] = proxy
356+
self._transport = httpx.AsyncHTTPTransport(**kwargs)
357+
358+
async def handle_async_request(self, request: httpx.Request) -> httpx.Response:
359+
hostname = request.url.host
360+
# Skip IP validation for already-resolved IPs
361+
try:
362+
ipaddress.ip_address(hostname)
363+
# Already an IP - validation was done in validate_url_for_ssrf()
364+
return await self._transport.handle_async_request(request)
365+
except ValueError:
366+
pass # It's a hostname, resolve it
367+
368+
# Resolve DNS
369+
try:
370+
addr_info = socket.getaddrinfo(
371+
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
372+
)
373+
if not addr_info:
374+
raise McpError(ErrorData(
375+
code=INVALID_PARAMS,
376+
message=f"Failed to resolve hostname '{hostname}': no addresses found",
377+
))
378+
resolved_ip = addr_info[0][4][0]
379+
except socket.gaierror as e:
380+
raise McpError(ErrorData(
381+
code=INVALID_PARAMS,
382+
message=f"Failed to resolve hostname '{hostname}': {str(e)}",
383+
))
384+
385+
# Validate resolved IP against SSRF rules
386+
if not ALLOW_PRIVATE_IPS and _is_ip_private_or_reserved(resolved_ip):
387+
raise McpError(ErrorData(
388+
code=INVALID_PARAMS,
389+
message=f"DNS rebinding protection: hostname '{hostname}' resolved to "
390+
f"private/internal IP '{resolved_ip}' at connection time. "
391+
f"Set MCP_FETCH_ALLOW_PRIVATE_IPS=true to allow internal network access.",
392+
))
393+
394+
# Replace hostname with validated IP to prevent DNS rebinding
395+
# The Host header is already set to the original hostname by httpx
396+
new_url = request.url.copy_with(host=resolved_ip)
397+
# Create new request with the IP-based URL but same headers (including Host)
398+
new_request = httpx.Request(
399+
method=request.method,
400+
url=new_url,
401+
headers=request.headers,
402+
stream=request.stream,
403+
extensions=request.extensions,
404+
)
405+
406+
return await self._transport.handle_async_request(new_request)
407+
408+
async def aclose(self):
409+
await self._transport.aclose()
410+
411+
async def __aenter__(self):
412+
await self._transport.__aenter__()
413+
return self
414+
415+
async def __aexit__(self, exc_type, exc_val, exc_tb):
416+
await self._transport.__aexit__(exc_type, exc_val, exc_tb)
417+
418+
335419
def extract_content_from_html(html: str) -> str:
336420
"""Extract and convert HTML content to Markdown format.
337421
@@ -381,14 +465,13 @@ async def check_may_autonomously_fetch_url(url: str, user_agent: str, proxy_url:
381465
- SSL certificate verification (configurable via SSL_VERIFY)
382466
- Comprehensive SSL error handling
383467
"""
384-
import httpx
385-
386468
robot_txt_url = get_robots_txt_url(url)
387469

388470
# SSRF Protection: Validate robots.txt URL before fetching
389471
validate_url_for_ssrf(robot_txt_url)
390472

391-
async with httpx.AsyncClient(proxies=proxy_url, verify=SSL_VERIFY) as client:
473+
transport = SSRFSafeTransport(proxy=proxy_url, verify=SSL_VERIFY)
474+
async with httpx.AsyncClient(transport=transport) as client:
392475
try:
393476
response = await client.get(
394477
robot_txt_url,
@@ -461,12 +544,11 @@ async def fetch_url(
461544
- User-Agent header for transparency
462545
- Comprehensive SSL error handling (catches wrapped exceptions)
463546
"""
464-
import httpx
465-
466547
# SSRF Protection: Validate URL before fetching
467548
validate_url_for_ssrf(url)
468549

469-
async with httpx.AsyncClient(proxies=proxy_url, verify=SSL_VERIFY) as client:
550+
transport = SSRFSafeTransport(proxy=proxy_url, verify=SSL_VERIFY)
551+
async with httpx.AsyncClient(transport=transport) as client:
470552
try:
471553
response = await client.get(
472554
url,

src/fetch/tests/test_security.py

Lines changed: 142 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
_parse_obfuscated_ip,
2626
fetch_url,
2727
extract_content_from_html,
28+
SSRFSafeTransport,
2829
BLOCKED_HOSTNAMES,
2930
CLOUD_METADATA_IPS,
3031
)
@@ -134,8 +135,9 @@ async def test_ssl_disabled_allows_self_signed(self, reset_env):
134135
import mcp_server_fetch.server as server_module
135136
importlib.reload(server_module)
136137

137-
# Mock httpx.AsyncClient to verify verify=False is passed
138-
with patch('httpx.AsyncClient') as mock_client:
138+
# Mock httpx.AsyncClient and AsyncHTTPTransport to verify verify=False is passed
139+
with patch('httpx.AsyncClient') as mock_client, \
140+
patch('httpx.AsyncHTTPTransport') as mock_transport_class:
139141
mock_response = MagicMock()
140142
mock_response.status_code = 200
141143
mock_response.text = "<html><body>Test</body></html>"
@@ -155,9 +157,9 @@ async def test_ssl_disabled_allows_self_signed(self, reset_env):
155157
"TestAgent/1.0"
156158
)
157159

158-
# Verify AsyncClient was called with verify=False
159-
mock_client.assert_called_once()
160-
call_kwargs = mock_client.call_args[1]
160+
# Verify AsyncHTTPTransport was created with verify=False
161+
mock_transport_class.assert_called_once()
162+
call_kwargs = mock_transport_class.call_args[1]
161163
assert call_kwargs.get('verify') is False
162164

163165

@@ -646,6 +648,141 @@ def test_url_with_port_bypass_attempt(self):
646648
validate_url_for_ssrf("http://127.0.0.1:65535/")
647649

648650

651+
# =============================================================================
652+
# 8. DNS REBINDING PROTECTION TESTS
653+
# =============================================================================
654+
655+
class TestDNSRebindingProtection:
656+
"""Test suite for DNS rebinding TOCTOU protection via SSRFSafeTransport."""
657+
658+
@pytest.mark.asyncio
659+
async def test_transport_blocks_private_ip_at_connection_time(self):
660+
"""SSRFSafeTransport must block requests when DNS resolves to private IP."""
661+
import httpx
662+
663+
transport = SSRFSafeTransport(verify=False)
664+
665+
# Simulate DNS resolving to a private IP (127.0.0.1)
666+
with patch("socket.getaddrinfo") as mock_dns:
667+
mock_dns.return_value = [
668+
(2, 1, 6, '', ('127.0.0.1', 0)),
669+
]
670+
request = httpx.Request("GET", "http://evil-rebind.example.com/secret")
671+
672+
with pytest.raises(McpError, match="DNS rebinding protection"):
673+
await transport.handle_async_request(request)
674+
675+
@pytest.mark.asyncio
676+
async def test_transport_blocks_metadata_ip_at_connection_time(self):
677+
"""SSRFSafeTransport must block DNS rebinding to cloud metadata IP."""
678+
import httpx
679+
680+
transport = SSRFSafeTransport(verify=False)
681+
682+
# Simulate DNS rebinding: attacker DNS returns metadata IP
683+
with patch("socket.getaddrinfo") as mock_dns:
684+
mock_dns.return_value = [
685+
(2, 1, 6, '', ('169.254.169.254', 0)),
686+
]
687+
request = httpx.Request("GET", "http://evil-rebind.example.com/metadata")
688+
689+
with pytest.raises(McpError, match="DNS rebinding protection"):
690+
await transport.handle_async_request(request)
691+
692+
@pytest.mark.asyncio
693+
async def test_transport_allows_public_ip(self):
694+
"""SSRFSafeTransport must allow requests when DNS resolves to public IP."""
695+
import httpx
696+
697+
transport = SSRFSafeTransport(verify=False)
698+
699+
# Simulate DNS resolving to a public IP
700+
with patch("socket.getaddrinfo") as mock_dns, \
701+
patch.object(transport, '_transport') as mock_inner:
702+
mock_dns.return_value = [
703+
(2, 1, 6, '', ('93.184.216.34', 0)),
704+
]
705+
mock_response = httpx.Response(200, text="OK")
706+
mock_inner.handle_async_request = AsyncMock(return_value=mock_response)
707+
708+
request = httpx.Request("GET", "http://example.com/page")
709+
response = await transport.handle_async_request(request)
710+
711+
assert response.status_code == 200
712+
# Verify the inner transport was called with the IP-based URL
713+
called_request = mock_inner.handle_async_request.call_args[0][0]
714+
assert called_request.url.host == "93.184.216.34"
715+
# Verify Host header preserved
716+
assert called_request.headers["host"] == "example.com"
717+
718+
@pytest.mark.asyncio
719+
async def test_transport_skips_validation_for_direct_ip(self):
720+
"""SSRFSafeTransport should skip DNS resolution for direct IP URLs."""
721+
import httpx
722+
723+
transport = SSRFSafeTransport(verify=False)
724+
725+
# Direct IP URL - should go straight to inner transport (IP already validated by validate_url_for_ssrf)
726+
with patch.object(transport, '_transport') as mock_inner, \
727+
patch("socket.getaddrinfo") as mock_dns:
728+
mock_response = httpx.Response(200, text="OK")
729+
mock_inner.handle_async_request = AsyncMock(return_value=mock_response)
730+
731+
request = httpx.Request("GET", "http://93.184.216.34/page")
732+
await transport.handle_async_request(request)
733+
734+
# DNS should NOT be called for direct IP
735+
mock_dns.assert_not_called()
736+
mock_inner.handle_async_request.assert_called_once()
737+
738+
@pytest.mark.asyncio
739+
async def test_transport_blocks_dns_failure(self):
740+
"""SSRFSafeTransport must raise error when DNS resolution fails."""
741+
import httpx
742+
import socket as socket_module
743+
744+
transport = SSRFSafeTransport(verify=False)
745+
746+
with patch("socket.getaddrinfo") as mock_dns:
747+
mock_dns.side_effect = socket_module.gaierror("Name resolution failed")
748+
request = httpx.Request("GET", "http://nonexistent.example.com/")
749+
750+
with pytest.raises(McpError, match="Failed to resolve"):
751+
await transport.handle_async_request(request)
752+
753+
@pytest.mark.asyncio
754+
async def test_dns_rebinding_scenario(self):
755+
"""
756+
Full DNS rebinding attack scenario:
757+
1. validate_url_for_ssrf() sees public IP (passes)
758+
2. SSRFSafeTransport resolves DNS again and sees private IP (blocks)
759+
"""
760+
import httpx
761+
762+
call_count = 0
763+
764+
def rebinding_dns(hostname, *args, **kwargs):
765+
nonlocal call_count
766+
call_count += 1
767+
if call_count == 1:
768+
# First call (validate_url_for_ssrf): return public IP
769+
return [(2, 1, 6, '', ('93.184.216.34', 0))]
770+
else:
771+
# Second call (SSRFSafeTransport): return private IP (rebinding!)
772+
return [(2, 1, 6, '', ('169.254.169.254', 0))]
773+
774+
with patch("socket.getaddrinfo", side_effect=rebinding_dns):
775+
# First validation passes (public IP)
776+
validate_url_for_ssrf("http://evil-rebind.example.com/")
777+
778+
# But transport-level check catches the rebinding
779+
transport = SSRFSafeTransport(verify=False)
780+
request = httpx.Request("GET", "http://evil-rebind.example.com/metadata")
781+
782+
with pytest.raises(McpError, match="DNS rebinding protection"):
783+
await transport.handle_async_request(request)
784+
785+
649786
# =============================================================================
650787
# RUN CONFIGURATION
651788
# =============================================================================

src/fetch/tests/test_server.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,13 +305,14 @@ async def test_fetch_500_raises_error(self):
305305

306306
@pytest.mark.asyncio
307307
async def test_fetch_with_proxy(self):
308-
"""Test that proxy URL is passed to client."""
308+
"""Test that proxy URL is passed to SSRFSafeTransport."""
309309
mock_response = MagicMock()
310310
mock_response.status_code = 200
311311
mock_response.text = '{"data": "test"}'
312312
mock_response.headers = {"content-type": "application/json"}
313313

314314
with patch("httpx.AsyncClient") as mock_client_class, \
315+
patch("httpx.AsyncHTTPTransport") as mock_transport_class, \
315316
patch("mcp_server_fetch.server.validate_url_for_ssrf"), \
316317
patch("mcp_server_fetch.server.SSL_VERIFY", True):
317318
mock_client = AsyncMock()
@@ -325,5 +326,7 @@ async def test_fetch_with_proxy(self):
325326
proxy_url="http://proxy.example.com:8080"
326327
)
327328

328-
# Verify AsyncClient was called with proxy and verify
329-
mock_client_class.assert_called_once_with(proxies="http://proxy.example.com:8080", verify=True)
329+
# Verify AsyncHTTPTransport was created with proxy and verify
330+
mock_transport_class.assert_called_once_with(
331+
verify=True, proxy="http://proxy.example.com:8080"
332+
)

0 commit comments

Comments
 (0)