Skip to content

Commit 9e0bae6

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
todo:archit
feat: One-line summary of change for external release notes. PiperOrigin-RevId: 877503411
1 parent fba5350 commit 9e0bae6

6 files changed

Lines changed: 275 additions & 56 deletions

File tree

google/cloud/aiplatform/constants/base.py

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,53 +19,53 @@
1919

2020

2121
DEFAULT_REGION = "us-central1"
22-
SUPPORTED_REGIONS = frozenset(
23-
{
24-
"africa-south1",
25-
"asia-east1",
26-
"asia-east2",
27-
"asia-northeast1",
28-
"asia-northeast2",
29-
"asia-northeast3",
30-
"asia-south1",
31-
"asia-south2",
32-
"asia-southeast1",
33-
"asia-southeast2",
34-
"australia-southeast1",
35-
"australia-southeast2",
36-
"europe-central2",
37-
"europe-north1",
38-
"europe-north2",
39-
"europe-southwest1",
40-
"europe-west1",
41-
"europe-west2",
42-
"europe-west3",
43-
"europe-west4",
44-
"europe-west6",
45-
"europe-west8",
46-
"europe-west9",
47-
"europe-west12",
48-
"global",
49-
"me-central1",
50-
"me-central2",
51-
"me-west1",
52-
"northamerica-northeast1",
53-
"northamerica-northeast2",
54-
"southamerica-east1",
55-
"southamerica-west1",
56-
"us-central1",
57-
"us-east1",
58-
"us-east4",
59-
"us-east5",
60-
"us-east7",
61-
"us-south1",
62-
"us-west1",
63-
"us-west2",
64-
"us-west3",
65-
"us-west4",
66-
"us-west8",
67-
}
68-
)
22+
DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
23+
SUPPORTED_REGIONS = frozenset({
24+
"africa-south1",
25+
"asia-east1",
26+
"asia-east2",
27+
"asia-northeast1",
28+
"asia-northeast2",
29+
"asia-northeast3",
30+
"asia-south1",
31+
"asia-south2",
32+
"asia-southeast1",
33+
"asia-southeast2",
34+
"australia-southeast1",
35+
"australia-southeast2",
36+
"europe-central2",
37+
"europe-north1",
38+
"europe-north2",
39+
"europe-southwest1",
40+
"europe-west1",
41+
"europe-west2",
42+
"europe-west3",
43+
"europe-west4",
44+
"europe-west6",
45+
"europe-west8",
46+
"europe-west9",
47+
"europe-west12",
48+
"global",
49+
"me-central1",
50+
"me-central2",
51+
"me-west1",
52+
"northamerica-northeast1",
53+
"northamerica-northeast2",
54+
"southamerica-east1",
55+
"southamerica-west1",
56+
"us-central1",
57+
"us-east1",
58+
"us-east4",
59+
"us-east5",
60+
"us-east7",
61+
"us-south1",
62+
"us-west1",
63+
"us-west2",
64+
"us-west3",
65+
"us-west4",
66+
"us-west8",
67+
"u-us-prp1",
68+
})
6969

7070
API_BASE_PATH = "aiplatform.googleapis.com"
7171
PREDICTION_API_BASE_PATH = API_BASE_PATH

google/cloud/aiplatform/initializer.py

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(self):
133133
self._request_metadata = None
134134
self._resource_type = None
135135
self._async_rest_credentials = None
136+
self._universe_domain = None
136137

