2525 _parse_obfuscated_ip ,
2626 fetch_url ,
2727 extract_content_from_html ,
28+ SSRFSafeTransport ,
2829 BLOCKED_HOSTNAMES ,
2930 CLOUD_METADATA_IPS ,
3031)
@@ -134,8 +135,9 @@ async def test_ssl_disabled_allows_self_signed(self, reset_env):
134135 import mcp_server_fetch .server as server_module
135136 importlib .reload (server_module )
136137
137- # Mock httpx.AsyncClient to verify verify=False is passed
138- with patch ('httpx.AsyncClient' ) as mock_client :
138+ # Mock httpx.AsyncClient and AsyncHTTPTransport to verify verify=False is passed
139+ with patch ('httpx.AsyncClient' ) as mock_client , \
140+ patch ('httpx.AsyncHTTPTransport' ) as mock_transport_class :
139141 mock_response = MagicMock ()
140142 mock_response .status_code = 200
141143 mock_response .text = "<html><body>Test</body></html>"
@@ -155,9 +157,9 @@ async def test_ssl_disabled_allows_self_signed(self, reset_env):
155157 "TestAgent/1.0"
156158 )
157159
158- # Verify AsyncClient was called with verify=False
159- mock_client .assert_called_once ()
160- call_kwargs = mock_client .call_args [1 ]
160+ # Verify AsyncHTTPTransport was created with verify=False
161+ mock_transport_class .assert_called_once ()
162+ call_kwargs = mock_transport_class .call_args [1 ]
161163 assert call_kwargs .get ('verify' ) is False
162164
163165
@@ -646,6 +648,141 @@ def test_url_with_port_bypass_attempt(self):
646648 validate_url_for_ssrf ("http://127.0.0.1:65535/" )
647649
648650
651+ # =============================================================================
652+ # 8. DNS REBINDING PROTECTION TESTS
653+ # =============================================================================
654+
655+ class TestDNSRebindingProtection :
656+ """Test suite for DNS rebinding TOCTOU protection via SSRFSafeTransport."""
657+
658+ @pytest .mark .asyncio
659+ async def test_transport_blocks_private_ip_at_connection_time (self ):
660+ """SSRFSafeTransport must block requests when DNS resolves to private IP."""
661+ import httpx
662+
663+ transport = SSRFSafeTransport (verify = False )
664+
665+ # Simulate DNS resolving to a private IP (127.0.0.1)
666+ with patch ("socket.getaddrinfo" ) as mock_dns :
667+ mock_dns .return_value = [
668+ (2 , 1 , 6 , '' , ('127.0.0.1' , 0 )),
669+ ]
670+ request = httpx .Request ("GET" , "http://evil-rebind.example.com/secret" )
671+
672+ with pytest .raises (McpError , match = "DNS rebinding protection" ):
673+ await transport .handle_async_request (request )
674+
675+ @pytest .mark .asyncio
676+ async def test_transport_blocks_metadata_ip_at_connection_time (self ):
677+ """SSRFSafeTransport must block DNS rebinding to cloud metadata IP."""
678+ import httpx
679+
680+ transport = SSRFSafeTransport (verify = False )
681+
682+ # Simulate DNS rebinding: attacker DNS returns metadata IP
683+ with patch ("socket.getaddrinfo" ) as mock_dns :
684+ mock_dns .return_value = [
685+ (2 , 1 , 6 , '' , ('169.254.169.254' , 0 )),
686+ ]
687+ request = httpx .Request ("GET" , "http://evil-rebind.example.com/metadata" )
688+
689+ with pytest .raises (McpError , match = "DNS rebinding protection" ):
690+ await transport .handle_async_request (request )
691+
692+ @pytest .mark .asyncio
693+ async def test_transport_allows_public_ip (self ):
694+ """SSRFSafeTransport must allow requests when DNS resolves to public IP."""
695+ import httpx
696+
697+ transport = SSRFSafeTransport (verify = False )
698+
699+ # Simulate DNS resolving to a public IP
700+ with patch ("socket.getaddrinfo" ) as mock_dns , \
701+ patch .object (transport , '_transport' ) as mock_inner :
702+ mock_dns .return_value = [
703+ (2 , 1 , 6 , '' , ('93.184.216.34' , 0 )),
704+ ]
705+ mock_response = httpx .Response (200 , text = "OK" )
706+ mock_inner .handle_async_request = AsyncMock (return_value = mock_response )
707+
708+ request = httpx .Request ("GET" , "http://example.com/page" )
709+ response = await transport .handle_async_request (request )
710+
711+ assert response .status_code == 200
712+ # Verify the inner transport was called with the IP-based URL
713+ called_request = mock_inner .handle_async_request .call_args [0 ][0 ]
714+ assert called_request .url .host == "93.184.216.34"
715+ # Verify Host header preserved
716+ assert called_request .headers ["host" ] == "example.com"
717+
718+ @pytest .mark .asyncio
719+ async def test_transport_skips_validation_for_direct_ip (self ):
720+ """SSRFSafeTransport should skip DNS resolution for direct IP URLs."""
721+ import httpx
722+
723+ transport = SSRFSafeTransport (verify = False )
724+
725+ # Direct IP URL - should go straight to inner transport (IP already validated by validate_url_for_ssrf)
726+ with patch .object (transport , '_transport' ) as mock_inner , \
727+ patch ("socket.getaddrinfo" ) as mock_dns :
728+ mock_response = httpx .Response (200 , text = "OK" )
729+ mock_inner .handle_async_request = AsyncMock (return_value = mock_response )
730+
731+ request = httpx .Request ("GET" , "http://93.184.216.34/page" )
732+ await transport .handle_async_request (request )
733+
734+ # DNS should NOT be called for direct IP
735+ mock_dns .assert_not_called ()
736+ mock_inner .handle_async_request .assert_called_once ()
737+
738+ @pytest .mark .asyncio
739+ async def test_transport_blocks_dns_failure (self ):
740+ """SSRFSafeTransport must raise error when DNS resolution fails."""
741+ import httpx
742+ import socket as socket_module
743+
744+ transport = SSRFSafeTransport (verify = False )
745+
746+ with patch ("socket.getaddrinfo" ) as mock_dns :
747+ mock_dns .side_effect = socket_module .gaierror ("Name resolution failed" )
748+ request = httpx .Request ("GET" , "http://nonexistent.example.com/" )
749+
750+ with pytest .raises (McpError , match = "Failed to resolve" ):
751+ await transport .handle_async_request (request )
752+
753+ @pytest .mark .asyncio
754+ async def test_dns_rebinding_scenario (self ):
755+ """
756+ Full DNS rebinding attack scenario:
757+ 1. validate_url_for_ssrf() sees public IP (passes)
758+ 2. SSRFSafeTransport resolves DNS again and sees private IP (blocks)
759+ """
760+ import httpx
761+
762+ call_count = 0
763+
764+ def rebinding_dns (hostname , * args , ** kwargs ):
765+ nonlocal call_count
766+ call_count += 1
767+ if call_count == 1 :
768+ # First call (validate_url_for_ssrf): return public IP
769+ return [(2 , 1 , 6 , '' , ('93.184.216.34' , 0 ))]
770+ else :
771+ # Second call (SSRFSafeTransport): return private IP (rebinding!)
772+ return [(2 , 1 , 6 , '' , ('169.254.169.254' , 0 ))]
773+
774+ with patch ("socket.getaddrinfo" , side_effect = rebinding_dns ):
775+ # First validation passes (public IP)
776+ validate_url_for_ssrf ("http://evil-rebind.example.com/" )
777+
778+ # But transport-level check catches the rebinding
779+ transport = SSRFSafeTransport (verify = False )
780+ request = httpx .Request ("GET" , "http://evil-rebind.example.com/metadata" )
781+
782+ with pytest .raises (McpError , match = "DNS rebinding protection" ):
783+ await transport .handle_async_request (request )
784+
785+
649786# =============================================================================
650787# RUN CONFIGURATION
651788# =============================================================================
0 commit comments