From c73392617b3d9b829fd6f8f66bd7f40ee976575b Mon Sep 17 00:00:00 2001 From: Giuseppe Villani Date: Wed, 4 Dec 2024 16:08:30 +0100 Subject: [PATCH] [NOID] Fixes #4058: Add support for mixedbread.ai Embedding API (#4060) (#4252) * [NOID] Fixes #4058: Add support for mixedbread.ai Embedding API (#4060) * Fixes #4058: Add support for mixedbread.ai Embedding API * updated extended.txt * Changes after rebase * [NOID] Fixes #3634: Updated ML procs for Azure OpenAI services (#3850) (#3863) (#3885) * Fixes #3634: Updated ML procs for Azure OpenAI services * Code clean * added enpoint env vars * Code clean part 2 * removed unused imports --------- --- docs/asciidoc/modules/ROOT/nav.adoc | 1 + .../asciidoc/modules/ROOT/pages/ml/index.adoc | 1 + .../modules/ROOT/pages/ml/mixedbread.adoc | 182 +++++++++++++ full/src/main/java/apoc/ml/MLUtil.java | 9 + full/src/main/java/apoc/ml/MixedbreadAI.java | 81 ++++++ full/src/main/java/apoc/ml/OpenAI.java | 53 +++- .../java/apoc/ml/OpenAIRequestHandler.java | 19 +- full/src/main/resources/extended.txt | 2 + .../src/test/java/apoc/ml/MixedbreadAIIT.java | 257 ++++++++++++++++++ full/src/test/java/apoc/ml/OpenAIAzureIT.java | 6 +- full/src/test/java/apoc/ml/OpenAITest.java | 3 +- 11 files changed, 593 insertions(+), 21 deletions(-) create mode 100644 docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc create mode 100644 full/src/main/java/apoc/ml/MLUtil.java create mode 100644 full/src/main/java/apoc/ml/MixedbreadAI.java create mode 100644 full/src/test/java/apoc/ml/MixedbreadAIIT.java diff --git a/docs/asciidoc/modules/ROOT/nav.adoc b/docs/asciidoc/modules/ROOT/nav.adoc index 1b9e56188a..57646810b3 100644 --- a/docs/asciidoc/modules/ROOT/nav.adoc +++ b/docs/asciidoc/modules/ROOT/nav.adoc @@ -135,6 +135,7 @@ include::partial$generated-documentation/nav.adoc[] * xref:ml/index.adoc[] ** xref:ml/openai.adoc[] + ** xref:ml/mixedbread.adoc[] * xref:background-operations/index.adoc[] ** xref::background-operations/periodic-background.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc index 73dd5e3169..9c89844781 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc @@ -8,3 +8,4 @@ These procedures generate embeddings, analyze text, complete text, complete chat This section includes: * xref::ml/openai.adoc[] +* xref::ml/mixedbread.adoc[] diff --git a/docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc b/docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc new file mode 100644 index 0000000000..cc829f28b1 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/ml/mixedbread.adoc @@ -0,0 +1,182 @@ +[[Mixedbread-api]] += Mixedbread API Access +:description: This section describes procedures that can be used to access the Mixedbread API. + + +Here is a list of all available Mixedbread API procedures: + + +[opts=header, cols="1, 4", separator="|"] +|=== +|name| description +|apoc.ml.mixedbread.custom(body, $config)| To create a customizable Mixedbread API call +|apoc.ml.mixedbread.embedding(texts, $config)| To create a Mixedbread API call to generate embeddings +|=== + +The `$config` parameter coincides with the payload to be passed to the http request, +and additionally the following configuration keys. + + +.Common configuration parameter + +|=== +| key | description +| apiType | analogous to `apoc.ml.openai.type` APOC config +| endpoint | analogous to `apoc.ml.openai.url` APOC config +| apiVersion | analogous to `apoc.ml.azure.api.version` APOC config +| path | To customize the url portion added to the base url (defined by the `endpoint` config). +By default, is `/embeddings`, `/completions` and `/chat/completions` for respectively the `apoc.ml.openai.embedding`, `apoc.ml.openai.completion` and `apoc.ml.openai.chat` procedures. +| jsonPath | To customize https://github.com/json-path/JsonPath[JSONPath] of the response. +The default is `$` for the `apoc.ml.openai.chat` and `apoc.ml.openai.completion` procedures, and `$.data` for the `apoc.ml.openai.embedding` procedure. +|=== + +Since embeddings are a super set of the Openai ones, +under-the-hood they leverage the apoc.ml.openai.* procedures, +so we can also create an APOC config `apoc.ml.openai.url` instead of the `endpoint` config. + + +== Generate Embeddings API + +This procedure `apoc.ml.mixedbread.embedding` can take a list of text strings, and will return one row per string, with the embedding data as a 1536 element vector. +It uses the `/embeddings/create` API which is https://www.mixedbread.ai/api-reference/endpoints/embeddings#create-embeddings[documented here^]. + +Additional configuration is passed to the API, the default model used is `mxbai-embed-large-v1`. + + +.Parameters +[%autowidth, opts=header] +|=== +|name | description +| texts | List of text strings +| apiKey | OpenAI API key +| configuration | optional map. See `Configuration` table above +|=== + +.Results +[%autowidth, opts=header] +|=== +|name | description +| index | index entry in original list +| text | line of text from original list +| embedding | embedding list of floatings/binary, + or map of embedding lists of floatings / binaries, in case of multiple encoding_format +|=== + + +.Generate Embeddings Call +[source,cypher] +---- +CALL apoc.ml.mixedbread.embedding(['Some Text'], $apiKey, {}) yield index, text, embedding; +---- + +.Generate Embeddings Response +[%autowidth, opts=header] +|=== +|index | text | embedding +|0 | "Some Text" | [-0.0065358975, -7.9563365E-4, .... -0.010693862, -0.005087272] +|=== + + +.Generate Embeddings Call with custom embedding dimension and model +[source,cypher] +---- +CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], + $apiKey, + {model: 'mxbai-embed-2d-large-v1', dimensions: 4} +) +---- + +.Generate Embeddings Example Response +[%autowidth, opts=header] +|=== +|index | text | embedding +|0 | "Some Text" | [0.019943237, -0.08843994, 0.068603516, 0.034942627] +|1 | "Other Text" | [0.011482239, -0.09069824, 0.05331421, 0.034088135] +|=== + + +.Generate Embeddings Call with custom embedding dimension, model and encoding_format +[source,cypher] +---- +CALL apoc.ml.mixedbread.embedding(['Some Text', 'garpez'], + $apiKey, + {encoding_format: ["float", "binary", "ubinary", "int8", "uint8", "base64"]} +) +---- + +.Generate Embeddings Example Response +[%autowidth, opts=header] +|=== +| index | text | embedding +| 0 | "Some Text" | {binary: , ubinary: , int8: , uint8: , base64: , float: } +| 0 | "garpez" | {binary: , ubinary: , int8: , uint8: , base64: , float: } +|=== + + + + +== Custom API + +Via the `apoc.ml.mixedbread.custom` we can create a customizable Mixedbread API Request, +returning a generic stream of objects. + +For example, we can use the https://www.mixedbread.ai/api-reference/endpoints/reranking[Reranking API]. + + +.Reranking API Call +[source,cypher] +---- +CALL apoc.ml.mixedbread.custom($apiKey, + { + endpoint: "https://api.mixedbread.ai/v1/reranking", + model: "mixedbread-ai/mxbai-rerank-large-v1", + query: "Who is the author of To Kill a Mockingbird?", + top_k: 3, + input: [ + "To Kill a Mockingbird is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.", + "The novel Moby-Dick was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.", + "Harper Lee, an American novelist widely known for her novel To Kill a Mockingbird, was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.", + "Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.", + "The Harry Potter series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.", + "The Great Gatsby, a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan." + ] + } +) +---- + +.Generate Embeddings Example Response +[%autowidth, opts=header] +|=== +| value +a| +[source,json] +---- +{ + "model": "mixedbread-ai/mxbai-rerank-large-v1", + "return_input": false, + "data": [ + { + "index": 0, + "score": 0.9980469, + "object": "text_document" + }, + { + "index": 2, + "score": 0.9980469, + "object": "text_document" + }, + { + "index": 3, + "score": 0.06915283, + "object": "text_document" + } + ], + "usage": { + "total_tokens": 302, + "prompt_tokens": 302 + }, + "object": "list", + "top_k": 3 +} +---- +|=== diff --git a/full/src/main/java/apoc/ml/MLUtil.java b/full/src/main/java/apoc/ml/MLUtil.java new file mode 100644 index 0000000000..1b9a770b29 --- /dev/null +++ b/full/src/main/java/apoc/ml/MLUtil.java @@ -0,0 +1,9 @@ +package apoc.ml; + +public class MLUtil { + public static final String ENDPOINT_CONF_KEY = "endpoint"; + public static final String API_VERSION_CONF_KEY = "apiVersion"; + public static final String MODEL_CONF_KEY = "model"; + public static final String API_TYPE_CONF_KEY = "apiType"; + public static final String APIKEY_CONF_KEY = "apiKey"; +} diff --git a/full/src/main/java/apoc/ml/MixedbreadAI.java b/full/src/main/java/apoc/ml/MixedbreadAI.java new file mode 100644 index 0000000000..4d478d8d26 --- /dev/null +++ b/full/src/main/java/apoc/ml/MixedbreadAI.java @@ -0,0 +1,81 @@ +package apoc.ml; + +import static apoc.ApocConfig.APOC_ML_OPENAI_URL; +import static apoc.ml.MLUtil.API_TYPE_CONF_KEY; +import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY; +import static apoc.ml.MLUtil.MODEL_CONF_KEY; + +import apoc.ApocConfig; +import apoc.result.ObjectResult; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Description; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +public class MixedbreadAI { + + public static final String DEFAULT_MODEL_ID = "mxbai-embed-large-v1"; + public static final String MIXEDBREAD_BASE_URL = "https://api.mixedbread.ai/v1"; + public static final String ERROR_MSG_MISSING_ENDPOINT = String.format( + "The endpoint must be defined via config `%s` or via apoc.conf `%s`", + ENDPOINT_CONF_KEY, APOC_ML_OPENAI_URL); + + public static final String ERROR_MSG_MISSING_MODELID = + String.format("The model must be defined via config `%s`", MODEL_CONF_KEY); + + /** + * embedding is an Object instead of List, as with a Mixedbread request having `"encoding_format": []`, + * the result can be e.g. {... "embedding": { "float": [], "base": , } ...} + * instead of e.g. {... "embedding": [] ...} + */ + public static final class EmbeddingResult { + public final long index; + public final String text; + public final Object embedding; + + public EmbeddingResult(long index, String text, Object embedding) { + this.index = index; + this.text = text; + this.embedding = embedding; + } + } + + @Context + public ApocConfig apocConfig; + + @Procedure("apoc.ml.mixedbread.custom") + @Description("apoc.mixedbread.custom(, configuration) - returns the embeddings for a given text") + public Stream custom( + @Name("api_key") String apiKey, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + if (!configuration.containsKey(MODEL_CONF_KEY)) { + throw new RuntimeException(ERROR_MSG_MISSING_MODELID); + } + + configuration.put(API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.MIXEDBREAD_CUSTOM.name()); + + return OpenAI.executeRequest(apiKey, configuration, null, null, null, null, null, apocConfig) + .map(ObjectResult::new); + } + + @Procedure("apoc.ml.mixedbread.embedding") + @Description( + "apoc.mixedbread.mixedbread([texts], api_key, configuration) - returns the embeddings for a given text") + public Stream getEmbedding( + @Name("texts") List texts, + @Name("api_key") String apiKey, + @Name(value = "configuration", defaultValue = "{}") Map configuration) + throws Exception { + configuration.putIfAbsent(MODEL_CONF_KEY, DEFAULT_MODEL_ID); + + configuration.put(API_TYPE_CONF_KEY, OpenAIRequestHandler.Type.MIXEDBREAD_EMBEDDING.name()); + return OpenAI.getEmbeddingResult(texts, apiKey, configuration, apocConfig, (map, text) -> { + Long index = (Long) map.get("index"); + return new EmbeddingResult(index, text, map.get("embedding")); + }); + } +} diff --git a/full/src/main/java/apoc/ml/OpenAI.java b/full/src/main/java/apoc/ml/OpenAI.java index f6b19eaea4..89f0787184 100644 --- a/full/src/main/java/apoc/ml/OpenAI.java +++ b/full/src/main/java/apoc/ml/OpenAI.java @@ -2,6 +2,11 @@ import static apoc.ApocConfig.APOC_ML_OPENAI_TYPE; import static apoc.ApocConfig.APOC_OPENAI_KEY; +import static apoc.ml.MLUtil.APIKEY_CONF_KEY; +import static apoc.ml.MLUtil.API_TYPE_CONF_KEY; +import static apoc.ml.MLUtil.API_VERSION_CONF_KEY; +import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY; +import static apoc.ml.MLUtil.MODEL_CONF_KEY; import apoc.ApocConfig; import apoc.Extended; @@ -13,6 +18,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.function.BiFunction; import java.util.stream.Stream; import org.neo4j.procedure.Context; import org.neo4j.procedure.Description; @@ -24,12 +30,6 @@ public class OpenAI { @Context public ApocConfig apocConfig; - public static final String APOC_ML_OPENAI_URL = "apoc.ml.openai.url"; - public static final String API_TYPE_CONF_KEY = "apiType"; - public static final String APIKEY_CONF_KEY = "apiKey"; - public static final String ENDPOINT_CONF_KEY = "endpoint"; - public static final String API_VERSION_CONF_KEY = "apiVersion"; - public static class EmbeddingResult { public final long index; public final String text; @@ -56,19 +56,26 @@ static Stream executeRequest( if (apiKey == null || apiKey.isBlank()) throw new IllegalArgumentException("API Key must not be empty"); String apiTypeString = (String) configuration.getOrDefault( API_TYPE_CONF_KEY, apocConfig.getString(APOC_ML_OPENAI_TYPE, OpenAIRequestHandler.Type.OPENAI.name())); - OpenAIRequestHandler apiType = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH)) - .get(); - - final Map headers = new HashMap<>(); - headers.put("Content-Type", "application/json"); - apiType.addApiKey(headers, apiKey); + OpenAIRequestHandler.Type type = OpenAIRequestHandler.Type.valueOf(apiTypeString.toUpperCase(Locale.ENGLISH)); var config = new HashMap<>(configuration); // we remove these keys from config, since the json payload is calculated starting from the config map Stream.of(ENDPOINT_CONF_KEY, API_TYPE_CONF_KEY, API_VERSION_CONF_KEY, APIKEY_CONF_KEY) .forEach(config::remove); - config.putIfAbsent("model", model); - config.put(key, inputs); + switch (type) { + case MIXEDBREAD_CUSTOM: + // no payload manipulation, taken from the configuration as-is + break; + default: + config.putIfAbsent(MODEL_CONF_KEY, model); + config.put(key, inputs); + } + OpenAIRequestHandler apiType = type.get(); + + final Map headers = new HashMap<>(); + headers.put("Content-Type", "application/json"); + + apiType.addApiKey(headers, apiKey); String payload = JsonUtil.OBJECT_MAPPER.writeValueAsString(config); @@ -99,13 +106,29 @@ public Stream getEmbedding( "model": "text-embedding-ada-002", "usage": { "prompt_tokens": 8, "total_tokens": 8 } } */ + + return getEmbeddingResult(texts, apiKey, configuration, apocConfig, (map, text) -> { + Long index = (Long) map.get("index"); + return new EmbeddingResult(index, text, (List) map.get("embedding")); + }); + } + + public static Stream getEmbeddingResult( + List texts, + String apiKey, + Map configuration, + ApocConfig apocConfig, + BiFunction embeddingMapping) + throws JsonProcessingException, MalformedURLException { Stream resultStream = executeRequest( apiKey, configuration, "embeddings", "text-embedding-ada-002", "input", texts, "$.data", apocConfig); + return resultStream .flatMap(v -> ((List>) v).stream()) .map(m -> { Long index = (Long) m.get("index"); - return new EmbeddingResult(index, texts.get(index.intValue()), (List) m.get("embedding")); + String text = texts.get(index.intValue()); + return embeddingMapping.apply(m, text); }); } diff --git a/full/src/main/java/apoc/ml/OpenAIRequestHandler.java b/full/src/main/java/apoc/ml/OpenAIRequestHandler.java index a7d6d539ff..7f78626be5 100644 --- a/full/src/main/java/apoc/ml/OpenAIRequestHandler.java +++ b/full/src/main/java/apoc/ml/OpenAIRequestHandler.java @@ -2,8 +2,10 @@ import static apoc.ApocConfig.APOC_ML_OPENAI_AZURE_VERSION; import static apoc.ApocConfig.APOC_ML_OPENAI_URL; -import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; -import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; +import static apoc.ml.MLUtil.API_VERSION_CONF_KEY; +import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY; +import static apoc.ml.MixedbreadAI.ERROR_MSG_MISSING_ENDPOINT; +import static apoc.ml.MixedbreadAI.MIXEDBREAD_BASE_URL; import apoc.ApocConfig; import java.util.Map; @@ -41,6 +43,8 @@ public String getFullUrl(String method, Map procConfig, ApocConf enum Type { AZURE(new Azure(null)), + MIXEDBREAD_EMBEDDING(new OpenAi(MIXEDBREAD_BASE_URL)), + MIXEDBREAD_CUSTOM(new Custom()), OPENAI(new OpenAi("https://api.openai.com/v1")); private final OpenAIRequestHandler handler; @@ -89,4 +93,15 @@ public void addApiKey(Map headers, String apiKey) { headers.put("Authorization", "Bearer " + apiKey); } } + + static class Custom extends OpenAi { + public Custom() { + super(null); + } + + @Override + public String getDefaultUrl() { + throw new RuntimeException(ERROR_MSG_MISSING_ENDPOINT); + } + } } diff --git a/full/src/main/resources/extended.txt b/full/src/main/resources/extended.txt index 16911f5f9f..0775bdc7e2 100644 --- a/full/src/main/resources/extended.txt +++ b/full/src/main/resources/extended.txt @@ -103,6 +103,8 @@ apoc.ml.fromCypher apoc.ml.fromQueries apoc.ml.query apoc.ml.schema +apoc.ml.mixedbread.custom +apoc.ml.mixedbread.embedding apoc.ml.openai.chat apoc.ml.openai.completion apoc.ml.openai.embedding diff --git a/full/src/test/java/apoc/ml/MixedbreadAIIT.java b/full/src/test/java/apoc/ml/MixedbreadAIIT.java new file mode 100644 index 0000000000..4f7a72290f --- /dev/null +++ b/full/src/test/java/apoc/ml/MixedbreadAIIT.java @@ -0,0 +1,257 @@ +package apoc.ml; + +import static apoc.ml.MLUtil.*; +import static apoc.ml.MixedbreadAI.*; +import static apoc.util.TestUtil.testCall; +import static apoc.util.TestUtil.testResult; +import static apoc.util.Util.map; +import static java.util.Collections.emptyMap; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +import apoc.util.TestUtil; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.junit.Assume; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; +import org.neo4j.test.rule.DbmsRule; +import org.neo4j.test.rule.ImpermanentDbmsRule; + +public class MixedbreadAIIT { + + @ClassRule + public static DbmsRule db = new ImpermanentDbmsRule(); + + private static String apiKey; + + @BeforeClass + public static void setUp() throws Exception { + String keyIdEnv = "MIXEDBREAD_API_KEY"; + apiKey = System.getenv(keyIdEnv); + + Assume.assumeNotNull("No MIXEDBREAD_API_KEY environment configured", apiKey); + + TestUtil.registerProcedure(db, MixedbreadAI.class); + } + + @Test + public void getEmbedding() { + testResult( + db, + "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", emptyMap()), + r -> { + Map row = r.next(); + assertEmbedding(row, 0L, "Some Text", 1024); + + row = r.next(); + assertEmbedding(row, 1L, "Other Text", 1024); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithNulls() { + testResult( + db, + "CALL apoc.ml.mixedbread.embedding([null, 'Some Text', null, 'Another Text'], $apiKey, $conf)", + Map.of("apiKey", apiKey, "conf", emptyMap()), + (r) -> { + Map row = r.next(); + assertEquals(1024, ((List) row.get("embedding")).size()); + assertEquals("Some Text", row.get("text")); + + row = r.next(); + assertEquals(1024, ((List) row.get("embedding")).size()); + assertEquals("Another Text", row.get("text")); + + row = r.next(); + assertNullEmbedding(row); + + row = r.next(); + assertNullEmbedding(row); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithCustomMultipleEncodingFormats() { + Set formats = Set.of("float", "binary", "ubinary", "int8", "uint8", "base64"); + Map conf = map("encoding_format", formats); + testResult( + db, + "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", conf), + r -> { + Map row = r.next(); + assertEquals(0L, row.get("index")); + assertEquals("Some Text", row.get("text")); + var embedding = (Map) row.get("embedding"); + assertEquals(formats, embedding.keySet()); + assertTrue(embedding.get("float") instanceof List); + assertTrue(embedding.get("binary") instanceof List); + assertTrue(embedding.get("ubinary") instanceof List); + assertTrue(embedding.get("int8") instanceof List); + assertTrue(embedding.get("uint8") instanceof List); + assertTrue(embedding.get("base64") instanceof String); + + row = r.next(); + assertEquals(1L, row.get("index")); + assertEquals("Other Text", row.get("text")); + embedding = (Map) row.get("embedding"); + assertEquals(formats, embedding.keySet()); + assertTrue(embedding.get("float") instanceof List); + assertTrue(embedding.get("binary") instanceof List); + assertTrue(embedding.get("ubinary") instanceof List); + assertTrue(embedding.get("int8") instanceof List); + assertTrue(embedding.get("uint8") instanceof List); + assertTrue(embedding.get("base64") instanceof String); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithCustomEmbeddingSize() { + testResult( + db, + "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", map("dimensions", 256)), + r -> { + Map row = r.next(); + assertEmbedding(row, 0L, "Some Text", 256); + + row = r.next(); + assertEmbedding(row, 1L, "Other Text", 256); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithOtherModel() { + testResult( + db, + "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", map(MODEL_CONF_KEY, "mxbai-embed-2d-large-v1")), + r -> { + Map row = r.next(); + assertEmbedding(row, 0L, "Some Text", 1024); + + row = r.next(); + assertEmbedding(row, 1L, "Other Text", 1024); + + assertFalse(r.hasNext()); + }); + } + + @Test + public void getEmbeddingWithWrongModel() { + try { + testCall( + db, + "CALL apoc.ml.mixedbread.embedding(['Some Text', 'Other Text'], $apiKey, $conf)", + map("apiKey", apiKey, "conf", map(MODEL_CONF_KEY, "wrong-id")), + r -> fail("Should fail due to wrong model id")); + } catch (Exception e) { + String errMsg = e.getMessage(); + assertTrue( + "Actual error message is: " + errMsg, + errMsg.contains( + "Server returned HTTP response code: 422 for URL: https://api.mixedbread.ai/v1/embeddings")); + } + } + + /** + * Example taken from here: https://www.mixedbread.ai/api-reference/endpoints/reranking + */ + @Test + public void customWithReranking() { + List input = List.of( + "To Kill a Mockingbird is a novel by Harper Lee published in 1960. It was immediately successful, winning the Pulitzer Prize, and has become a classic of modern American literature.", + "The novel Moby-Dick was written by Herman Melville and first published in 1851. It is considered a masterpiece of American literature and deals with complex themes of obsession, revenge, and the conflict between good and evil.", + "Harper Lee, an American novelist widely known for her novel To Kill a Mockingbird, was born in 1926 in Monroeville, Alabama. She received the Pulitzer Prize for Fiction in 1961.", + "Jane Austen was an English novelist known primarily for her six major novels, which interpret, critique and comment upon the British landed gentry at the end of the 18th century.", + "The Harry Potter series, which consists of seven fantasy novels written by British author J.K. Rowling, is among the most popular and critically acclaimed books of the modern era.", + "The Great Gatsby, a novel written by American author F. Scott Fitzgerald, was published in 1925. The story is set in the Jazz Age and follows the life of millionaire Jay Gatsby and his pursuit of Daisy Buchanan."); + Map conf = map( + ENDPOINT_CONF_KEY, + MIXEDBREAD_BASE_URL + "/reranking", + MODEL_CONF_KEY, + "mixedbread-ai/mxbai-rerank-large-v1", + "query", + "Who is the author of To Kill a Mockingbird?", + "top_k", + 3, + "input", + input); + testCall(db, "CALL apoc.ml.mixedbread.custom($apiKey, $conf)", Map.of("apiKey", apiKey, "conf", conf), row -> { + Map value = (Map) row.get("value"); + + List data = (List) value.get("data"); + assertEquals(3, data.size()); + + Map firstData = map("index", 0L, "score", 0.9980469, "object", "text_document"); + assertEquals(firstData, data.get(0)); + + Map secondData = map("index", 2L, "score", 0.9980469, "object", "text_document"); + assertEquals(secondData, data.get(1)); + + Map thirdData = map("index", 3L, "score", 0.06915283, "object", "text_document"); + assertEquals(thirdData, data.get(2)); + + assertEquals("list", value.get("object")); + }); + } + + @Test + public void customWithMissingEndpoint() { + try { + testCall( + db, + "CALL apoc.ml.mixedbread.custom($apiKey, $conf)", + map("apiKey", apiKey, "conf", map(MODEL_CONF_KEY, "aModelId")), + r -> fail("Should fail due to missing endpoint")); + } catch (Exception e) { + String errMsg = e.getMessage(); + assertTrue("Actual error message is: " + errMsg, errMsg.contains(ERROR_MSG_MISSING_ENDPOINT)); + } + } + + @Test + public void customWithMissingModel() { + try { + testCall( + db, + "CALL apoc.ml.mixedbread.custom($apiKey, $conf)", + map( + "apiKey", + apiKey, + "conf", + map(ENDPOINT_CONF_KEY, MIXEDBREAD_BASE_URL + "/reranking", "foo", "bar")), + r -> fail("Should fail due to missing model")); + } catch (Exception e) { + String errMsg = e.getMessage(); + assertTrue("Actual error message is: " + errMsg, errMsg.contains(ERROR_MSG_MISSING_MODELID)); + } + } + + private static void assertEmbedding( + Map row, long expectedIdx, String expectedText, Integer expectedSize) { + assertEquals(expectedIdx, row.get("index")); + assertEquals(expectedText, row.get("text")); + var embedding = (List) row.get("embedding"); + assertEquals(expectedSize, embedding.size()); + } + + private static void assertNullEmbedding(Map row) { + assertEmbedding(row, -1, null, 0); + } +} diff --git a/full/src/test/java/apoc/ml/OpenAIAzureIT.java b/full/src/test/java/apoc/ml/OpenAIAzureIT.java index 95f61e0a3d..c63c5adcce 100644 --- a/full/src/test/java/apoc/ml/OpenAIAzureIT.java +++ b/full/src/test/java/apoc/ml/OpenAIAzureIT.java @@ -1,8 +1,8 @@ package apoc.ml; -import static apoc.ml.OpenAI.API_TYPE_CONF_KEY; -import static apoc.ml.OpenAI.API_VERSION_CONF_KEY; -import static apoc.ml.OpenAI.ENDPOINT_CONF_KEY; +import static apoc.ml.MLUtil.API_TYPE_CONF_KEY; +import static apoc.ml.MLUtil.API_VERSION_CONF_KEY; +import static apoc.ml.MLUtil.ENDPOINT_CONF_KEY; import static apoc.ml.OpenAITestResultUtils.assertChatCompletion; import static apoc.ml.OpenAITestResultUtils.assertCompletion; import static apoc.util.TestUtil.testCall; diff --git a/full/src/test/java/apoc/ml/OpenAITest.java b/full/src/test/java/apoc/ml/OpenAITest.java index dd1f4aa4cd..ff19951931 100644 --- a/full/src/test/java/apoc/ml/OpenAITest.java +++ b/full/src/test/java/apoc/ml/OpenAITest.java @@ -1,6 +1,7 @@ package apoc.ml; import static apoc.ApocConfig.APOC_IMPORT_FILE_ENABLED; +import static apoc.ApocConfig.APOC_ML_OPENAI_URL; import static apoc.ApocConfig.apocConfig; import static apoc.util.TestUtil.getUrlFileName; import static apoc.util.TestUtil.testCall; @@ -31,7 +32,7 @@ public void setUp() throws Exception { // openaiKey = System.getenv("OPENAI_KEY"); // Assume.assumeNotNull("No OPENAI_KEY environment configured", openaiKey); var path = Paths.get(getUrlFileName("embeddings").toURI()).getParent().toUri(); - System.setProperty(OpenAI.APOC_ML_OPENAI_URL, path.toString()); + System.setProperty(APOC_ML_OPENAI_URL, path.toString()); apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true); TestUtil.registerProcedure(db, OpenAI.class); }