Skip to content

Commit

Permalink
Fixes #4005: Add a procedure for RAG (#4077)
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed May 29, 2024
1 parent bc76a3d commit fa61cde
Show file tree
Hide file tree
Showing 6 changed files with 676 additions and 5 deletions.
191 changes: 190 additions & 1 deletion docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -588,4 +588,193 @@ RETURN *
|===
| name | description
| value | the description of the dataset
|===
|===

== Query with Retrieval-augmented generation (RAG) technique

This procedure `apoc.ml.rag` takes a list of paths or a vector index name, relevant attributes and a natural language question
to create a prompt implementing a Retrieval-augmented generation (RAG) technique.

See https://aws.amazon.com/what-is/retrieval-augmented-generation/[here] for more info about the RAG process.

It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^].



.Input Parameters
[%autowidth, opts=header]
|===
| name | description | mandatory
| paths | the list of paths to retrieve and augment the prompt, it can also be a matching query or a vector index name | yes
| attributes | the relevant attributes useful to retrieve and augment the prompt | yes
| question | the user question | yes
| conf | An optional configuration map, please check the next section | no
|===


.Configuration map
[%autowidth, opts=header]
|===
| name | description | mandatory
| getLabelTypes | add the label / rel-type names to the info to augment the prompt | no, default `true`
| embeddings | to search similar embeddings stored into a node vector index (in case of `embeddings: "NODE"`) or relationship vector index (in case of `embeddings: "REL"`) | no, default `"FALSE"`
| topK | number of neighbors to find for each node (in case of `embeddings: "NODE"`) or relationships (in case of `embeddings: "REL"`) | no, default `40`
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| prompt | the base prompt to be augmented with the context | no, default is:

