Skip to content

Commit

Permalink
Code clean
Browse files Browse the repository at this point in the history
  • Loading branch information
conker84 committed Dec 3, 2023
1 parent e1f0687 commit 9aa5f11
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 59 deletions.
55 changes: 25 additions & 30 deletions extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@

import java.io.File;
import java.net.MalformedURLException;
import java.util.HashMap;
import java.util.Map;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -30,9 +28,23 @@

@Extended
public class OpenAI {
enum ApiType { AZURE, OPENAI }
enum ApiType {
AZURE(null), OPENAI("https://api.openai.com/v1");

private final String defaultUrl;

ApiType(String defaultUrl) {
this.defaultUrl = defaultUrl;
}

public String getEndpoint(Map<String, Object> procConfig, ApocConfig apocConfig) {
return (String) procConfig.getOrDefault(ENDPOINT_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_URL, System.getProperty(APOC_ML_OPENAI_URL, defaultUrl)));
}
}

public static final String API_TYPE_CONF_KEY = "apiType";
public static final String APIKEY_CONF_KEY = "apiKey";
public static final String ENDPOINT_CONF_KEY = "endpoint";
public static final String API_VERSION_CONF_KEY = "apiVersion";

Expand All @@ -52,65 +64,48 @@ public EmbeddingResult(long index, String text, List<Double> embedding) {
}

static Stream<Object> executeRequest(String apiKey, Map<String, Object> configuration, String path, String model, String key, Object inputs, String jsonPath, ApocConfig apocConfig) throws JsonProcessingException, MalformedURLException {
apiKey = (String) configuration.getOrDefault(APOC_OPENAI_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
apiKey = (String) configuration.getOrDefault(APIKEY_CONF_KEY, apocConfig.getString(APOC_OPENAI_KEY, apiKey));
if (apiKey == null || apiKey.isBlank())
throw new IllegalArgumentException("API Key must not be empty");


String apiTypeString = (String) configuration.getOrDefault(API_TYPE_CONF_KEY,
apocConfig.getString(APOC_ML_OPENAI_TYPE, ApiType.OPENAI.name())
);
ApiType apiType = ApiType.valueOf(apiTypeString);
ApiType apiType = ApiType.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH));

String endpoint = (String) configuration.get(ENDPOINT_CONF_KEY);
String apiVersion;
Map<String, Object> headers = new HashMap<>();
String endpoint = apiType.getEndpoint(configuration, apocConfig);

final String apiVersion;
final Map<String, Object> headers = new HashMap<>();
headers.put("Content-Type", "application/json");
switch (apiType) {
case AZURE -> {
endpoint = getEndpoint(endpoint, apocConfig, "");
apiVersion = "?api-version=" + configuration.getOrDefault(API_VERSION_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_AZURE_VERSION));
headers.put("api-key", apiKey);
}
default -> {
endpoint = getEndpoint(endpoint, apocConfig, "https://api.openai.com/v1");
apiVersion = "";
headers.put("Authorization", "Bearer " + apiKey);
}
}

var config = new HashMap<>(configuration);
// we remove these keys from config, since the json payload is calculated starting from the config map
Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY).forEach(config::remove);
Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY).forEach(config::remove);
config.putIfAbsent("model", model);
config.put(key, inputs);

String payload = new ObjectMapper().writeValueAsString(config);
String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(config);

// new URL(endpoint), path) can produce a wrong path, since endpoint can have for example embedding,
// eg: https://my-resource.openai.azure.com/openai/deployments/apoc-embeddings-model
// therefore is better to join the not-empty path pieces
var url = Stream.of(endpoint, path, apiVersion)
.filter(StringUtils::isNotBlank)
.collect(Collectors.joining(File.separator));
.collect(Collectors.joining("/"));
return JsonUtil.loadJson(url, headers, payload, jsonPath, true, List.of());
}

private static String getEndpoint(String endpointConfMap, ApocConfig apocConfig, String defaultUrl) {
if (endpointConfMap != null) {
return endpointConfMap;
}

String apocConfUrl = apocConfig.getString(APOC_ML_OPENAI_URL, null);
if (apocConfUrl != null) {
return apocConfUrl;
}

return System.getProperty(APOC_ML_OPENAI_URL, defaultUrl);
}


