Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 79 additions & 37 deletions google/cloud/aiplatform/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,35 @@
_DEFAULT_STAGING_BUCKET_SALT = str(uuid.uuid4())


def _verify_bucket_ownership(
bucket: storage.Bucket,
expected_project: str,
client: storage.Client,
) -> bool:
"""Verifies that a GCS bucket belongs to the expected project.

This check mitigates bucket squatting attacks where an attacker creates a
bucket with a predictable name in their own project before the victim does.

Args:
bucket: The GCS bucket to verify.
expected_project: The project ID that should own the bucket.
client: Storage client instance.

Returns:
True if the bucket belongs to the expected project, False otherwise.
"""
try:
bucket.reload(client=client)
bucket_project_number = str(bucket.project_number)
expected_project_number = str(
resource_manager_utils.get_project_number(expected_project)
)
return bucket_project_number == expected_project_number
except Exception:
return False


def blob_from_uri(uri: str, client: storage.Client) -> storage.Blob:
"""Create a Blob from a GCS URI, compatible with v2 and v3.

Expand Down Expand Up @@ -169,7 +198,7 @@ def stage_local_data_in_gcs(
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
) -> str:
"""Stages a local data in GCS.
"""Stages a local data in GCS.

The file copied to GCS is the name of the local file prepended with an
"aiplatform-{timestamp}-" string.
Expand All @@ -190,57 +219,70 @@ def stage_local_data_in_gcs(
RuntimeError: When source_path does not exist.
GoogleCloudError: When the upload process fails.
"""
data_path_obj = pathlib.Path(data_path)

if not data_path_obj.exists():
raise RuntimeError(f"Local data does not exist: data_path='{data_path}'")

staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket
if not staging_gcs_dir:
project = project or initializer.global_config.project
location = location or initializer.global_config.location
credentials = credentials or initializer.global_config.credentials
# Creating the bucket if it does not exist.
# Currently we only do this when staging_gcs_dir is not specified.
# The buckets that we create are regional.
# This prevents errors when some service required regional bucket.
# E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...`
# is in location `us`. It must be in the same regional location as the
# service location `us-central1`."
# We are making the bucket name region-specific since the bucket is
# regional.
staging_bucket_name = (
data_path_obj = pathlib.Path(data_path)

if not data_path_obj.exists():
raise RuntimeError(f"Local data does not exist: data_path='{data_path}'")

staging_gcs_dir = staging_gcs_dir or initializer.global_config.staging_bucket
if not staging_gcs_dir:
project = project or initializer.global_config.project
location = location or initializer.global_config.location
credentials = credentials or initializer.global_config.credentials
# Creating the bucket if it does not exist.
# Currently we only do this when staging_gcs_dir is not specified.
# The buckets that we create are regional.
# This prevents errors when some service required regional bucket.
# E.g. "FailedPrecondition: 400 The Cloud Storage bucket of `gs://...`
# is in location `us`. It must be in the same regional location as the
# service location `us-central1`."
# We are making the bucket name region-specific since the bucket is
# regional.
staging_bucket_name = (
project + "-vertex-staging-" + location + "-" + _DEFAULT_STAGING_BUCKET_SALT
)[:63]
client = storage.Client(project=project, credentials=credentials)
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
if not staging_bucket.exists():
_logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"')
staging_bucket = client.create_bucket(
client = storage.Client(project=project, credentials=credentials)
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
if not staging_bucket.exists():
_logger.info(f'Creating staging GCS bucket "{staging_bucket_name}"')
staging_bucket = client.create_bucket(
bucket_or_name=staging_bucket,
project=project,
location=location,
)
staging_gcs_dir = "gs://" + staging_bucket_name
else:
# Verify bucket ownership to prevent bucket squatting attacks.
# See b/469987320 for details.
if not _verify_bucket_ownership(staging_bucket, project, client):
raise ValueError(
f'Staging bucket "{staging_bucket_name}" exists but does '
f'not belong to project "{project}". This may indicate a '
f"bucket squatting attack. Please provide an explicit "
f"staging_bucket parameter or configure one via "
f"aiplatform.init(staging_bucket='gs://your-bucket')."
)
staging_gcs_dir = "gs://" + staging_bucket_name

timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds")
staging_gcs_subdir = (
staging_gcs_dir.rstrip("/") + "/vertex_ai_auto_staging/" + timestamp
)
timestamp = datetime.datetime.now().isoformat(
sep="-", timespec="milliseconds"
)
staging_gcs_subdir = (
staging_gcs_dir.rstrip("/") + "/vertex_ai_auto_staging/" + timestamp
)

staged_data_uri = staging_gcs_subdir
if data_path_obj.is_file():
staged_data_uri = staging_gcs_subdir + "/" + data_path_obj.name
staged_data_uri = staging_gcs_subdir
if data_path_obj.is_file():
staged_data_uri = staging_gcs_subdir + "/" + data_path_obj.name

_logger.info(f'Uploading "{data_path}" to "{staged_data_uri}"')
upload_to_gcs(
_logger.info(f'Uploading "{data_path}" to "{staged_data_uri}"')
upload_to_gcs(
source_path=data_path,
destination_uri=staged_data_uri,
project=project,
credentials=credentials,
)

return staged_data_uri
return staged_data_uri


def generate_gcs_directory_for_pipeline_artifacts(
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,83 @@ def test_validate_gcs_path(self):
with pytest.raises(ValueError, match=err_msg):
gcs_utils.validate_gcs_path(test_invalid_path)

@patch.object(
gcs_utils.resource_manager_utils,
"get_project_number",
return_value=12345,
)
@patch.object(storage.Bucket, "reload")
def test_verify_bucket_ownership_matching_project(
self, mock_reload, mock_get_project_number
):
mock_client = mock.MagicMock(spec=storage.Client)
mock_bucket = mock.MagicMock(spec=storage.Bucket)
mock_bucket.project_number = 12345
assert gcs_utils._verify_bucket_ownership(
mock_bucket, "test-project", mock_client
)

@patch.object(
gcs_utils.resource_manager_utils,
"get_project_number",
return_value=12345,
)
@patch.object(storage.Bucket, "reload")
def test_verify_bucket_ownership_different_project(
self, mock_reload, mock_get_project_number
):
mock_client = mock.MagicMock(spec=storage.Client)
mock_bucket = mock.MagicMock(spec=storage.Bucket)
mock_bucket.project_number = 99999
assert not gcs_utils._verify_bucket_ownership(
mock_bucket, "test-project", mock_client
)

@patch.object(storage.Bucket, "exists", return_value=True)
@patch.object(storage, "Client")
@patch.object(
gcs_utils, "_verify_bucket_ownership", return_value=False
)
def test_stage_local_data_in_gcs_rejects_squatted_bucket(
self, mock_verify, mock_storage_client, mock_bucket_exists, json_file
):
mock_config = mock.MagicMock()
mock_config.project = "victim-project"
mock_config.location = "us-central1"
mock_config.staging_bucket = None
mock_config.credentials = None
with patch.object(gcs_utils.initializer, "global_config", mock_config):
with pytest.raises(
ValueError,
match="bucket squatting",
):
gcs_utils.stage_local_data_in_gcs(json_file)

@patch.object(storage.Bucket, "exists", return_value=True)
@patch.object(storage, "Client")
@patch.object(
gcs_utils, "_verify_bucket_ownership", return_value=True
)
@patch("google.cloud.storage.Blob.upload_from_filename")
def test_stage_local_data_in_gcs_accepts_owned_bucket(
self,
mock_upload,
mock_verify,
mock_storage_client,
mock_bucket_exists,
json_file,
mock_datetime,
):
mock_config = mock.MagicMock()
mock_config.project = "my-project"
mock_config.location = "us-central1"
mock_config.staging_bucket = None
mock_config.credentials = None
with patch.object(gcs_utils.initializer, "global_config", mock_config):
result = gcs_utils.stage_local_data_in_gcs(json_file)
assert result.startswith("gs://")
mock_verify.assert_called_once()


class TestPipelineUtils:
SAMPLE_JOB_SPEC = {
Expand Down
Loading