diff --git a/google/cloud/aiplatform/utils/gcs_utils.py b/google/cloud/aiplatform/utils/gcs_utils.py index 7d5540e585..166d4518b0 100644 --- a/google/cloud/aiplatform/utils/gcs_utils.py +++ b/google/cloud/aiplatform/utils/gcs_utils.py @@ -63,6 +63,35 @@ _DEFAULT_STAGING_BUCKET_SALT = str(uuid.uuid4()) +def _verify_bucket_ownership( + bucket: storage.Bucket, + expected_project: str, + client: storage.Client, +) -> bool: + """Verifies that a GCS bucket belongs to the expected project. + + This check mitigates bucket squatting attacks where an attacker creates a + bucket with a predictable name in their own project before the victim does. + + Args: + bucket: The GCS bucket to verify. + expected_project: The project ID that should own the bucket. + client: Storage client instance. + + Returns: + True if the bucket belongs to the expected project, False otherwise. + """ + try: + bucket.reload(client=client) + bucket_project_number = str(bucket.project_number) + expected_project_number = str( + resource_manager_utils.get_project_number(expected_project) + ) + return bucket_project_number == expected_project_number + except Exception: + return False + + def blob_from_uri(uri: str, client: storage.Client) -> storage.Blob: """Create a Blob from a GCS URI, compatible with v2 and v3. @@ -169,7 +198,7 @@ def stage_local_data_in_gcs( location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, ) -> str: - """Stages a local data in GCS. + """Stages a local data in GCS. The file copied to GCS is the name of the local file prepended with an "aiplatform-{timestamp}-" string. @@ -190,57 +219,70 @@ def stage_local_data_in_gcs( RuntimeError: When source_path does not exist. GoogleCloudError: When the upload process fails. """ - data_path_obj = pathlib.Path(data_path) - - if not data_path_obj.exists(): - raise RuntimeError(f"Local data does not exist: data_path='{data_path}'") - - staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket - if not staging_gcs_dir: - project = project or initializer.global_config.project - location = location or initializer.global_config.location - credentials = credentials or initializer.global_config.credentials - # Creating the bucket if it does not exist. - # Currently we only do this when staging_gcs_dir is not specified. - # The buckets that we create are regional. - # This prevents errors when some service required regional bucket. - # E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...` - # is in location `us`. It must be in the same regional location as the - # service location `us-central1`." - # We are making the bucket name region-specific since the bucket is - # regional. - staging_bucket_name = ( + data_path_obj = pathlib.Path(data_path) + + if not data_path_obj.exists(): + raise RuntimeError(f"Local data does not exist: data_path='{data_path}'") + + staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket + if not staging_gcs_dir: + project = project or initializer.global_config.project + location = location or initializer.global_config.location + credentials = credentials or initializer.global_config.credentials + # Creating the bucket if it does not exist. + # Currently we only do this when staging_gcs_dir is not specified. + # The buckets that we create are regional. + # This prevents errors when some service required regional bucket. + # E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...` + # is in location `us`. It must be in the same regional location as the + # service location `us-central1`." + # We are making the bucket name region-specific since the bucket is + # regional. + staging_bucket_name = ( project + "-vertex-staging-" + location + "-" + _DEFAULT_STAGING_BUCKET_SALT )[:63] - client = storage.Client(project=project, credentials=credentials) - staging_bucket = storage.Bucket(client=client, name=staging_bucket_name) - if not staging_bucket.exists(): - _logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"') - staging_bucket = client.create_bucket( + client = storage.Client(project=project, credentials=credentials) + staging_bucket = storage.Bucket(client=client, name=staging_bucket_name) + if not staging_bucket.exists(): + _logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"') + staging_bucket = client.create_bucket( bucket_or_name=staging_bucket, project=project, location=location, ) - staging_gcs_dir = "gs://" + staging_bucket_name + else: + # Verify bucket ownership to prevent bucket squatting attacks. + # See b/469987320 for details. + if not _verify_bucket_ownership(staging_bucket, project, client): + raise ValueError( + f'Staging bucket "{staging_bucket_name}" exists but does ' + f'not belong to project "{project}". This may indicate a ' + f"bucket squatting attack. Please provide an explicit " + f"staging_bucket parameter or configure one via " + f"aiplatform.init(staging_bucket='gs://your-bucket')." + ) + staging_gcs_dir = "gs://" + staging_bucket_name - timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds") - staging_gcs_subdir = ( - staging_gcs_dir.rstrip("/") + "/vertex_ai_auto_staging/" + timestamp - ) + timestamp = datetime.datetime.now().isoformat( + sep="-", timespec="milliseconds" + ) + staging_gcs_subdir = ( + staging_gcs_dir.rstrip("/") + "/vertex_ai_auto_staging/" + timestamp + ) - staged_data_uri = staging_gcs_subdir - if data_path_obj.is_file(): - staged_data_uri = staging_gcs_subdir + "/" + data_path_obj.name + staged_data_uri = staging_gcs_subdir + if data_path_obj.is_file(): + staged_data_uri = staging_gcs_subdir + "/" + data_path_obj.name - _logger.info(f'Uploading "{data_path}" to "{staged_data_uri}"') - upload_to_gcs( + _logger.info(f'Uploading "{data_path}" to "{staged_data_uri}"') + upload_to_gcs( source_path=data_path, destination_uri=staged_data_uri, project=project, credentials=credentials, ) - return staged_data_uri + return staged_data_uri def generate_gcs_directory_for_pipeline_artifacts( diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 3bf19eac1c..a3eef5b0d1 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -668,6 +668,83 @@ def test_validate_gcs_path(self): with pytest.raises(ValueError, match=err_msg): gcs_utils.validate_gcs_path(test_invalid_path) + @patch.object( + gcs_utils.resource_manager_utils, + "get_project_number", + return_value=12345, + ) + @patch.object(storage.Bucket, "reload") + def test_verify_bucket_ownership_matching_project( + self, mock_reload, mock_get_project_number + ): + mock_client = mock.MagicMock(spec=storage.Client) + mock_bucket = mock.MagicMock(spec=storage.Bucket) + mock_bucket.project_number = 12345 + assert gcs_utils._verify_bucket_ownership( + mock_bucket, "test-project", mock_client + ) + + @patch.object( + gcs_utils.resource_manager_utils, + "get_project_number", + return_value=12345, + ) + @patch.object(storage.Bucket, "reload") + def test_verify_bucket_ownership_different_project( + self, mock_reload, mock_get_project_number + ): + mock_client = mock.MagicMock(spec=storage.Client) + mock_bucket = mock.MagicMock(spec=storage.Bucket) + mock_bucket.project_number = 99999 + assert not gcs_utils._verify_bucket_ownership( + mock_bucket, "test-project", mock_client + ) + + @patch.object(storage.Bucket, "exists", return_value=True) + @patch.object(storage, "Client") + @patch.object( + gcs_utils, "_verify_bucket_ownership", return_value=False + ) + def test_stage_local_data_in_gcs_rejects_squatted_bucket( + self, mock_verify, mock_storage_client, mock_bucket_exists, json_file + ): + mock_config = mock.MagicMock() + mock_config.project = "victim-project" + mock_config.location = "us-central1" + mock_config.staging_bucket = None + mock_config.credentials = None + with patch.object(gcs_utils.initializer, "global_config", mock_config): + with pytest.raises( + ValueError, + match="bucket squatting", + ): + gcs_utils.stage_local_data_in_gcs(json_file) + + @patch.object(storage.Bucket, "exists", return_value=True) + @patch.object(storage, "Client") + @patch.object( + gcs_utils, "_verify_bucket_ownership", return_value=True + ) + @patch("google.cloud.storage.Blob.upload_from_filename") + def test_stage_local_data_in_gcs_accepts_owned_bucket( + self, + mock_upload, + mock_verify, + mock_storage_client, + mock_bucket_exists, + json_file, + mock_datetime, + ): + mock_config = mock.MagicMock() + mock_config.project = "my-project" + mock_config.location = "us-central1" + mock_config.staging_bucket = None + mock_config.credentials = None + with patch.object(gcs_utils.initializer, "global_config", mock_config): + result = gcs_utils.stage_local_data_in_gcs(json_file) + assert result.startswith("gs://") + mock_verify.assert_called_once() + class TestPipelineUtils: SAMPLE_JOB_SPEC = {