Skip to content

Commit ee1c26c

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
todo archit
feat: TPC PiperOrigin-RevId: 877503411
1 parent 6b5cc8f commit ee1c26c

File tree

5 files changed

+202
-56
lines changed

5 files changed

+202
-56
lines changed

google/cloud/aiplatform/constants/base.py

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,52 @@
1919

2020

2121
DEFAULT_REGION = "us-central1"
22-
SUPPORTED_REGIONS = frozenset(
23-
{
24-
"africa-south1",
25-
"asia-east1",
26-
"asia-east2",
27-
"asia-northeast1",
28-
"asia-northeast2",
29-
"asia-northeast3",
30-
"asia-south1",
31-
"asia-south2",
32-
"asia-southeast1",
33-
"asia-southeast2",
34-
"australia-southeast1",
35-
"australia-southeast2",
36-
"europe-central2",
37-
"europe-north1",
38-
"europe-north2",
39-
"europe-southwest1",
40-
"europe-west1",
41-
"europe-west2",
42-
"europe-west3",
43-
"europe-west4",
44-
"europe-west6",
45-
"europe-west8",
46-
"europe-west9",
47-
"europe-west12",
48-
"global",
49-
"me-central1",
50-
"me-central2",
51-
"me-west1",
52-
"northamerica-northeast1",
53-
"northamerica-northeast2",
54-
"southamerica-east1",
55-
"southamerica-west1",
56-
"us-central1",
57-
"us-east1",
58-
"us-east4",
59-
"us-east5",
60-
"us-east7",
61-
"us-south1",
62-
"us-west1",
63-
"us-west2",
64-
"us-west3",
65-
"us-west4",
66-
"us-west8",
67-
}
68-
)
22+
SUPPORTED_REGIONS = frozenset({
23+
"africa-south1",
24+
"asia-east1",
25+
"asia-east2",
26+
"asia-northeast1",
27+
"asia-northeast2",
28+
"asia-northeast3",
29+
"asia-south1",
30+
"asia-south2",
31+
"asia-southeast1",
32+
"asia-southeast2",
33+
"australia-southeast1",
34+
"australia-southeast2",
35+
"europe-central2",
36+
"europe-north1",
37+
"europe-north2",
38+
"europe-southwest1",
39+
"europe-west1",
40+
"europe-west2",
41+
"europe-west3",
42+
"europe-west4",
43+
"europe-west6",
44+
"europe-west8",
45+
"europe-west9",
46+
"europe-west12",
47+
"global",
48+
"me-central1",
49+
"me-central2",
50+
"me-west1",
51+
"northamerica-northeast1",
52+
"northamerica-northeast2",
53+
"southamerica-east1",
54+
"southamerica-west1",
55+
"us-central1",
56+
"us-east1",
57+
"us-east4",
58+
"us-east5",
59+
"us-east7",
60+
"us-south1",
61+
"us-west1",
62+
"us-west2",
63+
"us-west3",
64+
"us-west4",
65+
"us-west8",
66+
"u-us-prp1",
67+
})
6968

7069
API_BASE_PATH = "aiplatform.googleapis.com"
7170
PREDICTION_API_BASE_PATH = API_BASE_PATH

google/cloud/aiplatform/initializer.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(self):
133133
self._request_metadata = None
134134
self._resource_type = None
135135
self._async_rest_credentials = None
136+
self._universe_domain = None
136137

