Skip to content

Commit

Permalink
refactor(js/plugins/vertexai): extract vectorsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac committed Nov 11, 2024
1 parent dd8f173 commit 9ccbd8d
Show file tree
Hide file tree
Showing 18 changed files with 78 additions and 68 deletions.
9 changes: 9 additions & 0 deletions js/plugins/vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
"import": "./lib/modelgarden/index.mjs",
"types": "./lib/modelgarden/index.d.ts",
"default": "./lib/modelgarden/index.js"
},
"./vectorsearch": {
"require": "./lib/vectorsearch/index.js",
"import": "./lib/vectorsearch/index.mjs",
"types": "./lib/vectorsearch/index.d.ts",
"default": "./lib/vectorsearch/index.js"
}
},
"typesVersions": {
Expand All @@ -96,6 +102,9 @@
],
"modelgarden": [
"./lib/modelgarden/index"
],
"vectorsearch": [
"./lib/vectorsearch/index"
]
}
}
Expand Down
36 changes: 1 addition & 35 deletions js/plugins/vertexai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,7 @@ import {
imagen3Fast,
imagenModel,
} from './imagen.js';
import { vertexAiIndexers, vertexAiRetrievers } from './vector-search/index.js';
export { PluginOptions } from './common/types.js';
export {
DocumentIndexer,
DocumentRetriever,
Neighbor,
VectorSearchOptions,
getBigQueryDocumentIndexer,
getBigQueryDocumentRetriever,
getFirestoreDocumentIndexer,
getFirestoreDocumentRetriever,
vertexAiIndexerRef,
vertexAiIndexers,
vertexAiRetrieverRef,
vertexAiRetrievers,
} from './vector-search/index.js';
export {
gemini10Pro,
gemini15Flash,
Expand Down Expand Up @@ -83,28 +68,9 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin {
defineGeminiModel(ai, name, vertexClientFactory, { projectId, location })
);

const embedders = Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) =>
Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) =>
defineVertexAIEmbedder(ai, name, authClient, { projectId, location })
);

