From 68fe36b1e8381cc30c859c503a04f7449f8a5a58 Mon Sep 17 00:00:00 2001 From: Tatu Saloranta Date: Fri, 17 May 2024 09:04:10 -0700 Subject: [PATCH] Fix #1098: prevent NPE for unknown embedding provider id (#1099) --- .../operation/EmbeddingProviderFactory.java | 43 ++++++++----------- 1 file changed, 19 insertions(+), 24 deletions(-) diff --git a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java index 9c15900c6a..dc815be2d7 100644 --- a/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java +++ b/src/main/java/io/stargate/sgv2/jsonapi/service/embedding/operation/EmbeddingProviderFactory.java @@ -4,7 +4,6 @@ import io.stargate.embedding.gateway.EmbeddingService; import io.stargate.sgv2.jsonapi.config.OperationsConfig; import io.stargate.sgv2.jsonapi.exception.ErrorCode; -import io.stargate.sgv2.jsonapi.exception.JsonApiException; import io.stargate.sgv2.jsonapi.service.embedding.configuration.EmbeddingProviderConfigStore; import io.stargate.sgv2.jsonapi.service.embedding.configuration.ProviderConstants; import io.stargate.sgv2.jsonapi.service.embedding.gateway.EmbeddingGatewayClient; @@ -96,33 +95,29 @@ private synchronized EmbeddingProvider addService( } if (configuration.serviceProvider().equals(ProviderConstants.CUSTOM)) { + Optional> clazz = configuration.implementationClass(); + if (!clazz.isPresent()) { + throw ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException("custom class undefined"); + } try { - Optional> clazz = configuration.implementationClass(); - if (clazz.isPresent()) { - final EmbeddingProvider customEmbeddingProvider = - (EmbeddingProvider) clazz.get().getConstructor().newInstance(); - return customEmbeddingProvider; - } else { - throw new JsonApiException( - ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE, - ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.getMessage() + "custom class undefined"); - } + return (EmbeddingProvider) clazz.get().getConstructor().newInstance(); } catch (Exception e) { - throw new JsonApiException( - ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE, - ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.getMessage() - + "custom class provided does not resolve to EmbeddingProvider " - + configuration.implementationClass().get().getCanonicalName()); + throw ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + "custom class provided ('%s') does not resolve to EmbeddingProvider", + clazz.get().getCanonicalName()); } } - return providersMap - .get(configuration.serviceProvider()) - .create( - configuration.requestConfiguration(), - configuration.baseUrl(), - modelName, - dimension, - vectorizeServiceParameters); + ProviderConstructor ctor = providersMap.get(configuration.serviceProvider()); + if (ctor == null) { + throw ErrorCode.VECTORIZE_SERVICE_TYPE_UNAVAILABLE.toApiException( + "unknown service provider '%s'", configuration.serviceProvider()); + } + return ctor.create( + configuration.requestConfiguration(), + configuration.baseUrl(), + modelName, + dimension, + vectorizeServiceParameters); } }