diff --git a/extended/src/main/java/apoc/ml/OpenAI.java b/extended/src/main/java/apoc/ml/OpenAI.java index e379ed3c3e..692ac1e844 100644 --- a/extended/src/main/java/apoc/ml/OpenAI.java +++ b/extended/src/main/java/apoc/ml/OpenAI.java @@ -12,9 +12,7 @@ import java.io.File; import java.net.MalformedURLException; -import java.util.HashMap; -import java.util.Map; -import java.util.List; +import java.util.*; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -30,9 +28,23 @@ @Extended public class OpenAI { - enum ApiType { AZURE, OPENAI } + enum ApiType { + AZURE(null), OPENAI("https://api.openai.com/v1"); + + private final String defaultUrl; + + ApiType(String defaultUrl) { + this.defaultUrl = defaultUrl; + } + + public String getEndpoint(Map procConfig, ApocConfig apocConfig) { + return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY, + apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, defaultUrl))); + } + } public static final String API_TYPE_CONF_KEY = "apiType"; + public static final String APIKEY_CONF_KEY = "apiKey"; public static final String ENDPOINT_CONF_KEY = "endpoint"; public static final String API_VERSION_CONF_KEY = "apiVersion"; @@ -52,29 +64,26 @@ public EmbeddingResult(long index, String text, List embedding) { } static Stream executeRequest(String apiKey, Map configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig) throws JsonProcessingException, MalformedURLException { - apiKey = (String) configuration.getOrDefault(APOC_OPENAI_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey)); + apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey)); if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); - String apiTypeString = (String) configuration.getOrDefault(API_TYPE_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_TYPE, ApiType.OPENAI.name()) ); - ApiType apiType = ApiType.valueOf(apiTypeString); + ApiType apiType = ApiType.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH)); - String endpoint = (String) configuration.get(ENDPOINT_CONF_KEY); - - String apiVersion; - Map headers = new HashMap<>(); + String endpoint = apiType.getEndpoint(configuration, apocConfig); + + final String apiVersion; + final Map headers = new HashMap<>(); headers.put("Content-Type", "application/json"); switch (apiType) { case AZURE -> { - endpoint = getEndpoint(endpoint, apocConfig, ""); apiVersion = "?api-version=" + configuration.getOrDefault(API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION)); headers.put("api-key", apiKey); } default -> { - endpoint = getEndpoint(endpoint, apocConfig, "https://api.openai.com/v1"); apiVersion = ""; headers.put("Authorization", "Bearer " + apiKey); } @@ -82,35 +91,21 @@ static Stream executeRequest(String apiKey, Map configur var config = new HashMap<>(configuration); // we remove these keys from config, since the json payload is calculated starting from the config map - Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY).forEach(config::remove); + Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY).forEach(config::remove); config.putIfAbsent("model", model); config.put(key, inputs); - String payload = new ObjectMapper().writeValueAsString(config); + String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(config); // new URL(endpoint), path) can produce a wrong path, since endpoint can have for example embedding, // eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model // therefore is better to join the not-empty path pieces var url = Stream.of(endpoint, path, apiVersion) .filter(StringUtils::isNotBlank) - .collect(Collectors.joining(File.separator)); + .collect(Collectors.joining("/")); return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of()); } - private static String getEndpoint(String endpointConfMap, ApocConfig apocConfig, String defaultUrl) { - if (endpointConfMap != null) { - return endpointConfMap; - } - - String apocConfUrl = apocConfig.getString(APOC_ML_OPENAI_URL, null); - if (apocConfUrl != null) { - return apocConfUrl; - } - - return System.getProperty(APOC_ML_OPENAI_URL, defaultUrl); - } - - @Procedure("apoc.ml.openai.embedding") @Description("apoc.openai.embedding([texts], api_key, configuration) - returns the embeddings for a given text") public Stream getEmbedding(@Name("texts") List texts, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map configuration) throws Exception { diff --git a/extended/src/test/java/apoc/ml/OpenAIAzureIT.java b/extended/src/test/java/apoc/ml/OpenAIAzureIT.java index 4e0ba02d70..b629442328 100644 --- a/extended/src/test/java/apoc/ml/OpenAIAzureIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIAzureIT.java @@ -8,22 +8,18 @@ import org.neo4j.test.rule.ImpermanentDbmsRule; import java.util.Map; -import java.util.stream.Stream; import static apoc.ApocConfig.apocConfig; import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; -import static apoc.ml.OpenAITestUtils.getStringObjectMap; +import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; import static apoc.util.TestUtil.testCall; -import static org.junit.Assume.assumeNotNull; public class OpenAIAzureIT { // In Azure, the endpoints can be different - private static String OPENAI_EMBEDDING_URL; - private static String OPENAI_CHAT_URL; - private static String OPENAI_COMPLETION_URL; - + private static String OPENAI_URL; + private static String OPENAI_AZURE_API_VERSION; private static String OPENAI_KEY; @@ -35,21 +31,21 @@ public class OpenAIAzureIT { public static void setUp() throws Exception { OPENAI_KEY = System.getenv("OPENAI_KEY"); // Azure OpenAI base URLs - OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL"); - OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL"); - OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL"); + OPENAI_URL = System.getenv("OPENAI_URL"); - // Azure OpenAI query url (`//?api-version=) + // Azure OpenAI query url (`//?api-version=`) OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION"); apocConfig().setProperty("ajeje", "brazorf"); - + + /* Stream.of(OPENAI_EMBEDDING_URL, OPENAI_CHAT_URL, - OPENAI_COMPLETION_URL, + OPENAI_COMPLETION_URL, OPENAI_AZURE_API_VERSION, OPENAI_KEY) .forEach(key -> assumeNotNull("No " + key + " environment configured", key)); + */ TestUtil.registerProcedure(db, OpenAI.class); } @@ -57,15 +53,15 @@ public static void setUp() throws Exception { @Test public void embedding() { testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)", - getParams(OPENAI_EMBEDDING_URL), - OpenAITestUtils::extracted); + getParams(), + OpenAITestResultUtils::assertEmbeddings); } @Test public void completion() { testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)", - getParams(OPENAI_COMPLETION_URL), OpenAITestUtils::extracted1); + getParams(), OpenAITestResultUtils::assertCompletion); } @Test @@ -74,14 +70,14 @@ public void chatCompletion() { CALL apoc.ml.openai.chat([ {role:"system", content:"Only answer with a single word"}, {role:"user", content:"What planet do humans live on?"} - ], $apiKey, $conf) - """, getParams(OPENAI_CHAT_URL), - (row) -> getStringObjectMap(row, "gpt-35-turbo")); + ], $apiKey, $conf) + """, getParams(), + (row) -> assertChatCompletion(row, "gpt-35-turbo")); } - private static Map getParams(String endpoint) { + private static Map getParams() { return Map.of("apiKey", OPENAI_KEY, - "conf", Map.of(ENDPOINT_CONF_KEY, endpoint, + "conf", Map.of(ENDPOINT_CONF_KEY, OPENAI_URL, API_TYPE_CONF_KEY, OpenAI.ApiType.AZURE.name(), API_VERSION_CONF_KEY, OPENAI_AZURE_API_VERSION ) diff --git a/extended/src/test/java/apoc/ml/OpenAIIT.java b/extended/src/test/java/apoc/ml/OpenAIIT.java index 387c754f2c..f459ab2c77 100644 --- a/extended/src/test/java/apoc/ml/OpenAIIT.java +++ b/extended/src/test/java/apoc/ml/OpenAIIT.java @@ -10,7 +10,7 @@ import java.util.Map; -import static apoc.ml.OpenAITestUtils.getStringObjectMap; +import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; import static apoc.util.TestUtil.testCall; public class OpenAIIT { @@ -33,14 +33,14 @@ public void setUp() throws Exception { @Test public void getEmbedding() { testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey), - OpenAITestUtils::extracted); + OpenAITestResultUtils::assertEmbeddings); } @Test public void completion() { testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey)", Map.of("apiKey",openaiKey), - OpenAITestUtils::extracted1); + OpenAITestResultUtils::assertCompletion); } @Test @@ -51,7 +51,7 @@ public void chatCompletion() { {role:"user", content:"What planet do humans live on?"} ], $apiKey) """, Map.of("apiKey",openaiKey), - (row) -> getStringObjectMap(row, "gpt-3.5-turbo")); + (row) -> assertChatCompletion(row, "gpt-3.5-turbo")); /* { diff --git a/extended/src/test/java/apoc/ml/OpenAITestUtils.java b/extended/src/test/java/apoc/ml/OpenAITestResultUtils.java similarity index 88% rename from extended/src/test/java/apoc/ml/OpenAITestUtils.java rename to extended/src/test/java/apoc/ml/OpenAITestResultUtils.java index b35494bbd7..11f3a25293 100644 --- a/extended/src/test/java/apoc/ml/OpenAITestUtils.java +++ b/extended/src/test/java/apoc/ml/OpenAITestResultUtils.java @@ -6,15 +6,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; -public class OpenAITestUtils { - public static void extracted(Map row) { +public class OpenAITestResultUtils { + public static void assertEmbeddings(Map row) { assertEquals(0L, row.get("index")); assertEquals("Some Text", row.get("text")); var embedding = (List) row.get("embedding"); assertEquals(1536, embedding.size()); } - public static void extracted1(Map row) { + public static void assertCompletion(Map row) { var result = (Map) row.get("value"); assertTrue(result.get("created") instanceof Number); assertTrue(result.containsKey("choices")); @@ -29,7 +29,7 @@ public static void extracted1(Map row) { assertEquals("text_completion", result.get("object")); } - public static void getStringObjectMap(Map row, String modelId) { + public static void assertChatCompletion(Map row, String modelId) { var result = (Map) row.get("value"); assertTrue(result.get("created") instanceof Number); assertTrue(result.containsKey("choices"));