diff --git a/src/google/adk/tools/load_web_page.py b/src/google/adk/tools/load_web_page.py index e7419c9fbf..2165640ffd 100644 --- a/src/google/adk/tools/load_web_page.py +++ b/src/google/adk/tools/load_web_page.py @@ -16,8 +16,13 @@ """Tool for web browse.""" +from urllib.parse import urlparse + import requests +# Default timeout in seconds for HTTP requests. +_DEFAULT_TIMEOUT_SECONDS = 10 + def load_web_page(url: str) -> str: """Fetches the content in the url and returns the text in it. @@ -30,8 +35,21 @@ def load_web_page(url: str) -> str: """ from bs4 import BeautifulSoup - # Set allow_redirects=False to prevent SSRF attacks via redirection. - response = requests.get(url, allow_redirects=False) + parsed = urlparse(url) + if parsed.scheme not in ('http', 'https'): + return ( + f'Invalid URL scheme: {parsed.scheme}. Only http and https are allowed.' + ) + + try: + # Set allow_redirects=False to prevent SSRF attacks via redirection. + response = requests.get( + url, allow_redirects=False, timeout=_DEFAULT_TIMEOUT_SECONDS + ) + except requests.exceptions.Timeout: + return f'Request timed out while fetching url: {url}' + except requests.exceptions.ConnectionError: + return f'Connection error while fetching url: {url}' if response.status_code == 200: soup = BeautifulSoup(response.content, 'lxml') diff --git a/tests/unittests/tools/test_load_web_page.py b/tests/unittests/tools/test_load_web_page.py new file mode 100644 index 0000000000..a2a4724b02 --- /dev/null +++ b/tests/unittests/tools/test_load_web_page.py @@ -0,0 +1,144 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from google.adk.tools.load_web_page import load_web_page +import requests + + +def _mock_beautiful_soup(text='This is a test paragraph with enough words'): + """Create a mock BeautifulSoup class that returns the given text.""" + mock_soup = mock.Mock() + mock_soup.get_text.return_value = text + mock_cls = mock.Mock(return_value=mock_soup) + return mock_cls + + +class TestLoadWebPage: + + def test_invalid_scheme_file(self): + result = load_web_page('file:///etc/passwd') + assert 'Invalid URL scheme' in result + assert 'file' in result + + def test_invalid_scheme_ftp(self): + result = load_web_page('ftp://example.com/file') + assert 'Invalid URL scheme' in result + assert 'ftp' in result + + def test_invalid_scheme_empty(self): + result = load_web_page('not-a-url') + assert 'Invalid URL scheme' in result + + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_timeout_returns_error_message(self, mock_get): + mock_get.side_effect = requests.exceptions.Timeout() + result = load_web_page('https://example.com') + assert 'timed out' in result + + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_connection_error_returns_error_message(self, mock_get): + mock_get.side_effect = requests.exceptions.ConnectionError() + result = load_web_page('https://example.com') + assert 'Connection error' in result + + @mock.patch('builtins.__import__') + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_successful_request(self, mock_get, mock_import): + mock_soup_instance = mock.Mock() + mock_soup_instance.get_text.return_value = ( + 'This is a test paragraph with enough words' + ) + mock_bs_module = mock.Mock() + mock_bs_module.BeautifulSoup.return_value = mock_soup_instance + + original_import = ( + __builtins__.__import__ + if hasattr(__builtins__, '__import__') + else __import__ + ) + + def side_effect(name, *args, **kwargs): + if name == 'bs4': + return mock_bs_module + return original_import(name, *args, **kwargs) + + mock_import.side_effect = side_effect + + mock_response = mock.Mock() + mock_response.status_code = 200 + mock_response.content = ( + b'

This is a test paragraph with enough words

' + b'' + ) + mock_get.return_value = mock_response + + result = load_web_page('https://example.com') + + mock_get.assert_called_once_with( + 'https://example.com', allow_redirects=False, timeout=10 + ) + assert 'test paragraph' in result + + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_failed_request_non_200(self, mock_get): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + result = load_web_page('https://example.com/missing') + assert 'Failed to fetch url' in result + + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_timeout_parameter_passed(self, mock_get): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + load_web_page('http://example.com') + + _, kwargs = mock_get.call_args + assert kwargs['timeout'] == 10 + + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_allow_redirects_false(self, mock_get): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + load_web_page('https://example.com') + + _, kwargs = mock_get.call_args + assert kwargs['allow_redirects'] is False + + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_http_scheme_accepted(self, mock_get): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + result = load_web_page('http://example.com') + assert 'Invalid URL scheme' not in result + mock_get.assert_called_once() + + @mock.patch('google.adk.tools.load_web_page.requests.get') + def test_https_scheme_accepted(self, mock_get): + mock_response = mock.Mock() + mock_response.status_code = 404 + mock_get.return_value = mock_response + + result = load_web_page('https://example.com') + assert 'Invalid URL scheme' not in result + mock_get.assert_called_once()