@Procedure("apoc.ml.openai.embedding")
@Description("apoc.openai.embedding([texts], api_key, configuration) - returns the embeddings for a given text")
public Stream<EmbeddingResult> getEmbedding(@Name("texts") List<String> texts, @Name("api_key") String apiKey, @Name(value = "configuration", defaultValue = "{}") Map<String, Object> configuration) throws Exception {
Expand Down
38 changes: 17 additions & 21 deletions extended/src/test/java/apoc/ml/OpenAIAzureIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,18 @@
import org.neo4j.test.rule.ImpermanentDbmsRule;

import java.util.Map;
import java.util.stream.Stream;

import static apoc.ApocConfig.apocConfig;
import static apoc.ml.OpenAI.API_TYPE_CONF_KEY;
import static apoc.ml.OpenAI.API_VERSION_CONF_KEY;
import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY;
import static apoc.ml.OpenAITestUtils.getStringObjectMap;
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion;
import static apoc.util.TestUtil.testCall;
import static org.junit.Assume.assumeNotNull;

public class OpenAIAzureIT {
// In Azure, the endpoints can be different
private static String OPENAI_EMBEDDING_URL;
private static String OPENAI_CHAT_URL;
private static String OPENAI_COMPLETION_URL;

private static String OPENAI_URL;

private static String OPENAI_AZURE_API_VERSION;

private static String OPENAI_KEY;
Expand All @@ -35,37 +31,37 @@ public class OpenAIAzureIT {
public static void setUp() throws Exception {
OPENAI_KEY = System.getenv("OPENAI_KEY");
// Azure OpenAI base URLs
OPENAI_EMBEDDING_URL = System.getenv("OPENAI_EMBEDDING_URL");
OPENAI_CHAT_URL = System.getenv("OPENAI_CHAT_URL");
OPENAI_COMPLETION_URL = System.getenv("OPENAI_COMPLETION_URL");
OPENAI_URL = System.getenv("OPENAI_URL");

// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>)
// Azure OpenAI query url (`<baseURL>/<type>/?api-version=<OPENAI_AZURE_API_VERSION>`)
OPENAI_AZURE_API_VERSION = System.getenv("OPENAI_AZURE_API_VERSION");

apocConfig().setProperty("ajeje", "brazorf");


/*
Stream.of(OPENAI_EMBEDDING_URL,
OPENAI_CHAT_URL,
OPENAI_COMPLETION_URL,
OPENAI_COMPLETION_URL,
OPENAI_AZURE_API_VERSION,
OPENAI_KEY)
.forEach(key -> assumeNotNull("No " + key + " environment configured", key));
*/

TestUtil.registerProcedure(db, OpenAI.class);
}

@Test
public void embedding() {
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey, $conf)",
getParams(OPENAI_EMBEDDING_URL),
OpenAITestUtils::extracted);
getParams(),
OpenAITestResultUtils::assertEmbeddings);
}


@Test
public void completion() {
testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey, $conf)",
getParams(OPENAI_COMPLETION_URL), OpenAITestUtils::extracted1);
getParams(), OpenAITestResultUtils::assertCompletion);
}

@Test
Expand All @@ -74,14 +70,14 @@ public void chatCompletion() {
CALL apoc.ml.openai.chat([
{role:"system", content:"Only answer with a single word"},
{role:"user", content:"What planet do humans live on?"}
], $apiKey, $conf)
""", getParams(OPENAI_CHAT_URL),
(row) -> getStringObjectMap(row, "gpt-35-turbo"));
], $apiKey, $conf)
""", getParams(),
(row) -> assertChatCompletion(row, "gpt-35-turbo"));
}

private static Map<String, Object> getParams(String endpoint) {
private static Map<String, Object> getParams() {
return Map.of("apiKey", OPENAI_KEY,
"conf", Map.of(ENDPOINT_CONF_KEY, endpoint,
"conf", Map.of(ENDPOINT_CONF_KEY, OPENAI_URL,
API_TYPE_CONF_KEY, OpenAI.ApiType.AZURE.name(),
API_VERSION_CONF_KEY, OPENAI_AZURE_API_VERSION
)
Expand Down
8 changes: 4 additions & 4 deletions extended/src/test/java/apoc/ml/OpenAIIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import java.util.Map;

import static apoc.ml.OpenAITestUtils.getStringObjectMap;
import static apoc.ml.OpenAITestResultUtils.assertChatCompletion;
import static apoc.util.TestUtil.testCall;

public class OpenAIIT {
Expand All @@ -33,14 +33,14 @@ public void setUp() throws Exception {
@Test
public void getEmbedding() {
testCall(db, "CALL apoc.ml.openai.embedding(['Some Text'], $apiKey)", Map.of("apiKey",openaiKey),
OpenAITestUtils::extracted);
OpenAITestResultUtils::assertEmbeddings);
}

@Test
public void completion() {
testCall(db, "CALL apoc.ml.openai.completion('What color is the sky? Answer in one word: ', $apiKey)",
Map.of("apiKey",openaiKey),
OpenAITestUtils::extracted1);
OpenAITestResultUtils::assertCompletion);
}

@Test
Expand All @@ -51,7 +51,7 @@ public void chatCompletion() {
{role:"user", content:"What planet do humans live on?"}
], $apiKey)
""", Map.of("apiKey",openaiKey),
(row) -> getStringObjectMap(row, "gpt-3.5-turbo"));
(row) -> assertChatCompletion(row, "gpt-3.5-turbo"));

/*
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class OpenAITestUtils {
public static void extracted(Map<String, Object> row) {
public class OpenAITestResultUtils {
public static void assertEmbeddings(Map<String, Object> row) {
assertEquals(0L, row.get("index"));
assertEquals("Some Text", row.get("text"));
var embedding = (List<Double>) row.get("embedding");
assertEquals(1536, embedding.size());
}

public static void extracted1(Map<String, Object> row) {
public static void assertCompletion(Map<String, Object> row) {
var result = (Map<String,Object>) row.get("value");
assertTrue(result.get("created") instanceof Number);
assertTrue(result.containsKey("choices"));
Expand All @@ -29,7 +29,7 @@ public static void extracted1(Map<String, Object> row) {
assertEquals("text_completion", result.get("object"));
}

public static void getStringObjectMap(Map<String, Object> row, String modelId) {
public static void assertChatCompletion(Map<String, Object> row, String modelId) {
var result = (Map<String,Object>) row.get("value");
assertTrue(result.get("created") instanceof Number);
assertTrue(result.containsKey("choices"));
Expand Down

0 comments on commit 9aa5f11

Please sign in to comment.