137138
def init(
138139
self,
@@ -153,6 +154,7 @@ def init(
153154
api_key: Optional[str] = None,
154155
api_transport: Optional[str] = None,
155156
request_metadata: Optional[Sequence[Tuple[str, str]]] = None,
157+
universe_domain: Optional[str] = None,
156158
):
157159
"""Updates common initialization parameters with provided options.
158160
@@ -220,6 +222,8 @@ def init(
220222
beta state (preview).
221223
request_metadata:
222224
Optional. Additional gRPC metadata to send with every client request.
225+
universe_domain (str):
226+
Optional. The universe domain.
223227
Raises:
224228
ValueError:
225229
If experiment_description is provided but experiment is not.
@@ -291,6 +295,8 @@ def init(
291295
self._request_metadata = request_metadata
292296
if api_key is not None:
293297
self._api_key = api_key
298+
if universe_domain is not None:
299+
self._universe_domain = universe_domain
294300
self._resource_type = None
295301

296302
# Finally, perform secondary state updates
@@ -348,6 +354,11 @@ def api_key(self) -> Optional[str]:
348354
"""API Key, if provided."""
349355
return self._api_key
350356

357+
@property
358+
def universe_domain(self) -> Optional[str]:
359+
"""Default universe domain, if provided."""
360+
return self._universe_domain
361+
351362
@property
352363
def project(self) -> str:
353364
"""Default project."""
@@ -382,7 +393,11 @@ def location(self) -> str:
382393

383394
location = os.getenv("GOOGLE_CLOUD_REGION") or os.getenv("CLOUD_ML_REGION")
384395
if location:
385-
utils.validate_region(location)
396+
utils.validate_region(
397+
location,
398+
api_endpoint=self.api_endpoint,
399+
universe_domain=self.universe_domain,
400+
)
386401
return location
387402

388403
return constants.DEFAULT_REGION
@@ -449,6 +464,7 @@ def get_client_options(
449464
api_base_path_override: Optional[str] = None,
450465
api_key: Optional[str] = None,
451466
api_path_override: Optional[str] = None,
467+
universe_domain: Optional[str] = None,
452468
) -> client_options.ClientOptions:
453469
"""Creates GAPIC client_options using location and type.
454470
@@ -461,6 +477,7 @@ def get_client_options(
461477
api_base_path_override (str): Optional. Override default API base path.
462478
api_key (str): Optional. API key to use for the client.
463479
api_path_override (str): Optional. Override default api path.
480+
universe_domain (str): Optional. Override default universe domain.
464481
Returns:
465482
clients_options (google.api_core.client_options.ClientOptions):
466483
A ClientOptions object set with regionalized API endpoint, i.e.
@@ -491,14 +508,29 @@ def get_client_options(
491508
region = location_override or self.location
492509
region = region.lower()
493510

494-
utils.validate_region(region)
511+
utils.validate_region(
512+
region,
513+
api_endpoint=self.api_endpoint,
514+
universe_domain=universe_domain or self.universe_domain,
515+
)
495516

496517
service_base_path = api_base_path_override or (
497518
constants.PREDICTION_API_BASE_PATH
498519
if prediction_client
499520
else constants.API_BASE_PATH
500521
)
501522

523+
current_universe_domain = (
524+
universe_domain
525+
or self.universe_domain
526+
or constants.DEFAULT_UNIVERSE_DOMAIN
527+
)
528+
529+
if not api_base_path_override and current_universe_domain != "googleapis.com":
530+
service_base_path = service_base_path.replace(
531+
"googleapis.com", current_universe_domain
532+
)
533+
502534
api_endpoint = (
503535
f"{region}-{service_base_path}"
504536
if not api_path_override
@@ -508,9 +540,14 @@ def get_client_options(
508540
# Project/location take precedence over api_key
509541
if api_key and not self._project:
510542
return client_options.ClientOptions(
511-
api_endpoint=api_endpoint, api_key=api_key
543+
api_endpoint=api_endpoint,
544+
api_key=api_key,
545+
universe_domain=universe_domain or self.universe_domain,
512546
)
513-
return client_options.ClientOptions(api_endpoint=api_endpoint)
547+
return client_options.ClientOptions(
548+
api_endpoint=api_endpoint,
549+
universe_domain=universe_domain or self.universe_domain,
550+
)
514551

515552
def common_location_path(
516553
self, project: Optional[str] = None, location: Optional[str] = None
@@ -524,7 +561,11 @@ def common_location_path(
524561
resource_parent: Formatted parent resource string.
525562
"""
526563
if location:
527-
utils.validate_region(location)
564+
utils.validate_region(
565+
location,
566+
api_endpoint=self.api_endpoint,
567+
universe_domain=self.universe_domain,
568+
)
528569

529570
return "/".join(
530571
[
@@ -546,6 +587,7 @@ def create_client(
546587
api_path_override: Optional[str] = None,
547588
appended_user_agent: Optional[List[str]] = None,
548589
appended_gapic_version: Optional[str] = None,
590+
universe_domain: Optional[str] = None,
549591
) -> _TVertexAiServiceClientWithOverride:
550592
"""Instantiates a given VertexAiServiceClient with optional
551593
overrides.
@@ -565,6 +607,8 @@ def create_client(
565607
separated by spaces.
566608
appended_gapic_version (str):
567609
Optional. GAPIC version suffix appended in the client info.
610+
universe_domain (str):
611+
Optional. universe domain override.
568612
Returns:
569613
client: Instantiated Vertex AI Service client with optional overrides
570614
"""
@@ -607,6 +651,7 @@ def create_client(
607651
api_key=api_key,
608652
api_base_path_override=api_base_path_override,
609653
api_path_override=api_path_override,
654+
universe_domain=universe_domain,
610655
),
611656
"client_info": client_info,
612657
}

google/cloud/aiplatform/utils/__init__.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,18 @@ def validate_labels(labels: Dict[str, str]):
299299
)
300300

301301

302-
def validate_region(region: str) -> bool:
302+
def validate_region(
303+
region: str,
304+
api_endpoint: Optional[str] = None,
305+
universe_domain: Optional[str] = None,
306+
) -> bool:
303307
"""Validates region against supported regions.
304308
305309
Args:
306310
region: region to validate
311+
api_endpoint: Optional API endpoint.
312+
universe_domain: Optional universe domain.
313+
307314
Returns:
308315
bool: True if no errors raised
309316
Raises:
@@ -316,9 +323,16 @@ def validate_region(region: str) -> bool:
316323

317324
region = region.lower()
318325
if region not in constants.SUPPORTED_REGIONS:
319-
raise ValueError(
320-
f"Unsupported region for Vertex AI, select from {constants.SUPPORTED_REGIONS}"
321-
)
326+
if not (
327+
api_endpoint
328+
or universe_domain
329+
or initializer.global_config.api_endpoint
330+
or initializer.global_config.universe_domain
331+
):
332+
raise ValueError(
333+
"Unsupported region for Vertex AI, select from"
334+
f" {constants.SUPPORTED_REGIONS}"
335+
)
322336

323337
return True
324338

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# -*- coding: utf-8 -*-
2+
3+
# Copyright 2024 Google LLC
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
from google.cloud import aiplatform
19+
from tests.system.aiplatform import e2e_base
20+
21+
_TEST_PROJECT = "test-project"
22+
_TEST_LOCATION_TPC = "u-us-prp1"
23+
_TEST_ENDPOINT_TPC = "u-us-prp1-aiplatform.apis-tpczero.goog"
24+
_TEST_UNIVERSE_TPC = "apis-tpczero.goog"
25+
26+
27+
class TestTpcInitializer(e2e_base.TestEndToEnd):
28+
"""Tests TPC support in the initializer without monkeypatching."""
29+
30+
_temp_prefix = "test_tpc_initializer_"
31+
32+
def test_init_tpc_sets_global_config(self):
33+
# This verifies that our changes to validate_region allow u-us-prp1
34+
# and that universe_domain is correctly stored.
35+
aiplatform.init(
36+
project=_TEST_PROJECT,
37+
location=_TEST_LOCATION_TPC,
38+
universe_domain=_TEST_UNIVERSE_TPC,
39+
)
40+
41+
assert aiplatform.initializer.global_config.location == _TEST_LOCATION_TPC
42+
assert aiplatform.initializer.global_config.universe_domain == _TEST_UNIVERSE_TPC
43+
44+
def test_tpc_client_creation_plumbing(self):
45+
"""Verifies that clients created after TPC init have correct TPC endpoints."""
46+
aiplatform.init(
47+
project=_TEST_PROJECT,
48+
location=_TEST_LOCATION_TPC,
49+
universe_domain=_TEST_UNIVERSE_TPC,
50+
)
51+
52+
# Instantiate a client (e.g., Endpoint)
53+
# We don't call any methods that trigger real RPCs.
54+
ds = aiplatform.Endpoint.list()
55+
56+
# Check the underlying client's endpoint
57+
# The SDK should have constructed the endpoint using the TPC location and universe domain.
58+
client = aiplatform.initializer.global_config.create_client(
59+
client_class=aiplatform.utils.EndpointClientWithOverride
60+
)
61+
62+
expected_host = f"{_TEST_LOCATION_TPC}-aiplatform.{_TEST_UNIVERSE_TPC}:443"
63+
assert client._transport._host == expected_host

0 commit comments

Comments
 (0)