137138
def init(
138139
self,
@@ -153,6 +154,7 @@ def init(
153154
api_key: Optional[str] = None,
154155
api_transport: Optional[str] = None,
155156
request_metadata: Optional[Sequence[Tuple[str, str]]] = None,
157+
universe_domain: Optional[str] = None,
156158
):
157159
"""Updates common initialization parameters with provided options.
158160
@@ -220,6 +222,8 @@ def init(
220222
beta state (preview).
221223
request_metadata:
222224
Optional. Additional gRPC metadata to send with every client request.
225+
universe_domain (str):
226+
Optional. The universe domain.
223227
Raises:
224228
ValueError:
225229
If experiment_description is provided but experiment is not.
@@ -291,6 +295,8 @@ def init(
291295
self._request_metadata = request_metadata
292296
if api_key is not None:
293297
self._api_key = api_key
298+
if universe_domain is not None:
299+
self._universe_domain = universe_domain
294300
self._resource_type = None
295301

296302
# Finally, perform secondary state updates
@@ -348,6 +354,11 @@ def api_key(self) -> Optional[str]:
348354
"""API Key, if provided."""
349355
return self._api_key
350356

357+
@property
358+
def universe_domain(self) -> Optional[str]:
359+
"""Default universe domain, if provided."""
360+
return self._universe_domain
361+
351362
@property
352363
def project(self) -> str:
353364
"""Default project."""
@@ -382,7 +393,11 @@ def location(self) -> str:
382393

383394
location = os.getenv("GOOGLE_CLOUD_REGION") or os.getenv("CLOUD_ML_REGION")
384395
if location:
385-
utils.validate_region(location)
396+
utils.validate_region(
397+
location,
398+
api_endpoint=self.api_endpoint,
399+
universe_domain=self.universe_domain,
400+
)
386401
return location
387402

388403
return constants.DEFAULT_REGION
@@ -449,6 +464,7 @@ def get_client_options(
449464
api_base_path_override: Optional[str] = None,
450465
api_key: Optional[str] = None,
451466
api_path_override: Optional[str] = None,
467+
universe_domain: Optional[str] = None,
452468
) -> client_options.ClientOptions:
453469
"""Creates GAPIC client_options using location and type.
454470
@@ -461,6 +477,7 @@ def get_client_options(
461477
api_base_path_override (str): Optional. Override default API base path.
462478
api_key (str): Optional. API key to use for the client.
463479
api_path_override (str): Optional. Override default api path.
480+
universe_domain (str): Optional. Override default universe domain.
464481
Returns:
465482
clients_options (google.api_core.client_options.ClientOptions):
466483
A ClientOptions object set with regionalized API endpoint, i.e.
@@ -491,7 +508,11 @@ def get_client_options(
491508
region = location_override or self.location
492509
region = region.lower()
493510

494-
utils.validate_region(region)
511+
utils.validate_region(
512+
region,
513+
api_endpoint=self.api_endpoint,
514+
universe_domain=universe_domain or self.universe_domain,
515+
)
495516

496517
service_base_path = api_base_path_override or (
497518
constants.PREDICTION_API_BASE_PATH
@@ -508,9 +529,14 @@ def get_client_options(
508529
# Project/location take precedence over api_key
509530
if api_key and not self._project:
510531
return client_options.ClientOptions(
511-
api_endpoint=api_endpoint, api_key=api_key
532+
api_endpoint=api_endpoint,
533+
api_key=api_key,
534+
universe_domain=universe_domain or self.universe_domain,
512535
)
513-
return client_options.ClientOptions(api_endpoint=api_endpoint)
536+
return client_options.ClientOptions(
537+
api_endpoint=api_endpoint,
538+
universe_domain=universe_domain or self.universe_domain,
539+
)
514540

515541
def common_location_path(
516542
self, project: Optional[str] = None, location: Optional[str] = None
@@ -524,7 +550,11 @@ def common_location_path(
524550
resource_parent: Formatted parent resource string.
525551
"""
526552
if location:
527-
utils.validate_region(location)
553+
utils.validate_region(
554+
location,
555+
api_endpoint=self.api_endpoint,
556+
universe_domain=self.universe_domain,
557+
)
528558

529559
return "/".join(
530560
[
@@ -546,6 +576,7 @@ def create_client(
546576
api_path_override: Optional[str] = None,
547577
appended_user_agent: Optional[List[str]] = None,
548578
appended_gapic_version: Optional[str] = None,
579+
universe_domain: Optional[str] = None,
549580
) -> _TVertexAiServiceClientWithOverride:
550581
"""Instantiates a given VertexAiServiceClient with optional
551582
overrides.
@@ -565,6 +596,8 @@ def create_client(
565596
separated by spaces.
566597
appended_gapic_version (str):
567598
Optional. GAPIC version suffix appended in the client info.
599+
universe_domain (str):
600+
Optional. universe domain override.
568601
Returns:
569602
client: Instantiated Vertex AI Service client with optional overrides
570603
"""
@@ -607,6 +640,7 @@ def create_client(
607640
api_key=api_key,
608641
api_base_path_override=api_base_path_override,
609642
api_path_override=api_path_override,
643+
universe_domain=universe_domain,
610644
),
611645
"client_info": client_info,
612646
}

google/cloud/aiplatform/utils/__init__.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,18 @@ def validate_labels(labels: Dict[str, str]):
299299
)
300300

301301

302-
def validate_region(region: str) -> bool:
302+
def validate_region(
303+
region: str,
304+
api_endpoint: Optional[str] = None,
305+
universe_domain: Optional[str] = None,
306+
) -> bool:
303307
"""Validates region against supported regions.
304308
305309
Args:
306310
region: region to validate
311+
api_endpoint: Optional API endpoint.
312+
universe_domain: Optional universe domain.
313+
307314
Returns:
308315
bool: True if no errors raised
309316
Raises:
@@ -316,9 +323,16 @@ def validate_region(region: str) -> bool:
316323

317324
region = region.lower()
318325
if region not in constants.SUPPORTED_REGIONS:
319-
raise ValueError(
320-
f"Unsupported region for Vertex AI, select from {constants.SUPPORTED_REGIONS}"
321-
)
326+
if not (
327+
api_endpoint
328+
or universe_domain
329+
or initializer.global_config.api_endpoint
330+
or initializer.global_config.universe_domain
331+
):
332+
raise ValueError(
333+
"Unsupported region for Vertex AI, select from"
334+
f" {constants.SUPPORTED_REGIONS}"
335+
)
322336

323337
return True
324338

tests/unit/aiplatform/test_initializer.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
_TEST_NETWORK = "projects/12345/global/networks/myVPC"
4848
_TEST_SERVICE_ACCOUNT = "test-service-account@test-project.iam.gserviceaccount.com"
4949

50+
_TEST_LOCATION_TPC = "u-us-prp1"
51+
_TEST_ENDPOINT_TPC = "u-us-prp1-aiplatform.apis-tpczero.goog"
52+
_TEST_UNIVERSE_TPC = "apis-tpczero.goog"
53+
5054
# tensorboard
5155
_TEST_TENSORBOARD_ID = "1028944691210842416"
5256
_TEST_TENSORBOARD_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/tensorboards/{_TEST_TENSORBOARD_ID}"
@@ -547,6 +551,71 @@ def test_create_client_with_request_metadata_prediction_service(self):
547551
for metadata_key in ["global_param", "request_param"]:
548552
assert metadata_key in headers
549553

554+
def test_init_with_universe_domain(self):
555+
"""Test that aiplatform.init supports universe_domain."""
556+
initializer.global_config.init(
557+
project=_TEST_PROJECT, universe_domain=_TEST_UNIVERSE_TPC
558+
)
559+
assert initializer.global_config.universe_domain == _TEST_UNIVERSE_TPC
560+
561+
def test_get_client_options_with_universe_domain(self):
562+
"""Test that ClientOptions correctly inherits universe_domain."""
563+
initializer.global_config.init(
564+
project=_TEST_PROJECT,
565+
location="us-central1",
566+
universe_domain=_TEST_UNIVERSE_TPC,
567+
)
568+
client_options = initializer.global_config.get_client_options()
569+
assert client_options.universe_domain == _TEST_UNIVERSE_TPC
570+
571+
def test_init_tpc_success(self):
572+
"""Test that aiplatform.init succeeds with TPC parameters."""
573+
# Should NOT raise ValueError
574+
initializer.global_config.init(
575+
project=_TEST_PROJECT,
576+
location=_TEST_LOCATION_TPC,
577+
api_endpoint=_TEST_ENDPOINT_TPC,
578+
universe_domain=_TEST_UNIVERSE_TPC,
579+
)
580+
assert initializer.global_config.location == _TEST_LOCATION_TPC
581+
582+
def test_get_impersonated_credentials_tpc(self):
583+
"""Test impersonated credentials for TPC."""
584+
info = {
585+
"type": "impersonated_service_account",
586+
"source_credentials": {
587+
"type": "external_account_authorized_user",
588+
"token_url": "https://sts.apis-tpczero.goog/v1/oauth/token",
589+
},
590+
"service_account_impersonation_url": (
591+
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/test-sa@test.iam.gserviceaccount.com:generateAccessToken"
592+
),
593+
"universe_domain": _TEST_UNIVERSE_TPC,
594+
}
595+
596+
with mock.patch(
597+
"google.auth._default._get_external_account_authorized_user_credentials"
598+
) as mock_get_ext:
599+
mock_source_creds = mock.Mock()
600+
mock_get_ext.return_value = (mock_source_creds, None)
601+
602+
with mock.patch(
603+
"google.auth.impersonated_credentials.Credentials"
604+
) as mock_imp_creds:
605+
# pylint: disable=protected-access
606+
google.auth._default._get_impersonated_service_account_credentials(
607+
"fake-file", info, ["scope"]
608+
)
609+
610+
# Verify that iam_endpoint_override was set correctly for TPC
611+
_, kwargs = mock_imp_creds.call_args
612+
target_principal = "test-sa@test.iam.gserviceaccount.com"
613+
expected_iam_endpoint = (
614+
f"https://iamcredentials.{_TEST_UNIVERSE_TPC}/v1/projects/"
615+
f"-/serviceAccounts/{target_principal}:generateAccessToken"
616+
)
617+
assert kwargs["iam_endpoint_override"] == expected_iam_endpoint
618+
550619

551620
class TestThreadPool:
552621
def teardown_method(self):

0 commit comments

Comments
 (0)