From 54718ad30592553dcaf15f3b1384113b4862df6d Mon Sep 17 00:00:00 2001 From: Tatu Saloranta Date: Tue, 14 May 2024 12:39:40 -0700 Subject: [PATCH] Fix #1088: handle `null` valuad service params --- .../operation/EmbeddingProviderFactory.java | 5 +++- .../operation/VertexAIEmbeddingClient.java | 9 ++++--- .../operation/VoyageAIEmbeddingClient.java | 4 ++-- .../operation/EmbeddingGatewayClientTest.java | 24 +++++++++++++++++++ 4 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java index 18f4e1466c..9c15900c6a 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java @@ -23,7 +23,7 @@ public class EmbeddingProviderFactory { @GrpcClient("embedding") EmbeddingService embeddingService; - private interface ProviderConstructor { + interface ProviderConstructor { EmbeddingProvider create( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, @@ -55,6 +55,9 @@ public EmbeddingProvider getConfiguration( Map vectorizeServiceParameters, Map authentication, String commandName) { + if (vectorizeServiceParameters == null) { + vectorizeServiceParameters = Map.of(); + } return addService( tenant, authToken, diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java index a231d3c4ff..584dfc48dd 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VertexAIEmbeddingClient.java @@ -33,18 +33,17 @@ public class VertexAIEmbeddingClient implements EmbeddingProvider { private static final String PROJECT_ID = "projectId"; - private Map vectorizeServiceParameters; - public VertexAIEmbeddingClient( EmbeddingProviderConfigStore.RequestProperties requestProperties, String baseUrl, String modelName, int dimension, - Map vectorizeServiceParameters) { + Map serviceParameters) { this.requestProperties = requestProperties; this.modelName = modelName; - this.vectorizeServiceParameters = vectorizeServiceParameters; - baseUrl = baseUrl.replace(PROJECT_ID, vectorizeServiceParameters.get(PROJECT_ID).toString()); + String projectId = + (serviceParameters == null) ? "" : String.valueOf(serviceParameters.get(PROJECT_ID)); + baseUrl = baseUrl.replace(PROJECT_ID, projectId); embeddingProvider = QuarkusRestClientBuilder.newBuilder() .baseUri(URI.create(baseUrl)) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingClient.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingClient.java index 04deb0e234..f452f9d2b2 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingClient.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/VoyageAIEmbeddingClient.java @@ -38,13 +38,13 @@ public VoyageAIEmbeddingClient( String baseUrl, String modelName, int dimension, - Map vectorizeServiceParameters) { + Map serviceParameters) { this.requestProperties = requestProperties; this.modelName = modelName; // use configured input_type if available requestTypeQuery = requestProperties.requestTypeQuery().orElse(null); requestTypeIndex = requestProperties.requestTypeIndex().orElse(null); - Object v = vectorizeServiceParameters.get("autoTruncate"); + Object v = (serviceParameters == null) ? null : serviceParameters.get("autoTruncate"); autoTruncate = (v instanceof Boolean) ? (Boolean) v : null; embeddingProvider = diff --git a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java index 5a9ceb47cc..3e4386051e 100644 --- a/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java +++ b/src/test/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingGatewayClientTest.java @@ -16,6 +16,7 @@ import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.gateway.EmbeddingGatewayClient; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Optional; @@ -27,6 +28,29 @@ public class EmbeddingGatewayClientTest { public static final String TESTING_COMMAND_NAME = "test_command"; + // for [data-api#1088] (NPE for VoyageAI provider) + @Test + void verifyDirectConstructionWithNullServiceParameters() { + List providerCtors = + Arrays.asList( + AzureOpenAIEmbeddingClient::new, + CohereEmbeddingClient::new, + HuggingFaceEmbeddingClient::new, + JinaAIEmbeddingClient::new, + MistralEmbeddingClient::new, + NvidiaEmbeddingClient::new, + OpenAIEmbeddingClient::new, + UpstageAIEmbeddingClient::new, + VertexAIEmbeddingClient::new, + VoyageAIEmbeddingClient::new); + for (EmbeddingProviderFactory.ProviderConstructor ctor : providerCtors) { + EmbeddingProviderConfigStore.RequestProperties requestProperties = + EmbeddingProviderConfigStore.RequestProperties.of( + 3, 5, 5000, Optional.empty(), Optional.empty()); + assertThat(ctor.create(requestProperties, "baseUrl", "modelName", 5, null)).isNotNull(); + } + } + @Test void handleValidResponse() { EmbeddingService embeddingService = mock(EmbeddingService.class);