if (
options?.vectorSearchOptions &&
options.vectorSearchOptions.length > 0
) {
const defaultEmbedder = embedders[0];

vertexAiIndexers(ai, {
pluginOptions: options,
authClient,
defaultEmbedder,
});

vertexAiRetrievers(ai, {
pluginOptions: options,
authClient,
defaultEmbedder,
});
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,6 @@ import { Genkit } from 'genkit';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import { getDerivedParams } from '../common/index.js';
import { PluginOptions } from './types.js';
// import {
// SUPPORTED_EMBEDDER_MODELS,
// defineVertexAIEmbedder,
// textEmbedding004,
// textEmbeddingGecko003,
// textEmbeddingGeckoMultilingual001,
// textMultilingualEmbedding002,
// } from '../embedder.js';
import { vertexAiIndexers, vertexAiRetrievers } from './vector_search/index.js';
export { PluginOptions } from '../common/types.js';
export {
Expand All @@ -45,7 +37,7 @@ export {
/**
* Add Google Cloud Vertex AI to Genkit. Includes Gemini and Imagen models and text embedder.
*/
export function vertexAI(options?: PluginOptions): GenkitPlugin {
export function vertexAIVectorSearch(options?: PluginOptions): GenkitPlugin {
return genkitPlugin('vertexai', async (ai: Genkit) => {
const { projectId, location, vertexClientFactory, authClient } =
await getDerivedParams(options);
Expand All @@ -68,5 +60,3 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin {
}
});
}

export default vertexAI;
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
* limitations under the License.
*/

import { z } from 'genkit';
import { EmbedderReference, z } from 'genkit';
import { CommonPluginOptions } from '../common/types';
import { VectorSearchOptions } from './vector_search';

/** Options specific to vector search configuration */
export interface VectorSearchOptionsConfig {
/** Configure Vertex AI vector search index options */
vectorSearchOptions?: VectorSearchOptions<z.ZodTypeAny, any, any>[];
embedder?: EmbedderReference;
}

export interface PluginOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ export function vertexAiIndexers<EmbedderCustomOptions extends z.ZodTypeAny>(
params: VertexVectorSearchOptions<EmbedderCustomOptions>
): IndexerAction<z.ZodTypeAny>[] {
const vectorSearchOptions = params.pluginOptions.vectorSearchOptions;
const defaultEmbedder = params.defaultEmbedder;
const indexers: IndexerAction<z.ZodTypeAny>[] = [];

if (!vectorSearchOptions || vectorSearchOptions.length === 0) {
Expand All @@ -67,7 +66,14 @@ export function vertexAiIndexers<EmbedderCustomOptions extends z.ZodTypeAny>(

for (const vectorSearchOption of vectorSearchOptions) {
const { documentIndexer, indexId } = vectorSearchOption;
const embedder = vectorSearchOption.embedder ?? defaultEmbedder;
const embedderReference =
vectorSearchOption.embedder ?? params.defaultEmbedder;

if (!embedderReference) {
throw new Error(
'Embedder reference is required to define Vertex AI retriever'
);
}
const embedderOptions = vectorSearchOption.embedderOptions;

const indexer = ai.defineIndexer(
Expand All @@ -87,7 +93,7 @@ export function vertexAiIndexers<EmbedderCustomOptions extends z.ZodTypeAny>(
}

const embeddings = await ai.embedMany({
embedder,
embedder: embedderReference,
content: docs,
options: embedderOptions,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,17 @@ export function vertexAiRetrievers<EmbedderCustomOptions extends z.ZodTypeAny>(
configSchema: VertexAIVectorRetrieverOptionsSchema.optional(),
},
async (content, options) => {
const embedderReference =
vectorSearchOption.embedder ?? defaultEmbedder;

if (!embedderReference) {
throw new Error(
'Embedder reference is required to define Vertex AI retriever'
);
}

const queryEmbeddings = await ai.embed({
embedder,
embedder: embedderReference,
options: embedderOptions,
content,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import * as aiplatform from '@google-cloud/aiplatform';
import { z } from 'genkit';
import { EmbedderArgument } from 'genkit/embedder';
import { EmbedderReference } from 'genkit/embedder';
import { CommonRetrieverOptionsSchema, Document } from 'genkit/retriever';
import { GoogleAuth } from 'google-auth-library';
import { PluginOptions } from '../types';
Expand All @@ -27,7 +27,7 @@ export interface VertexVectorSearchOptions<
> {
pluginOptions: PluginOptions;
authClient: GoogleAuth;
defaultEmbedder: EmbedderArgument<EmbedderCustomOptions>;
defaultEmbedder?: EmbedderReference<EmbedderCustomOptions>;
}

export type IIndexDatapoint =
Expand Down Expand Up @@ -184,6 +184,6 @@ export interface VectorSearchOptions<
documentRetriever: DocumentRetriever<RetrieverOptions>;
documentIndexer: DocumentIndexer<IndexerOptions>;
// Embedder and default options to use for indexing and retrieval
embedder?: EmbedderArgument<EmbedderCustomOptions>;
embedder?: EmbedderReference<EmbedderCustomOptions>;
embedderOptions?: z.infer<EmbedderCustomOptions>;
}
10 changes: 9 additions & 1 deletion js/testapps/model-tester/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
*/

import { googleAI } from '@genkit-ai/googleai';
import { claude3Sonnet, llama31, vertexAI } from '@genkit-ai/vertexai';
import { vertexAI } from '@genkit-ai/vertexai';
import {
claude3Sonnet,
llama31,
vertexAIModelGarden,
} from '@genkit-ai/vertexai/modelgarden';
import * as clc from 'colorette';
import { genkit } from 'genkit';
import { testModels } from 'genkit/testing';
Expand All @@ -27,6 +32,9 @@ export const ai = genkit({
googleAI(),
vertexAI({
location: 'us-central1',
}),
vertexAIModelGarden({
location: 'us-central1',
modelGarden: {
models: [claude3Sonnet, llama31],
},
Expand Down
9 changes: 6 additions & 3 deletions js/testapps/rag/src/genkit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import { devLocalVectorstore } from '@genkit-ai/dev-local-vectorstore';
import { genkitEval, GenkitMetric } from '@genkit-ai/evaluator';
import { gemini15Flash, googleAI } from '@genkit-ai/googleai';
import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai';
import {
claude3Sonnet,
llama31,
textEmbedding004,
vertexAI,
} from '@genkit-ai/vertexai';
vertexAIModelGarden,
} from '@genkit-ai/vertexai/modelgarden';
import { genkit } from 'genkit';
import { chroma } from 'genkitx-chromadb';
import { langchain } from 'genkitx-langchain';
Expand Down Expand Up @@ -76,6 +76,9 @@ export const ai = genkit({
}),
vertexAI({
location: 'us-central1',
}),
vertexAIModelGarden({
location: 'us-central1',
modelGarden: {
models: [claude3Sonnet, llama31],
},
Expand Down
16 changes: 10 additions & 6 deletions js/testapps/vertexai-vector-search-bigquery/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@

import { Document, genkit, z } from 'genkit';
// important imports for this sample:
import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai';
import {
DocumentIndexer,
DocumentRetriever,
getBigQueryDocumentIndexer,
getBigQueryDocumentRetriever,
vertexAI,
vertexAIVectorSearch,
vertexAiIndexerRef,
vertexAiRetrieverRef,
type DocumentIndexer,
type DocumentRetriever,
} from '@genkit-ai/vertexai';

} from '@genkit-ai/vertexai/vectorsearch';
// // Environment variables set with dotenv for simplicity of sample
import {
BIGQUERY_DATASET,
Expand Down Expand Up @@ -81,6 +81,11 @@ const ai = genkit({
googleAuth: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
},
}),
vertexAIVectorSearch({
location: LOCATION,
projectId: PROJECT_ID,
embedder: textEmbedding004,
vectorSearchOptions: [
{
publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME,
Expand All @@ -95,7 +100,6 @@ const ai = genkit({
],
});

// // Define indexing flow
export const indexFlow = ai.defineFlow(
{
name: 'indexFlow',
Expand Down
10 changes: 8 additions & 2 deletions js/testapps/vertexai-vector-search-custom/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@

import { Document, genkit, z } from 'genkit';
// important imports for this sample:
import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai';
import {
vertexAI,
vertexAIVectorSearch,
vertexAiIndexerRef,
vertexAiRetrieverRef,
type DocumentIndexer,
type DocumentRetriever,
type Neighbor,
} from '@genkit-ai/vertexai';
} from '@genkit-ai/vertexai/vectorsearch';

// // Environment variables set with dotenv for simplicity of sample
import {
Expand Down Expand Up @@ -148,6 +149,10 @@ const ai = genkit({
googleAuth: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
},
}),
vertexAIVectorSearch({
location: LOCATION,
projectId: PROJECT_ID,
vectorSearchOptions: [
{
publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME,
Expand All @@ -156,6 +161,7 @@ const ai = genkit({
deployedIndexId: VECTOR_SEARCH_DEPLOYED_INDEX_ID,
documentRetriever: localDocumentRetriever,
documentIndexer: localDocumentIndexer,
embedder: textEmbedding004,
},
],
}),
Expand Down
12 changes: 10 additions & 2 deletions js/testapps/vertexai-vector-search-firestore/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
import { initializeApp } from 'firebase-admin/app';
import { Document, genkit, z } from 'genkit';
// important imports for this sample:

import { textEmbedding004, vertexAI } from '@genkit-ai/vertexai';

import {
DocumentIndexer,
DocumentRetriever,
getFirestoreDocumentIndexer,
getFirestoreDocumentRetriever,
vertexAI,
vertexAiIndexerRef,
vertexAiRetrieverRef,
} from '@genkit-ai/vertexai';
vertexAIVectorSearch,
} from '@genkit-ai/vertexai/vectorsearch';

// // Environment variables set with dotenv for simplicity of sample
import { getFirestore } from 'firebase-admin/firestore';
Expand Down Expand Up @@ -80,6 +83,10 @@ const ai = genkit({
googleAuth: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
},
}),
vertexAIVectorSearch({
projectId: PROJECT_ID,
location: LOCATION,
vectorSearchOptions: [
{
publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME,
Expand All @@ -88,6 +95,7 @@ const ai = genkit({
deployedIndexId: VECTOR_SEARCH_DEPLOYED_INDEX_ID,
documentRetriever: firestoreDocumentRetriever,
documentIndexer: firestoreDocumentIndexer,
embedder: textEmbedding004,
},
],
}),
Expand Down

0 comments on commit 9ccbd8d

Please sign in to comment.