Skip to content

Commit

Permalink
Fixes #4133: switch default model for cypher/schema interactions to g…
Browse files Browse the repository at this point in the history
…pt-4o
  • Loading branch information
gmarcostam authored and vga91 committed Jul 16, 2024
1 parent 1ebbd99 commit 68afdf8
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 12 deletions.
10 changes: 5 additions & 5 deletions docs/asciidoc/modules/ROOT/pages/ml/genai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ RETURN m.title
| retryWithError | If true, in case of error retry the api adding the following messages to the body request:
{`"role":"user", "content": "The previous Cypher Statement throws the following error, consider it to return the correct statement: `<errorMessage>`"}, {"role":"assistant", "content":"Cypher Statement (in backticks):"}` | no, default `false`
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down Expand Up @@ -138,7 +138,7 @@ RETURN *
|===
| name | description | mandatory
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down Expand Up @@ -203,7 +203,7 @@ RETURN DISTINCT a.name
| name | description | mandatory
| count | The number of queries to retrieve | no, default `1`
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down Expand Up @@ -250,7 +250,7 @@ overall, this graph database schema provides a simple yet powerful representatio
| name | description | mandatory
| retries | The number of retries in case of API call failures | no, default `3`
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down Expand Up @@ -329,7 +329,7 @@ RETURN *
|===
| name | description | mandatory
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| model | The Open AI model | no, default `gpt-4o`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===

Expand Down
2 changes: 1 addition & 1 deletion docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ This procedure `apoc.ml.openai.chat` takes a list of maps of chat exchanges betw

It uses the `/chat/create` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^].

Additional configuration is passed to the API, the default model used is `gpt-3.5-turbo`.
Additional configuration is passed to the API, the default model used is `gpt-4o`.

