diff --git a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java index 8284e5821d88..500d902e3089 100644 --- a/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java +++ b/java-vertexai/google-cloud-vertexai/src/main/java/com/google/cloud/vertexai/VertexAI.java @@ -166,8 +166,15 @@ private VertexAI( this.llmClientSupplier = Suppliers.memoize(llmClientSupplierOpt.orElse(this::newLlmUtilityClient)); - this.apiEndpoint = - apiEndpoint.orElse(String.format("%s-aiplatform.googleapis.com", this.location)); + if (apiEndpoint.isPresent()) { + this.apiEndpoint = apiEndpoint.get(); + } else { + if ("global".equals(this.location)) { + this.apiEndpoint = "aiplatform.googleapis.com"; + } else { + this.apiEndpoint = String.format("%s-aiplatform.googleapis.com", this.location); + } + } } /** Builder for {@link VertexAI}. */ diff --git a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java index 58a20773dc30..3d005dc1feb0 100644 --- a/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java +++ b/java-vertexai/google-cloud-vertexai/src/test/java/com/google/cloud/vertexai/VertexAITest.java @@ -48,7 +48,9 @@ public final class VertexAITest { private static final String TEST_PROJECT = "test_project"; private static final String TEST_LOCATION = "test_location"; + private static final String GLOBAL_LOCATION = "global"; private static final String TEST_ENDPOINT = "test_endpoint"; + private static final String GLOBAL_ENDPOINT = "aiplatform.googleapis.com"; private static final String TEST_DEFAULT_ENDPOINT = String.format("%s-aiplatform.googleapis.com", TEST_LOCATION); private static final Optional EMPTY_ENV_VAR_OPTIONAL = Optional.ofNullable(null); @@ -344,6 +346,27 @@ public void testInstantiateVertexAI_builderLocationFromCLOUD_ML_REGION_shouldCon } } + @Test + public void testInstantiateVertexAI_builderLocationFromGLOBAL_REGION_shouldContainRightFields() { + try (MockedStatic mockedStaticVertexAI = mockStatic(VertexAI.class)) { + mockedStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_REGION")) + .thenReturn(Optional.of(GLOBAL_LOCATION)); + mockedStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("CLOUD_ML_REGION")) + .thenReturn(Optional.empty()); + mockedStaticVertexAI + .when(() -> VertexAI.getEnvironmentVariable("GOOGLE_CLOUD_PROJECT")) + .thenReturn(Optional.of(TEST_PROJECT)); + + vertexAi = new VertexAI.Builder().build(); + + assertThat(vertexAi.getProjectId()).isEqualTo(TEST_PROJECT); + assertThat(vertexAi.getLocation()).isEqualTo(GLOBAL_LOCATION); + assertThat(vertexAi.getApiEndpoint()).isEqualTo(GLOBAL_ENDPOINT); + } + } + @Test public void testInstantiateVertexAI_builderWithScopes_throwsIlegalArgumentException() throws IllegalArgumentException {