"You are a customer service agent that helps a customer with answering questions about a service.
Use the following context to answer the `user question` at the end.
Make sure not to make any changes to the context if possible when prepare answers to provide accurate responses.
If you don't know the answer, just say \`Sorry, I don't know`, don't try to make up an answer."
|===


Using the apoc.ml.rag procedure we can reduce AI hallucinations (i.e. false or misleading responses),
providing relevant and up-to-date information to our procedure via the 1st parameter.

For example, by executing the following procedure (with the `gpt-3.5-turbo` model, last updated in January 2022)
we have a hallucination

.Query call
[source,cypher]
----
CALL apoc.ml.openai.chat([
{role:"user", content: "Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?"}
], $apiKey)
----

.Example response
[opts="header"]
|===
| value
| The gold medal in curling at the 2022 Winter Olympics was won by the Swedish men's team and the Russian women's team.
|===

So, we can use the RAG technique to provide real results.
For example with the given dataset (with data taken from https://en.wikipedia.org/wiki/Curling_at_the_2022_Winter_Olympics[this wikipedia page]):

.wikipedia dataset
[source,cypher]
----
CREATE (mixed2022:Discipline {title:"Mixed doubles's curling", year: 2022})
WITH mixed2022
CREATE (:Athlete {name: 'Stefania Constantini', country: 'Italy', irrelevant: 'asdasd'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'asdasd'}]->(mixed2022)
CREATE (:Athlete {name: 'Amos Mosaner', country: 'Italy', irrelevant: 'qweqwe'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'rwerew'}]->(mixed2022)
CREATE (:Athlete {name: 'Kristin Skaslien', country: 'Norway', irrelevant: 'dfgdfg'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'gdfg'}]->(mixed2022)
CREATE (:Athlete {name: 'Magnus Nedregotten', country: 'Norway', irrelevant: 'xcvxcv'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'asdasd'}]->(mixed2022)
CREATE (:Athlete {name: 'Almida de Val', country: 'Sweden', irrelevant: 'rtyrty'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'bfbfb'}]->(mixed2022)
CREATE (:Athlete {name: 'Oskar Eriksson', country: 'Sweden', irrelevant: 'qwresdc'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'juju'}]->(mixed2022)
----

we can execute:

.Query call
[source,cypher]
----
MATCH path=(:Athlete)-[:HAS_MEDAL]->(Discipline)
WITH collect(path) AS paths
CALL apoc.ml.rag(paths,
["name", "country", "medal", "title", "year"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey}
) YIELD value
RETURN value
----

.Example response
[opts="header"]
|===
| value
| The gold medal in curling at the 2022 Winter Olympics was won by Stefania Constantini and Amos Mosaner from Italy.
|===

or:

.Query call
[source,cypher]
----
MATCH path=(:Athlete)-[:HAS_MEDAL]->(Discipline)
WITH collect(path) AS paths
CALL apoc.ml.rag(paths,
["name", "country", "medal", "title", "year"],
"Which athletes won the silver medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey}
) YIELD value
RETURN value
----

.Example response
[opts="header"]
|===
| value
| The gold medal in curling at the 2022 Winter Olympics was won by Kristin Skaslien and Magnus Nedregotten from Norway.
|===

We can also pass a string query returning paths/relationships/nodes, for example:

[source,cypher]
----
CALL apoc.ml.rag("MATCH path=(:Athlete)-[:HAS_MEDAL]->(Discipline) WITH collect(path) AS paths",
["name", "country", "medal", "title", "year"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey}
) YIELD value
RETURN value
----

.Example response
[opts="header"]
|===
| value
| The gold medal in curling at the 2022 Winter Olympics was won by Stefania Constantini and Amos Mosaner from Italy.
|===

or we can pass a vector index name as the 1st parameter, in case we stored useful info into embedding nodes.
For example, given this node vector index:

[source,cypher]
----
CREATE VECTOR INDEX `rag-embeddings`
FOR (n:RagEmbedding) ON (n.embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1536,
`vector.similarity_function`: 'cosine'
}}
----

and some (:RagEmbedding) nodes with the `text` properties, we can execute:

[source,cypher]
----
CALL apoc.ml.rag("rag-embeddings",
["text"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey, embeddings: "NODE", topK: 20}
) YIELD value
RETURN value
----

or, with a relationship vector index:


[source,cypher]
----
CREATE VECTOR INDEX `rag-rel-embeddings`
FOR ()-[r:RAG_EMBEDDING]-() ON (r.embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1536,
`vector.similarity_function`: 'cosine'
}}
----

and some [:RagEmbedding] relationships with the `text` properties, we can execute:

[source,cypher]
----
CALL apoc.ml.rag("rag-rel-embeddings",
["text"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey, embeddings: "REL", topK: 20}
) YIELD value
RETURN value
----
105 changes: 102 additions & 3 deletions extended/src/main/java/apoc/ml/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
import apoc.util.Util;
import apoc.util.collection.Iterators;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.commons.text.WordUtils;
import org.jetbrains.annotations.NotNull;
import org.neo4j.graphdb.Entity;
import org.neo4j.graphdb.GraphDatabaseService;
import org.neo4j.graphdb.Node;
import org.neo4j.graphdb.Path;
import org.neo4j.graphdb.QueryExecutionException;
import org.neo4j.graphdb.Relationship;
import org.neo4j.graphdb.Transaction;
import org.neo4j.graphdb.security.URLAccessChecker;
import org.neo4j.internal.kernel.api.procs.ProcedureCallContext;
Expand All @@ -30,7 +35,8 @@

@Extended
public class Prompt {

public static final String API_KEY_CONF = "apiKey";

@Context
public Transaction tx;
@Context
Expand All @@ -44,7 +50,101 @@ public class Prompt {
@Context
public URLAccessChecker urlAccessChecker;


@Procedure(mode = Mode.READ)
@Description("Takes a query in cypher and in natural language and returns the results in natural language")
public Stream<StringResult> rag(@Name("paths") Object paths,
@Name("attributes") List<String> attributes,
@Name("question") String question,
@Name(value = "conf", defaultValue = "{}") Map<String, Object> conf) throws Exception {

RagConfig config = new RagConfig(conf);

String[] arrayAttrs = attributes.toArray(String[]::new);

StringBuilder context = new StringBuilder();

// -- Retrieve
if (paths instanceof List pathList) {

for (var listItem : pathList) {
// -- Augment
augment(config, arrayAttrs, context, listItem);
}

} else if (paths instanceof String queryOrIndex) {
config.getEmbeddings()
.getQuery(queryOrIndex, question, tx, config)
.forEachRemaining(row -> row
.values()
// -- Augment
.forEach( val -> augment(config, arrayAttrs, context, val) )
);
} else {
throw new RuntimeException("The first parameter must be a List or a String");
}

// - Generate
String contextPrompt = """
---- Start context ----
%s
---- End context ----
""".formatted(context);

String prompt = config.getBasePrompt() + contextPrompt;

String result = prompt("\nQuestion:" + question,
prompt,
null,
null,
conf,
List.of()
);
return Stream.of(new StringResult(result));
}

private void augment(RagConfig config, String[] objects, StringBuilder context, Object listItem) {
if (listItem instanceof Path p) {
for (Entity entity : p) {
augmentEntity(config, objects, context, entity);
}
} else if (listItem instanceof Entity e) {
augmentEntity(config, objects, context, e);
} else {
throw new RuntimeException("The list `%s` must have node/type/path items".formatted(listItem));
}
}

private void augmentEntity(RagConfig config, String[] objects, StringBuilder context, Entity entity) {
Map<String, Object> props = entity.getProperties(objects);
if (config.isGetLabelTypes()) {
String labelsOrType = entity instanceof Node node
? Util.joinLabels(node.getLabels(), ",")
: ((Relationship) entity).getType().name();
labelsOrType = WordUtils.capitalize(labelsOrType, '_');
props.put("context description", labelsOrType);
}
String obj = props.entrySet().stream()
.filter(i -> i.getValue() != null)
.map(i -> i.getKey() + ": " + i.getValue() + "\n")
.collect(Collectors.joining("\n---\n"));
context.append(obj);
}

public static final String BACKTICKS = "```";

public static final String UNKNOWN_ANSWER = "Sorry, I don't know";
static final String RAG_BASE_PROMPT = """
You are a customer service agent that helps a customer with answering questions about a service.
Use the following context to answer the `user question` at the end. Make sure not to make any changes to the context if possible when prepare answers so as to provide accurate responses.
If you don't know the answer, just say `%s`, don't try to make up an answer.
---- Start context ----
%s
---- End context ----
""";

public static final String EXPLAIN_SCHEMA_PROMPT = """
You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains.
Explain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable.
Expand All @@ -70,7 +170,6 @@ Given a graph database schema of entities (nodes) with labels and attributes and
Do not explain, apologize or provide additional detail about the schema, otherwise people will come to harm.
""";


public class PromptMapResult {
public final Map<String, Object> value;
public final String query;
Expand Down Expand Up @@ -212,7 +311,7 @@ private String prompt(String userQuestion, String systemPrompt, String assistant

prompt.addAll(otherPrompts);

String apiKey = (String) conf.get("apiKey");
String apiKey = (String) conf.get(API_KEY_CONF);
String model = (String) conf.getOrDefault("model", "gpt-3.5-turbo");
String result = OpenAI.executeRequest(apiKey, Map.of(), "chat/completions",
model, "messages", prompt, "$", apocConfig, urlAccessChecker)
Expand Down
Loading

0 comments on commit fa61cde

Please sign in to comment.