Skip to content

Commit

Permalink
Vectorize: optionally return metadata for a vector query (#1297)
Browse files Browse the repository at this point in the history
  • Loading branch information
ndisidore authored Oct 31, 2023
1 parent 8cb234c commit ce4237a
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 98 deletions.
32 changes: 32 additions & 0 deletions src/cloudflare/internal/compatibility-flags.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) 2017-2022 Cloudflare, Inc.
// Licensed under the Apache 2.0 license found in the LICENSE file or at:
// https://opensource.org/licenses/Apache-2.0

// Keep this in sync with compatibility-date.capnp
// TODO(soon): See if we can automatically generate this from the capnp file
export const formDataParserSupportsFiles: boolean;
export const fetchRefusesUnknownProtocols: boolean;
export const esiIncludeIsVoidTag: boolean;
export const durableObjectFetchRequiresSchemeAuthority: boolean;
export const streamsByobReaderDetachesBuffer: boolean;
export const streamsJavaScriptControllers: boolean;
export const jsgPropertyOnPrototypeTemplate: boolean;
export const minimalSubrequests: boolean;
export const noCotsOnExternalFetch: boolean;
export const specCompliantUrl: boolean;
export const globalNavigator: boolean;
export const captureThrowsAsRejections: boolean;
export const r2PublicBetaApi: boolean;
export const obsolete14: boolean;
export const noSubstituteNull: boolean;
export const transformStreamJavaScriptControllers: boolean;
export const r2ListHonorIncludeFields: boolean;
export const exportCommonJsDefaultNamespace: boolean;
export const obsolete19: boolean;
export const webSocketCompression: boolean;
export const nodeJsCompat: boolean;
export const tcpSocketsSupport: boolean;
export const specCompliantResponseRedirect: boolean;
export const workerdExperimental: boolean;
export const durableObjectGetExisting: boolean;
export const vectorizeQueryMetadataOptional: boolean;
57 changes: 25 additions & 32 deletions src/cloudflare/internal/test/vectorize/vectorize-api-test.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,53 +19,46 @@ export const test_vector_search_vector_query = {
async test(ctr, env) {
const IDX = env["vector-search"];
{
// with returnVectors = true
// with returnValues = true, returnMetadata = true
const results = await IDX.query(new Float32Array(new Array(5).fill(0)), {
topK: 3,
returnVectors: true,
returnValues: true,
returnMetadata: true,
});
assert.equal(true, results.count > 0);
/** @type {VectorizeMatches} */
const expected = {
matches: [
{
vectorId: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
id: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
values: [0.2331, 1.0125, 0.6131, 0.9421, 0.9661, 0.8121],
metadata: { text: "She sells seashells by the seashore" },
score: 0.71151,
vector: {
id: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
values: [0.2331, 1.0125, 0.6131, 0.9421, 0.9661, 0.8121],
metadata: { text: "She sells seashells by the seashore" },
},
},
{
vectorId: "a44706aa-a366-48bc-8cc1-3feffd87d548",
score: 0.68913,
vector: {
id: "a44706aa-a366-48bc-8cc1-3feffd87d548",
values: [0.2321, 0.8121, 0.6315, 0.6151, 0.4121, 0.1512],
metadata: {
text: "Peter Piper picked a peck of pickled peppers",
},
id: "a44706aa-a366-48bc-8cc1-3feffd87d548",
values: [0.2321, 0.8121, 0.6315, 0.6151, 0.4121, 0.1512],
metadata: {
text: "Peter Piper picked a peck of pickled peppers",
},
score: 0.68913,
},
{
vectorId: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
score: 0.94812,
vector: {
id: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
values: [0.0515, 0.7512, 0.8612, 0.2153, 0.15121, 0.6812],
metadata: {
text: "You know New York, you need New York, you know you need unique New York",
},
id: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
values: [0.0515, 0.7512, 0.8612, 0.2153, 0.15121, 0.6812],
metadata: {
text: "You know New York, you need New York, you know you need unique New York",
},
score: 0.94812,
},
],
count: 3,
};
assert.deepStrictEqual(results, expected);
}

{
// with returnVectors unset (false)
// with returnValues = unset (false), returnMetadata = unset (false)
const results = await IDX.query(new Float32Array(new Array(5).fill(0)), {
topK: 3,
});
Expand All @@ -74,15 +67,15 @@ export const test_vector_search_vector_query = {
const expected = {
matches: [
{
vectorId: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
id: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
score: 0.71151,
},
{
vectorId: "a44706aa-a366-48bc-8cc1-3feffd87d548",
id: "a44706aa-a366-48bc-8cc1-3feffd87d548",
score: 0.68913,
},
{
vectorId: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
id: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
score: 0.94812,
},
],
Expand Down Expand Up @@ -141,7 +134,7 @@ export const test_vector_search_vector_insert_error = {
try {
await IDX.insert(newVectors);
} catch (e) {
error = e;
error = /** @type {Error} */ (e);
}

assert.equal(
Expand All @@ -150,7 +143,7 @@ export const test_vector_search_vector_insert_error = {
);
}
},
}
};

export const test_vector_search_vector_upsert = {
/**
Expand Down Expand Up @@ -233,8 +226,8 @@ export const test_vector_search_vector_get_ids = {
export const test_vector_search_can_use_enum_exports = {
async test() {
assert.equal(
KnownModel["openapi-text-embedding-ada-002"],
"openapi-text-embedding-ada-002"
KnownModel["openai/text-embedding-ada-002"],
"openai/text-embedding-ada-002"
);
assert.equal(DistanceMetric.COSINE, "cosine");
},
Expand Down
77 changes: 37 additions & 40 deletions src/cloudflare/internal/test/vectorize/vectorize-mock.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,40 +5,34 @@
/** @type {Array<VectorizeMatch>} */
const exampleVectorMatches = [
{
vectorId: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
id: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
values: [0.2331, 1.0125, 0.6131, 0.9421, 0.9661, 0.8121],
metadata: { text: "She sells seashells by the seashore" },
score: 0.71151,
vector: {
id: "b0daca4a-ffd8-4865-926b-e24800af2a2d",
values: [0.2331, 1.0125, 0.6131, 0.9421, 0.9661, 0.8121],
metadata: { text: "She sells seashells by the seashore" },
},
},
{
vectorId: "a44706aa-a366-48bc-8cc1-3feffd87d548",
id: "a44706aa-a366-48bc-8cc1-3feffd87d548",
values: [0.2321, 0.8121, 0.6315, 0.6151, 0.4121, 0.1512],
metadata: { text: "Peter Piper picked a peck of pickled peppers" },
score: 0.68913,
vector: {
id: "a44706aa-a366-48bc-8cc1-3feffd87d548",
values: [0.2321, 0.8121, 0.6315, 0.6151, 0.4121, 0.1512],
metadata: { text: "Peter Piper picked a peck of pickled peppers" },
},
},
{
vectorId: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
score: 0.94812,
vector: {
id: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
values: [0.0515, 0.7512, 0.8612, 0.2153, 0.15121, 0.6812],
metadata: {
text: "You know New York, you need New York, you know you need unique New York",
},
id: "43cfcb31-07e2-411f-8bf9-f82a95ba8b96",
values: [0.0515, 0.7512, 0.8612, 0.2153, 0.15121, 0.6812],
metadata: {
text: "You know New York, you need New York, you know you need unique New York",
},
score: 0.94812,
},
];
/** @type {Array<VectorizeVector>} */
// @ts-ignore
const exampleVectors = exampleVectorMatches
.filter((m) => typeof m !== "undefined")
.map((m) => m.vector);
.map(({ id, values, metadata }) => ({
id,
values: values ?? [],
metadata: metadata ?? {},
}));

export default {
/**
Expand Down Expand Up @@ -105,30 +99,33 @@ export default {
} else if (request.method === "POST" && pathname.endsWith("/query")) {
/** @type {VectorizeQueryOptions & {vector: number[]}} */
const body = await request.json();
if (body?.returnVectors) {
return Response.json({
matches: exampleVectorMatches,
count: exampleVectorMatches.length,
const returnSet = exampleVectorMatches;
if (!body?.returnValues)
returnSet.forEach((v) => {
delete v.values;
});
if (!body?.returnMetadata)
returnSet.forEach((v) => {
delete v.metadata;
});
}
return Response.json({
matches: exampleVectorMatches.map(({ vectorId, score }) => ({
vectorId,
score,
})),
count: exampleVectorMatches.length,
matches: returnSet,
count: returnSet.length,
});
} else if (request.method === "POST" && pathname.endsWith("/insert")) {
/** @type {{vectors: Array<VectorizeVector>}} */
const data = await request.json();
if (data.vectors.find((v) => v.id == 'fail-with-test-error')) {
return Response.json({
code: 9999,
error: 'You asked me for this error',
}, {
status: 400
});
};
if (data.vectors.find((v) => v.id == "fail-with-test-error")) {
return Response.json(
{
code: 9999,
error: "You asked me for this error",
},
{
status: 400,
}
);
}

return Response.json({
ids: [
Expand Down
16 changes: 11 additions & 5 deletions src/cloudflare/internal/vectorize-api.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) 2023 Cloudflare, Inc.
// Licensed under the Apache 2.0 license found in the LICENSE file or at:
// https://opensource.org/licenses/Apache-2.0
import * as flags from 'workerd:compatibility-flags'

interface Fetcher {
fetch: typeof fetch;
Expand All @@ -19,7 +20,7 @@ class VectorizeIndexImpl implements VectorizeIndex {
public constructor(
private readonly fetcher: Fetcher,
private readonly indexId: string
) {}
) { }

public async describe(): Promise<VectorizeIndexDetails> {
const res = await this._send(
Expand All @@ -45,10 +46,16 @@ class VectorizeIndexImpl implements VectorizeIndex {
body: JSON.stringify({
...options,
vector: Array.isArray(vector) ? vector : Array.from(vector),
compat: {
queryMetadataOptional: !!flags.vectorizeQueryMetadataOptional,
},
}),
headers: {
"content-type": "application/json",
accept: "application/json",
"cf-vector-search-query-compat": JSON.stringify({
queryMetadataOptional: !!flags.vectorizeQueryMetadataOptional,
})
},
}
);
Expand Down Expand Up @@ -196,10 +203,9 @@ async function toJson<T = unknown>(response: Response): Promise<T> {
return JSON.parse(body) as T;
} catch (e) {
throw new Error(
`Failed to parse body as JSON, got: ${
body.length > maxBodyLogChars
? `${body.slice(0, maxBodyLogChars)}…`
: body
`Failed to parse body as JSON, got: ${body.length > maxBodyLogChars
? `${body.slice(0, maxBodyLogChars)}…`
: body
}`
);
}
Expand Down
11 changes: 4 additions & 7 deletions src/cloudflare/internal/vectorize.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ type VectorizeDistanceMetric = "euclidean" | "cosine" | "dot-product";
interface VectorizeQueryOptions {
topK?: number;
namespace?: string;
returnVectors?: boolean;
returnValues?: boolean;
returnMetadata?: boolean;
}

/**
Expand Down Expand Up @@ -77,20 +78,16 @@ interface VectorizeVector {
values: VectorFloatArray | number[];
/** The namespace this vector belongs to. */
namespace?: string;
/** Metadata associated with the binding. Includes the values of the other fields and potentially additional details. */
/** Metadata associated with the vector. Includes the values of the other fields and potentially additional details. */
metadata?: Record<string, VectorizeVectorMetadata>;
}

/**
* Represents a matched vector for a query along with its score and (if specified) the matching vector information.
*/
interface VectorizeMatch {
/** The ID for the vector. This can be user-defined, and must be unique. It should uniquely identify the object, and is best set based on the ID of what the vector represents. */
vectorId: string;
type VectorizeMatch = Pick<Partial<VectorizeVector>, 'values'> & Omit<VectorizeVector, 'values'> & {
/** The score or rank for similarity, when returned as a result */
score: number;
/** Vector data for the match. Included only if the user specified they want it returned (via {@link VectorizeQueryOptions}). */
vector?: VectorizeVector;
}

/**
Expand Down
3 changes: 2 additions & 1 deletion src/cloudflare/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
],
"paths": {
"cloudflare:*": ["./*"],
"cloudflare-internal:*": ["./internal/*"]
"cloudflare-internal:*": ["./internal/*"],
"workerd:compatibility-flags": ["./internal/compatibility-flags.d.ts"]
}
},
"include": [
Expand Down
8 changes: 5 additions & 3 deletions src/cloudflare/vectorize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
* These can be supplied in place of configuring explicit dimensions.
*/
export enum KnownModel {
"openapi-text-embedding-ada-002" = "openapi-text-embedding-ada-002",
"workers-ai/bge-small-en" = "workers-ai/bge-small-en",
"cohere/embed-multilingual-v2.0" = "cohere/embed-multilingual-v2.0",
'openai/text-embedding-ada-002' = 'openai/text-embedding-ada-002',
'cohere/embed-multilingual-v2.0' = 'cohere/embed-multilingual-v2.0',
'@cf/baai/bge-small-en-v1.5' = '@cf/baai/bge-small-en-v1.5',
'@cf/baai/bge-base-en-v1.5' = '@cf/baai/bge-base-en-v1.5',
'@cf/baai/bge-large-en-v1.5' = '@cf/baai/bge-large-en-v1.5',
}

/**
Expand Down
1 change: 1 addition & 0 deletions src/node/internal/compatibility-flags.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ export const tcpSocketsSupport: boolean;
export const specCompliantResponseRedirect: boolean;
export const workerdExperimental: boolean;
export const durableObjectGetExisting: boolean;
export const vectorizeQueryMetadataOptional: boolean;
7 changes: 7 additions & 0 deletions src/workerd/io/compatibility-date.capnp
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,11 @@ struct CompatibilityFlags @0x8f8c1b68151b6cef {
# In the WebCrypto API, the `publicExponent` field of the algorithm of RSA keys would previously
# be an ArrayBuffer. Using this flag, publicExponent is a Uint8Array as mandated by the
# specification.

vectorizeQueryMetadataOptional @37 :Bool
$compatEnableFlag("vectorize_query_metadata_optional")
$compatEnableDate("2023-11-08")
$compatDisableFlag("vectorize_query_original");
# Vectorize query option change to allow returning of metadata to be optional. Accompanying this:
# a return format change to move away from a nested object with the VectorizeVector.
}
Loading

0 comments on commit ce4237a

Please sign in to comment.