.Chat Completion Call
[source,cypher]
Expand Down
3 changes: 2 additions & 1 deletion extended/src/main/java/apoc/ml/OpenAI.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ public Stream<MapResult> chatCompletion(@Name("messages") List<Map<String, Objec
if (messages == null) {
throw new RuntimeException(ERROR_NULL_INPUT);
}
return executeRequest(apiKey, configuration, "chat/completions", "gpt-3.5-turbo", "messages", messages, "$", apocConfig, urlAccessChecker)
configuration.putIfAbsent("model", "gpt-4o");
return executeRequest(apiKey, configuration, "chat/completions", (String) configuration.get("model"), "messages", messages, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String,Object>)v).map(MapResult::new);
// https://platform.openai.com/docs/api-reference/chat/create
/*
Expand Down
2 changes: 1 addition & 1 deletion extended/src/main/java/apoc/ml/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ private String prompt(String userQuestion, String systemPrompt, String assistant
prompt.addAll(otherPrompts);

String apiKey = (String) conf.get(API_KEY_CONF);
String model = (String) conf.getOrDefault("model", "gpt-3.5-turbo");
String model = (String) conf.getOrDefault("model", "gpt-4o");
String result = OpenAI.executeRequest(apiKey, Map.of(), "chat/completions",
model, "messages", prompt, "$", apocConfig, urlAccessChecker)
.map(v -> (Map<String, Object>) v)
Expand Down
215 changes: 211 additions & 4 deletions extended/src/test/java/apoc/ml/PromptIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,28 @@ public void testQuery() {
testResult(db, """
CALL apoc.ml.query($query, {retries: $retries, apiKey: $apiKey})
""",
Map.of(
"query", "What movies has Tom Hanks acted in?",
"retries", 3L,
"apiKey", OPENAI_KEY
),
(r) -> {
List<Map<String, Object>> list = r.stream().toList();
Assertions.assertThat(list).hasSize(12);
Assertions.assertThat(list.stream()
.map(m -> m.get("query"))
.filter(Objects::nonNull)
.map(Object::toString)
.map(String::trim))
.isNotEmpty();
});
}

@Test
public void testQueryGpt35Turbo() {
testResult(db, """
CALL apoc.ml.query($query, {model: 'gpt-3.5-turbo', retries: $retries, apiKey: $apiKey})
""",
Map.of(
"query", "What movies has Tom Hanks acted in?",
"retries", 2L,
Expand All @@ -100,6 +122,23 @@ public void testQuery() {
});
}

@Test
public void testQueryGpt35TurboUsingRetryWithError() {
testResult(db, """
CALL apoc.ml.query($query, {model: 'gpt-3.5-turbo', retries: $retries, apiKey: $apiKey, retryWithError: true})
""",
Map.of(
"query", UUID.randomUUID().toString(),
"retries", 10L,
"apiKey", OPENAI_KEY
),
(r) -> {
// check that it returns a Cypher result, also empty, without errors
List<Map<String, Object>> maps = Iterators.asList(r);
assertNotNull(maps);
});
}

@Test
public void testQueryUsingRetryWithError() {
testResult(db, """
Expand Down Expand Up @@ -154,6 +193,49 @@ public void testCypher() {
});
}

@Test
public void testCypherGpt35Turbo() {
long numOfQueries = 4L;
testResult(db, """
CALL apoc.ml.cypher($query, {model: 'gpt-3.5-turbo', count: $numOfQueries, apiKey: $apiKey})
""",
Map.of(
"query", "Who are the actors which also directed a movie?",
"numOfQueries", numOfQueries,
"apiKey", OPENAI_KEY
),
(r) -> {
List<Map<String, Object>> list = r.stream().toList();
Assertions.assertThat(list).hasSize((int) numOfQueries);
Assertions.assertThat(list.stream()
.map(m -> m.get("query"))
.filter(Objects::nonNull)
.map(Object::toString)
.filter(StringUtils::isNotEmpty))
.hasSize((int) numOfQueries);
});
}

@Test
public void testFromCypherGpt35Turbo() {
testCall(db, """
CALL apoc.ml.fromCypher($query, {model: 'gpt-3.5-turbo', retries: $retries, apiKey: $apiKey})
""",
Map.of(
"query", "MATCH (p:Person {name: \"Tom Hanks\"})-[:ACTED_IN]->(m:Movie) RETURN m",
"retries", 2L,
"apiKey", OPENAI_KEY
),
(r) -> {
String value = ( (String) r.get("value") ).toLowerCase();
String message = "Current value is: " + value;
assertTrue(message,
value.contains("movie"));
assertTrue(message,
value.contains("person") || value.contains("people") || value.contains("actor"));
});
}

@Test
public void testFromCypher() {
testCall(db, """
Expand All @@ -174,6 +256,25 @@ public void testFromCypher() {
});
}

@Test
public void testSchemaFromQueriesGpt35Turbo() {
List<String> queries = List.of("MATCH p=(n:Movie)--() RETURN p", "MATCH (n:Person) RETURN n", "MATCH (n:Movie) RETURN n", "MATCH p=(n)-[r]->() RETURN r");

testCall(db, """
CALL apoc.ml.fromQueries($queries, {model: 'gpt-3.5-turbo', apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> {

String value = ((String) r.get("value")).toLowerCase();
Assertions.assertThat(value).containsIgnoringCase("movie");
Assertions.assertThat(value).containsAnyOf("person", "people");
});
}

@Test
public void testSchemaFromQueries() {
List<String> queries = List.of("MATCH p=(n:Movie)--() RETURN p", "MATCH (n:Person) RETURN n", "MATCH (n:Movie) RETURN n", "MATCH p=(n)-[r]->() RETURN r");
Expand All @@ -192,6 +293,24 @@ public void testSchemaFromQueries() {
Assertions.assertThat(value).containsAnyOf("person", "people");
});
}

@Test
public void testSchemaFromQueriesWithSingleQueryGpt35Turbo() {
List<String> queries = List.of("MATCH (n:Movie) RETURN n");

testCall(db, """
CALL apoc.ml.fromQueries($queries, {model: 'gpt-3.5-turbo', apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> {
String value = ((String) r.get("value")).toLowerCase();
Assertions.assertThat(value).containsIgnoringCase("movie");
Assertions.assertThat(value).doesNotContainIgnoringCase("person", "people");
});
}

@Test
public void testSchemaFromQueriesWithSingleQuery() {
Expand All @@ -211,6 +330,24 @@ public void testSchemaFromQueriesWithSingleQuery() {
});
}

@Test
public void testSchemaFromQueriesWithWrongQueryGpt35Turbo() {
List<String> queries = List.of("MATCH (n:Movie) RETURN a");
try {
testCall(db, """
CALL apoc.ml.fromQueries($queries, {model: 'gpt-3.5-turbo', apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> fail());
} catch (Exception e) {
Assertions.assertThat(e.getMessage()).contains(" Variable `a` not defined");
}

}

@Test
public void testSchemaFromQueriesWithWrongQuery() {
List<String> queries = List.of("MATCH (n:Movie) RETURN a");
Expand All @@ -228,6 +365,23 @@ public void testSchemaFromQueriesWithWrongQuery() {
}

}

@Test
public void testSchemaFromEmptyQueriesGpt35Turbo() {
List<String> queries = List.of("MATCH (n:Movie) RETURN 1");

testCall(db, """
CALL apoc.ml.fromQueries($queries, {model: 'gpt-3.5-turbo', apiKey: $apiKey})
""",
Map.of(
"queries", queries,
"apiKey", OPENAI_KEY
),
(r) -> {
String value = ((String) r.get("value")).toLowerCase();
Assertions.assertThat(value).containsAnyOf("does not contain", "empty", "undefined", "doesn't have");
});
}

@Test
public void testSchemaFromEmptyQueries() {
Expand All @@ -246,15 +400,67 @@ public void testSchemaFromEmptyQueries() {
});
}

@Test
public void ragWithRelevantAttributesComparedToIrrelevantOneAndChatProcedureGpt35Turbo() {
String question = "Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?";

// -- test with hallucinations, wrong winner names
testCall(db, """
CALL apoc.ml.openai.chat([
{role:"user", content: $question}
], $apiKey, { model: 'gpt-3.5-turbo' })""",
map("apiKey", OPENAI_KEY, "question", question),
r -> {
var result = (Map<String,Object>) r.get("value");

Map message = ((List<Map<String,Map>>) result.get("choices")).get(0).get("message");
assertEquals("assistant", message.get("role"));
String value = (String) message.get("content");

String msg = "Current value is: " + value;
assertTrue(msg, value.contains("gold medal"));
assertNot2022Winners(value);
});

// -- test RAG with irrilevant attributes
testCall(db, QUERY_RAG,
map("attributes", List.of("irrelevant", "irrelevant2"),
"question", "Which athletes won the gold medal in curling at the 2022 Winter Olympics?",
"conf", map(API_KEY_CONF, OPENAI_KEY)
),
(r) -> {
String value = (String) r.get("value");
String message = "Current value is: " + value;
assertTrue(message, value.contains(UNKNOWN_ANSWER));

assertNot2022Winners(value);
});

// -- test RAG with relevant attributes
testCall(db, QUERY_RAG,
map(
"attributes", RAG_ATTRIBUTES,
"question", "Which athletes won the gold medal in curling at the 2022 Winter Olympics?",
"conf", map("apiKey", OPENAI_KEY)
),
(r) -> {
String value = (String) r.get("value");
String message = "Current value is: " + value;
assertTrue(message, value.contains("Stefania Constantini"));
assertTrue(message, value.contains("Amos Mosaner"));
assertTrue(message, value.contains("Italy"));
});
}

@Test
public void ragWithRelevantAttributesComparedToIrrelevantOneAndChatProcedure() {
String question = "Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?";
String question = "Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?";

// -- test with hallucinations, wrong winner names
testCall(db, """
CALL apoc.ml.openai.chat([
{role:"user", content: $question}
], $apiKey)""",
], $apiKey)""",
map("apiKey", OPENAI_KEY, "question", question),
r -> {
var result = (Map<String,Object>) r.get("value");
Expand All @@ -265,7 +471,7 @@ public void ragWithRelevantAttributesComparedToIrrelevantOneAndChatProcedure() {

String msg = "Current value is: " + value;
assertTrue(msg, value.contains("gold medal"));
assertNot2022Winners(value);
assert2022Winners(value);
});

// -- test RAG with irrilevant attributes
Expand Down Expand Up @@ -451,7 +657,8 @@ private static void assertNot2022Winners(String value) {
}

private static void assert2022Winners(String value) {
assertThat(value).contains("Stefania Constantini", "Amos Mosaner", "Italy");
assertThat(value).contains("Stefania Constantini", "Amos Mosaner");
assertThat(value).containsAnyOf("Italy", "Italian");
}

}

0 comments on commit 68afdf8

Please sign in to comment.