From 0793d1cf0f71b28eeb3e14954b5010f5f951703e Mon Sep 17 00:00:00 2001 From: hunterheston Date: Tue, 29 Oct 2024 17:32:26 +0000 Subject: [PATCH 01/30] Copy vertexai dir --- js/plugins/checks/.npmignore | 1 + js/plugins/checks/LICENSE | 203 +++++++ js/plugins/checks/README.md | 7 + js/plugins/checks/package.json | 71 +++ js/plugins/checks/src/anthropic.ts | 422 +++++++++++++ js/plugins/checks/src/embedder.ts | 155 +++++ js/plugins/checks/src/evaluation.ts | 440 ++++++++++++++ js/plugins/checks/src/evaluator_factory.ts | 100 ++++ js/plugins/checks/src/gemini.ts | 554 ++++++++++++++++++ js/plugins/checks/src/imagen.ts | 309 ++++++++++ js/plugins/checks/src/index.ts | 262 +++++++++ js/plugins/checks/src/model_garden.ts | 123 ++++ js/plugins/checks/src/openai_compatibility.ts | 350 +++++++++++ js/plugins/checks/src/predict.ts | 83 +++ js/plugins/checks/src/reranker.ts | 159 +++++ .../checks/src/vector-search/bigquery.ts | 131 +++++ .../checks/src/vector-search/firestore.ts | 87 +++ js/plugins/checks/src/vector-search/index.ts | 36 ++ .../checks/src/vector-search/indexers.ts | 120 ++++ .../vector-search/query_public_endpoint.ts | 92 +++ .../checks/src/vector-search/retrievers.ts | 136 +++++ js/plugins/checks/src/vector-search/types.ts | 189 ++++++ .../src/vector-search/upsert_datapoints.ts | 71 +++ js/plugins/checks/src/vector-search/utils.ts | 65 ++ js/plugins/checks/tests/anthropic_test.ts | 313 ++++++++++ js/plugins/checks/tests/gemini_test.ts | 347 +++++++++++ .../tests/vector-search/bigquery_test.ts | 168 ++++++ .../query_public_endpoint_test.ts | 86 +++ .../vector-search/upsert_datapoints_test.ts | 81 +++ .../checks/tests/vector-search/utils_test.ts | 70 +++ js/plugins/checks/tsconfig.json | 4 + js/plugins/checks/tsup.config.ts | 22 + js/pnpm-lock.yaml | 56 ++ 33 files changed, 5313 insertions(+) create mode 100644 js/plugins/checks/.npmignore create mode 100644 js/plugins/checks/LICENSE create mode 100644 js/plugins/checks/README.md create mode 100644 js/plugins/checks/package.json create mode 100644 js/plugins/checks/src/anthropic.ts create mode 100644 js/plugins/checks/src/embedder.ts create mode 100644 js/plugins/checks/src/evaluation.ts create mode 100644 js/plugins/checks/src/evaluator_factory.ts create mode 100644 js/plugins/checks/src/gemini.ts create mode 100644 js/plugins/checks/src/imagen.ts create mode 100644 js/plugins/checks/src/index.ts create mode 100644 js/plugins/checks/src/model_garden.ts create mode 100644 js/plugins/checks/src/openai_compatibility.ts create mode 100644 js/plugins/checks/src/predict.ts create mode 100644 js/plugins/checks/src/reranker.ts create mode 100644 js/plugins/checks/src/vector-search/bigquery.ts create mode 100644 js/plugins/checks/src/vector-search/firestore.ts create mode 100644 js/plugins/checks/src/vector-search/index.ts create mode 100644 js/plugins/checks/src/vector-search/indexers.ts create mode 100644 js/plugins/checks/src/vector-search/query_public_endpoint.ts create mode 100644 js/plugins/checks/src/vector-search/retrievers.ts create mode 100644 js/plugins/checks/src/vector-search/types.ts create mode 100644 js/plugins/checks/src/vector-search/upsert_datapoints.ts create mode 100644 js/plugins/checks/src/vector-search/utils.ts create mode 100644 js/plugins/checks/tests/anthropic_test.ts create mode 100644 js/plugins/checks/tests/gemini_test.ts create mode 100644 js/plugins/checks/tests/vector-search/bigquery_test.ts create mode 100644 js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts create mode 100644 js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts create mode 100644 js/plugins/checks/tests/vector-search/utils_test.ts create mode 100644 js/plugins/checks/tsconfig.json create mode 100644 js/plugins/checks/tsup.config.ts diff --git a/js/plugins/checks/.npmignore b/js/plugins/checks/.npmignore new file mode 100644 index 000000000..b512c09d4 --- /dev/null +++ b/js/plugins/checks/.npmignore @@ -0,0 +1 @@ +node_modules \ No newline at end of file diff --git a/js/plugins/checks/LICENSE b/js/plugins/checks/LICENSE new file mode 100644 index 000000000..26a870243 --- /dev/null +++ b/js/plugins/checks/LICENSE @@ -0,0 +1,203 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + \ No newline at end of file diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md new file mode 100644 index 000000000..324aee45e --- /dev/null +++ b/js/plugins/checks/README.md @@ -0,0 +1,7 @@ +# Genkit + +The sources for this package are in the main [Genkit](https://github.com/firebase/genkit) repo. Please file issues and pull requests against that repo. + +Usage information and reference details can be found in [Genkit documentation](https://firebase.google.com/docs/genkit). + +License: Apache 2.0 diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json new file mode 100644 index 000000000..23e2d5062 --- /dev/null +++ b/js/plugins/checks/package.json @@ -0,0 +1,71 @@ +{ + "name": "@genkit-ai/checks", + "description": "Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.", + "keywords": [ + "genkit", + "genkit-plugin", + "genkit-embedder", + "genkit-model", + "google cloud", + "vertex ai", + "imagen", + "image-generation", + "gemini", + "google gemini", + "google ai", + "ai", + "genai", + "generative-ai" + ], + "version": "0.9.0-dev.2", + "type": "commonjs", + "scripts": { + "check": "tsc", + "compile": "tsup-node", + "build:clean": "rimraf ./lib", + "build": "npm-run-all build:clean check compile", + "build:watch": "tsup-node --watch", + "test": "tsx --test ./tests/*_test.ts ./tests/**/*_test.ts" + }, + "repository": { + "type": "git", + "url": "https://github.com/firebase/genkit.git", + "directory": "js/plugins/checks" + }, + "author": "genkit", + "license": "Apache-2.0", + "dependencies": { + "@anthropic-ai/sdk": "^0.24.3", + "@anthropic-ai/vertex-sdk": "^0.4.0", + "@google-cloud/aiplatform": "^3.23.0", + "@google-cloud/vertexai": "^1.1.0", + "google-auth-library": "^9.6.3", + "googleapis": "^140.0.1", + "node-fetch": "^3.3.2", + "openai": "^4.52.7" + }, + "peerDependencies": { + "genkit": "workspace:*" + }, + "optionalDependencies": { + "firebase-admin": ">=12.2", + "@google-cloud/bigquery": "^7.8.0" + }, + "devDependencies": { + "@types/node": "^20.11.16", + "npm-run-all": "^4.1.5", + "rimraf": "^6.0.1", + "tsup": "^8.0.2", + "tsx": "^4.7.0", + "typescript": "^4.9.0" + }, + "types": "./lib/index.d.ts", + "exports": { + ".": { + "require": "./lib/index.js", + "import": "./lib/index.mjs", + "types": "./lib/index.d.ts", + "default": "./lib/index.js" + } + } +} diff --git a/js/plugins/checks/src/anthropic.ts b/js/plugins/checks/src/anthropic.ts new file mode 100644 index 000000000..a28ea12d4 --- /dev/null +++ b/js/plugins/checks/src/anthropic.ts @@ -0,0 +1,422 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + ContentBlock as AnthropicContent, + ImageBlockParam, + Message, + MessageCreateParamsBase, + MessageParam, + TextBlock, + TextBlockParam, + TextDelta, + Tool, + ToolResultBlockParam, + ToolUseBlock, + ToolUseBlockParam, +} from '@anthropic-ai/sdk/resources/messages'; +import { AnthropicVertex } from '@anthropic-ai/vertex-sdk'; +import { + GENKIT_CLIENT_HEADER, + GenerateRequest, + Genkit, + Part as GenkitPart, + MessageData, + ModelReference, + ModelResponseData, + Part, + z, +} from 'genkit'; +import { + GenerationCommonConfigSchema, + getBasicUsageStats, + modelRef, +} from 'genkit/model'; + +export const AnthropicConfigSchema = GenerationCommonConfigSchema.extend({ + location: z.string().optional(), +}); + +export const claude35Sonnet = modelRef({ + name: 'vertexai/claude-3-5-sonnet', + info: { + label: 'Vertex AI Model Garden - Claude 3.5 Sonnet', + versions: ['claude-3-5-sonnet@20240620'], + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + output: ['text'], + }, + }, + configSchema: AnthropicConfigSchema, +}); + +export const claude3Sonnet = modelRef({ + name: 'vertexai/claude-3-sonnet', + info: { + label: 'Vertex AI Model Garden - Claude 3 Sonnet', + versions: ['claude-3-sonnet@20240229'], + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + output: ['text'], + }, + }, + configSchema: AnthropicConfigSchema, +}); + +export const claude3Haiku = modelRef({ + name: 'vertexai/claude-3-haiku', + info: { + label: 'Vertex AI Model Garden - Claude 3 Haiku', + versions: ['claude-3-haiku@20240307'], + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + output: ['text'], + }, + }, + configSchema: AnthropicConfigSchema, +}); + +export const claude3Opus = modelRef({ + name: 'vertexai/claude-3-opus', + info: { + label: 'Vertex AI Model Garden - Claude 3 Opus', + versions: ['claude-3-opus@20240229'], + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + output: ['text'], + }, + }, + configSchema: AnthropicConfigSchema, +}); + +export const SUPPORTED_ANTHROPIC_MODELS: Record< + string, + ModelReference +> = { + 'claude-3-5-sonnet': claude35Sonnet, + 'claude-3-sonnet': claude3Sonnet, + 'claude-3-opus': claude3Opus, + 'claude-3-haiku': claude3Haiku, +}; + +export function toAnthropicRequest( + model: string, + input: GenerateRequest +): MessageCreateParamsBase { + let system: string | undefined = undefined; + const messages: MessageParam[] = []; + for (const msg of input.messages) { + if (msg.role === 'system') { + system = msg.content + .map((c) => { + if (!c.text) { + throw new Error( + 'Only text context is supported for system messages.' + ); + } + return c.text; + }) + .join(); + } + // If the last message is a tool response, we need to add a user message. + // https://docs.anthropic.com/en/docs/build-with-claude/tool-use#handling-tool-use-and-tool-result-content-blocks + else if (msg.content[msg.content.length - 1].toolResponse) { + messages.push({ + role: 'user', + content: toAnthropicContent(msg.content), + }); + } else { + messages.push({ + role: toAnthropicRole(msg.role), + content: toAnthropicContent(msg.content), + }); + } + } + const request = { + model, + messages, + // https://docs.anthropic.com/claude/docs/models-overview#model-comparison + max_tokens: input.config?.maxOutputTokens ?? 4096, + } as MessageCreateParamsBase; + if (system) { + request['system'] = system; + } + if (input.tools) { + request.tools = input.tools?.map((tool) => { + return { + name: tool.name, + description: tool.description, + input_schema: tool.inputSchema, + }; + }) as Array; + } + if (input.config?.stopSequences) { + request.stop_sequences = input.config?.stopSequences; + } + if (input.config?.temperature) { + request.temperature = input.config?.temperature; + } + if (input.config?.topK) { + request.top_k = input.config?.topK; + } + if (input.config?.topP) { + request.top_p = input.config?.topP; + } + return request; +} + +function toAnthropicContent( + content: GenkitPart[] +): Array< + TextBlockParam | ImageBlockParam | ToolUseBlockParam | ToolResultBlockParam +> { + return content.map((p) => { + if (p.text) { + return { + type: 'text', + text: p.text, + }; + } + if (p.media) { + let b64Data = p.media.url; + if (b64Data.startsWith('data:')) { + b64Data = b64Data.substring(b64Data.indexOf(',')! + 1); + } + + return { + type: 'image', + source: { + type: 'base64', + data: b64Data, + media_type: p.media.contentType as + | 'image/jpeg' + | 'image/png' + | 'image/gif' + | 'image/webp', + }, + }; + } + if (p.toolRequest) { + return toAnthropicToolRequest(p.toolRequest); + } + if (p.toolResponse) { + return toAnthropicToolResponse(p); + } + throw new Error(`Unsupported content type: ${JSON.stringify(p)}`); + }); +} + +function toAnthropicRole(role): 'user' | 'assistant' { + if (role === 'model') { + return 'assistant'; + } + if (role === 'user') { + return 'user'; + } + if (role === 'tool') { + return 'assistant'; + } + throw new Error(`Unsupported role type ${role}`); +} + +function fromAnthropicTextPart(part: TextBlock): Part { + return { + text: part.text, + }; +} + +function fromAnthropicToolCallPart(part: ToolUseBlock): Part { + return { + toolRequest: { + name: part.name, + input: part.input, + ref: part.id, + }, + }; +} + +// Converts an Anthropic part to a Genkit part. +function fromAnthropicPart(part: AnthropicContent): Part { + if (part.type === 'text') return fromAnthropicTextPart(part); + if (part.type === 'tool_use') return fromAnthropicToolCallPart(part); + throw new Error( + 'Part type is unsupported/corrupted. Either data is missing or type cannot be inferred from type.' + ); +} + +// Converts an Anthropic response to a Genkit response. +export function fromAnthropicResponse( + input: GenerateRequest, + response: Message +): ModelResponseData { + const parts = response.content as AnthropicContent[]; + const message: MessageData = { + role: 'model', + content: parts.map(fromAnthropicPart), + }; + return { + message, + finishReason: toGenkitFinishReason( + response.stop_reason as + | 'end_turn' + | 'max_tokens' + | 'stop_sequence' + | 'tool_use' + | null + ), + custom: { + id: response.id, + model: response.model, + type: response.type, + }, + usage: { + ...getBasicUsageStats(input.messages, message), + inputTokens: response.usage.input_tokens, + outputTokens: response.usage.output_tokens, + }, + }; +} + +function toGenkitFinishReason( + reason: 'end_turn' | 'max_tokens' | 'stop_sequence' | 'tool_use' | null +): ModelResponseData['finishReason'] { + switch (reason) { + case 'end_turn': + return 'stop'; + case 'max_tokens': + return 'length'; + case 'stop_sequence': + return 'stop'; + case 'tool_use': + return 'stop'; + case null: + return 'unknown'; + default: + return 'other'; + } +} + +function toAnthropicToolRequest(tool: Record): ToolUseBlock { + if (!tool.name) { + throw new Error('Tool name is required'); + } + // Validate the tool name, Anthropic only supports letters, numbers, and underscores. + // https://docs.anthropic.com/en/docs/build-with-claude/tool-use#specifying-tools + if (!/^[a-zA-Z0-9_-]{1,64}$/.test(tool.name)) { + throw new Error( + `Tool name ${tool.name} contains invalid characters. + Only letters, numbers, and underscores are allowed, + and the name must be between 1 and 64 characters long.` + ); + } + const declaration: ToolUseBlock = { + type: 'tool_use', + id: tool.ref, + name: tool.name, + input: tool.input, + }; + return declaration; +} + +function toAnthropicToolResponse(part: Part): ToolResultBlockParam { + if (!part.toolResponse?.ref) { + throw new Error('Tool response reference is required'); + } + + if (!part.toolResponse.output) { + throw new Error('Tool response output is required'); + } + + return { + type: 'tool_result', + tool_use_id: part.toolResponse.ref, + content: JSON.stringify(part.toolResponse.output), + }; +} + +export function anthropicModel( + ai: Genkit, + modelName: string, + projectId: string, + region: string +) { + const clients: Record = {}; + const clientFactory = (region: string): AnthropicVertex => { + if (!clients[region]) { + clients[region] = new AnthropicVertex({ + region, + projectId, + defaultHeaders: { + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + }, + }); + } + return clients[region]; + }; + const model = SUPPORTED_ANTHROPIC_MODELS[modelName]; + if (!model) { + throw new Error(`unsupported Anthropic model name ${modelName}`); + } + + return ai.defineModel( + { + name: model.name, + label: model.info?.label, + configSchema: AnthropicConfigSchema, + supports: model.info?.supports, + versions: model.info?.versions, + }, + async (input, streamingCallback) => { + const client = clientFactory(input.config?.location || region); + if (!streamingCallback) { + const response = await client.messages.create({ + ...toAnthropicRequest(input.config?.version ?? modelName, input), + stream: false, + }); + return fromAnthropicResponse(input, response); + } else { + const stream = await client.messages.stream( + toAnthropicRequest(input.config?.version ?? modelName, input) + ); + for await (const event of stream) { + if (event.type === 'content_block_delta') { + streamingCallback({ + index: 0, + content: [ + { + text: (event.delta as TextDelta).text, + }, + ], + }); + } + } + return fromAnthropicResponse(input, await stream.finalMessage()); + } + } + ); +} diff --git a/js/plugins/checks/src/embedder.ts b/js/plugins/checks/src/embedder.ts new file mode 100644 index 000000000..10d2ca18c --- /dev/null +++ b/js/plugins/checks/src/embedder.ts @@ -0,0 +1,155 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit, z } from 'genkit'; +import { EmbedderReference, embedderRef } from 'genkit/embedder'; +import { GoogleAuth } from 'google-auth-library'; +import { PluginOptions } from './index.js'; +import { PredictClient, predictModel } from './predict.js'; + +export const TaskTypeSchema = z.enum([ + 'RETRIEVAL_DOCUMENT', + 'RETRIEVAL_QUERY', + 'SEMANTIC_SIMILARITY', + 'CLASSIFICATION', + 'CLUSTERING', +]); + +export type TaskType = z.infer; + +export const VertexEmbeddingConfigSchema = z.object({ + /** + * The `task_type` parameter is defined as the intended downstream application to help the model + * produce better quality embeddings. + **/ + taskType: TaskTypeSchema.optional(), + title: z.string().optional(), + location: z.string().optional(), + version: z.string().optional(), +}); + +export type VertexEmbeddingConfig = z.infer; + +function commonRef( + name: string, + input?: ('text' | 'image')[] +): EmbedderReference { + return embedderRef({ + name: `vertexai/${name}`, + configSchema: VertexEmbeddingConfigSchema, + info: { + dimensions: 768, + label: `Vertex AI - ${name}`, + supports: { + input: input ?? ['text'], + }, + }, + }); +} + +export const textEmbeddingGecko003 = commonRef('textembedding-gecko@003'); +export const textEmbedding004 = commonRef('text-embedding-004'); +export const textEmbeddingGeckoMultilingual001 = commonRef( + 'textembedding-gecko-multilingual@001' +); +export const textMultilingualEmbedding002 = commonRef( + 'text-multilingual-embedding-002' +); + +export const SUPPORTED_EMBEDDER_MODELS: Record = { + 'textembedding-gecko@003': textEmbeddingGecko003, + 'text-embedding-004': textEmbedding004, + 'textembedding-gecko-multilingual@001': textEmbeddingGeckoMultilingual001, + 'text-multilingual-embedding-002': textMultilingualEmbedding002, + // TODO: add support for multimodal embeddings + // 'multimodalembedding@001': commonRef('multimodalembedding@001', [ + // 'image', + // 'text', + // ]), +}; + +interface EmbeddingInstance { + task_type?: TaskType; + content: string; + title?: string; +} +interface EmbeddingPrediction { + embeddings: { + statistics: { + truncated: boolean; + token_count: number; + }; + values: number[]; + }; +} + +export function defineVertexAIEmbedder( + ai: Genkit, + name: string, + client: GoogleAuth, + options: PluginOptions +) { + const embedder = SUPPORTED_EMBEDDER_MODELS[name]; + const predictClients: Record< + string, + PredictClient + > = {}; + const predictClientFactory = ( + config: VertexEmbeddingConfig + ): PredictClient => { + const requestLocation = config?.location || options.location; + if (!predictClients[requestLocation]) { + // TODO: Figure out how to allow different versions while still sharing a single implementation. + predictClients[requestLocation] = predictModel< + EmbeddingInstance, + EmbeddingPrediction + >( + client, + { + ...options, + location: requestLocation, + }, + name + ); + } + return predictClients[requestLocation]; + }; + + return ai.defineEmbedder( + { + name: embedder.name, + configSchema: embedder.configSchema, + info: embedder.info!, + }, + async (input, options) => { + const predictClient = predictClientFactory(options); + const response = await predictClient( + input.map((i) => { + return { + content: i.text, + task_type: options?.taskType, + title: options?.title, + }; + }) + ); + return { + embeddings: response.predictions.map((p) => ({ + embedding: p.embeddings.values, + })), + }; + } + ); +} diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts new file mode 100644 index 000000000..965b9fc86 --- /dev/null +++ b/js/plugins/checks/src/evaluation.ts @@ -0,0 +1,440 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Action, Genkit, z } from 'genkit'; +import { GoogleAuth } from 'google-auth-library'; +import { EvaluatorFactory } from './evaluator_factory.js'; + +/** + * Vertex AI Evaluation metrics. See API documentation for more information. + * https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation#parameter-list + */ +export enum VertexAIEvaluationMetricType { + // Update genkit/docs/plugins/vertex-ai.md when modifying the list of enums + BLEU = 'BLEU', + ROUGE = 'ROUGE', + FLUENCY = 'FLEUNCY', + SAFETY = 'SAFETY', + GROUNDEDNESS = 'GROUNDEDNESS', + SUMMARIZATION_QUALITY = 'SUMMARIZATION_QUALITY', + SUMMARIZATION_HELPFULNESS = 'SUMMARIZATION_HELPFULNESS', + SUMMARIZATION_VERBOSITY = 'SUMMARIZATION_VERBOSITY', +} + +/** + * Evaluation metric config. Use `metricSpec` to define the behavior of the metric. + * The value of `metricSpec` will be included in the request to the API. See the API documentation + * for details on the possible values of `metricSpec` for each metric. + * https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation#parameter-list + */ +export type VertexAIEvaluationMetricConfig = { + type: VertexAIEvaluationMetricType; + metricSpec: any; +}; + +export type VertexAIEvaluationMetric = + | VertexAIEvaluationMetricType + | VertexAIEvaluationMetricConfig; + +export function vertexEvaluators( + ai: Genkit, + auth: GoogleAuth, + metrics: VertexAIEvaluationMetric[], + projectId: string, + location: string +): Action[] { + const factory = new EvaluatorFactory(auth, location, projectId); + return metrics.map((metric) => { + const metricType = isConfig(metric) ? metric.type : metric; + const metricSpec = isConfig(metric) ? metric.metricSpec : {}; + + switch (metricType) { + case VertexAIEvaluationMetricType.BLEU: { + return createBleuEvaluator(ai, factory, metricSpec); + } + case VertexAIEvaluationMetricType.ROUGE: { + return createRougeEvaluator(ai, factory, metricSpec); + } + case VertexAIEvaluationMetricType.FLUENCY: { + return createFluencyEvaluator(ai, factory, metricSpec); + } + case VertexAIEvaluationMetricType.SAFETY: { + return createSafetyEvaluator(ai, factory, metricSpec); + } + case VertexAIEvaluationMetricType.GROUNDEDNESS: { + return createGroundednessEvaluator(ai, factory, metricSpec); + } + case VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY: { + return createSummarizationQualityEvaluator(ai, factory, metricSpec); + } + case VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS: { + return createSummarizationHelpfulnessEvaluator(ai, factory, metricSpec); + } + case VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY: { + return createSummarizationVerbosityEvaluator(ai, factory, metricSpec); + } + } + }); +} + +function isConfig( + config: VertexAIEvaluationMetric +): config is VertexAIEvaluationMetricConfig { + return (config as VertexAIEvaluationMetricConfig).type !== undefined; +} + +const BleuResponseSchema = z.object({ + bleuResults: z.object({ + bleuMetricValues: z.array(z.object({ score: z.number() })), + }), +}); + +// TODO: Add support for batch inputs +function createBleuEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.BLEU, + displayName: 'BLEU', + definition: + 'Computes the BLEU score by comparing the output against the ground truth', + responseSchema: BleuResponseSchema, + }, + (datapoint) => { + return { + bleuInput: { + metricSpec, + instances: [ + { + prediction: datapoint.output as string, + reference: datapoint.reference, + }, + ], + }, + }; + }, + (response) => { + return { + score: response.bleuResults.bleuMetricValues[0].score, + }; + } + ); +} + +const RougeResponseSchema = z.object({ + rougeResults: z.object({ + rougeMetricValues: z.array(z.object({ score: z.number() })), + }), +}); + +// TODO: Add support for batch inputs +function createRougeEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.ROUGE, + displayName: 'ROUGE', + definition: + 'Computes the ROUGE score by comparing the output against the ground truth', + responseSchema: RougeResponseSchema, + }, + (datapoint) => { + return { + rougeInput: { + metricSpec, + instances: { + prediction: datapoint.output as string, + reference: datapoint.reference, + }, + }, + }; + }, + (response) => { + return { + score: response.rougeResults.rougeMetricValues[0].score, + }; + } + ); +} + +const FluencyResponseSchema = z.object({ + fluencyResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + +function createFluencyEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.FLUENCY, + displayName: 'Fluency', + definition: 'Assesses the language mastery of an output', + responseSchema: FluencyResponseSchema, + }, + (datapoint) => { + return { + fluencyInput: { + metricSpec, + instance: { + prediction: datapoint.output as string, + }, + }, + }; + }, + (response) => { + return { + score: response.fluencyResult.score, + details: { + reasoning: response.fluencyResult.explanation, + }, + }; + } + ); +} + +const SafetyResponseSchema = z.object({ + safetyResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + +function createSafetyEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.SAFETY, + displayName: 'Safety', + definition: 'Assesses the level of safety of an output', + responseSchema: SafetyResponseSchema, + }, + (datapoint) => { + return { + safetyInput: { + metricSpec, + instance: { + prediction: datapoint.output as string, + }, + }, + }; + }, + (response) => { + return { + score: response.safetyResult.score, + details: { + reasoning: response.safetyResult.explanation, + }, + }; + } + ); +} + +const GroundednessResponseSchema = z.object({ + groundednessResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + +function createGroundednessEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.GROUNDEDNESS, + displayName: 'Groundedness', + definition: + 'Assesses the ability to provide or reference information included only in the context', + responseSchema: GroundednessResponseSchema, + }, + (datapoint) => { + return { + groundednessInput: { + metricSpec, + instance: { + prediction: datapoint.output as string, + context: datapoint.context?.join('. '), + }, + }, + }; + }, + (response) => { + return { + score: response.groundednessResult.score, + details: { + reasoning: response.groundednessResult.explanation, + }, + }; + } + ); +} + +const SummarizationQualityResponseSchema = z.object({ + summarizationQualityResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + +function createSummarizationQualityEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY, + displayName: 'Summarization quality', + definition: 'Assesses the overall ability to summarize text', + responseSchema: SummarizationQualityResponseSchema, + }, + (datapoint) => { + return { + summarizationQualityInput: { + metricSpec, + instance: { + prediction: datapoint.output as string, + instruction: datapoint.input as string, + context: datapoint.context?.join('. '), + }, + }, + }; + }, + (response) => { + return { + score: response.summarizationQualityResult.score, + details: { + reasoning: response.summarizationQualityResult.explanation, + }, + }; + } + ); +} + +const SummarizationHelpfulnessResponseSchema = z.object({ + summarizationHelpfulnessResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + +function createSummarizationHelpfulnessEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS, + displayName: 'Summarization helpfulness', + definition: + 'Assesses the ability to provide a summarization, which contains the details necessary to substitute the original text', + responseSchema: SummarizationHelpfulnessResponseSchema, + }, + (datapoint) => { + return { + summarizationHelpfulnessInput: { + metricSpec, + instance: { + prediction: datapoint.output as string, + instruction: datapoint.input as string, + context: datapoint.context?.join('. '), + }, + }, + }; + }, + (response) => { + return { + score: response.summarizationHelpfulnessResult.score, + details: { + reasoning: response.summarizationHelpfulnessResult.explanation, + }, + }; + } + ); +} + +const SummarizationVerbositySchema = z.object({ + summarizationVerbosityResult: z.object({ + score: z.number(), + explanation: z.string(), + confidence: z.number(), + }), +}); + +function createSummarizationVerbosityEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY, + displayName: 'Summarization verbosity', + definition: 'Aassess the ability to provide a succinct summarization', + responseSchema: SummarizationVerbositySchema, + }, + (datapoint) => { + return { + summarizationVerbosityInput: { + metricSpec, + instance: { + prediction: datapoint.output as string, + instruction: datapoint.input as string, + context: datapoint.context?.join('. '), + }, + }, + }; + }, + (response) => { + return { + score: response.summarizationVerbosityResult.score, + details: { + reasoning: response.summarizationVerbosityResult.explanation, + }, + }; + } + ); +} diff --git a/js/plugins/checks/src/evaluator_factory.ts b/js/plugins/checks/src/evaluator_factory.ts new file mode 100644 index 000000000..821f4631b --- /dev/null +++ b/js/plugins/checks/src/evaluator_factory.ts @@ -0,0 +1,100 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Action, Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; +import { BaseEvalDataPoint, Score } from 'genkit/evaluator'; +import { runInNewSpan } from 'genkit/tracing'; +import { GoogleAuth } from 'google-auth-library'; +import { VertexAIEvaluationMetricType } from './evaluation.js'; + +export class EvaluatorFactory { + constructor( + private readonly auth: GoogleAuth, + private readonly location: string, + private readonly projectId: string + ) {} + + create( + ai: Genkit, + config: { + metric: VertexAIEvaluationMetricType; + displayName: string; + definition: string; + responseSchema: ResponseType; + }, + toRequest: (datapoint: BaseEvalDataPoint) => any, + responseHandler: (response: z.infer) => Score + ): Action { + return ai.defineEvaluator( + { + name: `vertexai/${config.metric.toLocaleLowerCase()}`, + displayName: config.displayName, + definition: config.definition, + }, + async (datapoint: BaseEvalDataPoint) => { + const responseSchema = config.responseSchema; + const response = await this.evaluateInstances( + toRequest(datapoint), + responseSchema + ); + + return { + evaluation: responseHandler(response), + testCaseId: datapoint.testCaseId, + }; + } + ); + } + + async evaluateInstances( + partialRequest: any, + responseSchema: ResponseType + ): Promise> { + const locationName = `projects/${this.projectId}/locations/${this.location}`; + return await runInNewSpan( + { + metadata: { + name: 'EvaluationService#evaluateInstances', + }, + }, + async (metadata, _otSpan) => { + const request = { + location: locationName, + ...partialRequest, + }; + + metadata.input = request; + const client = await this.auth.getClient(); + const url = `https://${this.location}-aiplatform.googleapis.com/v1beta1/${locationName}:evaluateInstances`; + const response = await client.request({ + url, + method: 'POST', + body: JSON.stringify(request), + headers: { + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + }, + }); + metadata.output = response.data; + + try { + return responseSchema.parse(response.data); + } catch (e) { + throw new Error(`Error parsing ${url} API response: ${e}`); + } + } + ); + } +} diff --git a/js/plugins/checks/src/gemini.ts b/js/plugins/checks/src/gemini.ts new file mode 100644 index 000000000..eccf1a869 --- /dev/null +++ b/js/plugins/checks/src/gemini.ts @@ -0,0 +1,554 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Content, + FunctionDeclaration, + FunctionDeclarationSchemaType, + Part as GeminiPart, + GenerateContentCandidate, + GenerateContentResponse, + GenerateContentResult, + HarmBlockThreshold, + HarmCategory, + StartChatParams, + VertexAI, +} from '@google-cloud/vertexai'; +import { GENKIT_CLIENT_HEADER, Genkit, z } from 'genkit'; +import { + CandidateData, + GenerateRequest, + GenerationCommonConfigSchema, + MediaPart, + MessageData, + ModelAction, + ModelMiddleware, + ModelReference, + Part, + ToolDefinitionSchema, + getBasicUsageStats, + modelRef, +} from 'genkit/model'; +import { + downloadRequestMedia, + simulateSystemPrompt, +} from 'genkit/model/middleware'; +import { PluginOptions } from './index.js'; + +const SafetySettingsSchema = z.object({ + category: z.nativeEnum(HarmCategory), + threshold: z.nativeEnum(HarmBlockThreshold), +}); + +const VertexRetrievalSchema = z.object({ + datastore: z.object({ + projectId: z.string().optional(), + location: z.string().optional(), + dataStoreId: z.string(), + }), + disableAttribution: z.boolean().optional(), +}); + +const GoogleSearchRetrievalSchema = z.object({ + disableAttribution: z.boolean().optional(), +}); + +export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ + safetySettings: z.array(SafetySettingsSchema).optional(), + location: z.string().optional(), + vertexRetrieval: VertexRetrievalSchema.optional(), + googleSearchRetrieval: GoogleSearchRetrievalSchema.optional(), +}); + +export const gemini10Pro = modelRef({ + name: 'vertexai/gemini-1.0-pro', + info: { + label: 'Vertex AI - Gemini Pro', + versions: ['gemini-1.0-pro-001', 'gemini-1.0-pro-002'], + supports: { + multiturn: true, + media: false, + tools: true, + systemRole: true, + }, + }, + configSchema: GeminiConfigSchema, +}); + +export const gemini15Pro = modelRef({ + name: 'vertexai/gemini-1.5-pro', + info: { + label: 'Vertex AI - Gemini 1.5 Pro', + versions: ['gemini-1.5-pro-001', 'gemini-1.5-pro-002'], + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + }, + }, + configSchema: GeminiConfigSchema, +}); + +export const gemini15Flash = modelRef({ + name: 'vertexai/gemini-1.5-flash', + info: { + label: 'Vertex AI - Gemini 1.5 Flash', + versions: ['gemini-1.5-flash-001', 'gemini-1.5-flash-002'], + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + }, + }, + configSchema: GeminiConfigSchema, +}); + +export const SUPPORTED_V1_MODELS = { + 'gemini-1.0-pro': gemini10Pro, +}; + +export const SUPPORTED_V15_MODELS = { + 'gemini-1.5-pro': gemini15Pro, + 'gemini-1.5-flash': gemini15Flash, +}; + +export const SUPPORTED_GEMINI_MODELS = { + ...SUPPORTED_V1_MODELS, + ...SUPPORTED_V15_MODELS, +}; + +function toGeminiRole( + role: MessageData['role'], + model?: ModelReference +): string { + switch (role) { + case 'user': + return 'user'; + case 'model': + return 'model'; + case 'system': + if (model && SUPPORTED_V15_MODELS[model.name]) { + // We should have already pulled out the supported system messages, + // anything remaining is unsupported; throw an error. + throw new Error( + 'system role is only supported for a single message in the first position' + ); + } else { + throw new Error('system role is not supported'); + } + case 'tool': + return 'function'; + default: + return 'user'; + } +} + +const toGeminiTool = ( + tool: z.infer +): FunctionDeclaration => { + const declaration: FunctionDeclaration = { + name: tool.name.replace(/\//g, '__'), // Gemini throws on '/' in tool name + description: tool.description, + parameters: convertSchemaProperty(tool.inputSchema), + }; + return declaration; +}; + +const toGeminiFileDataPart = (part: MediaPart): GeminiPart => { + const media = part.media; + if (media.url.startsWith('gs://')) { + if (!media.contentType) + throw new Error( + 'Must supply contentType when using media from gs:// URLs.' + ); + return { + fileData: { + mimeType: media.contentType, + fileUri: media.url, + }, + }; + } else if (media.url.startsWith('data:')) { + const dataUrl = media.url; + const b64Data = dataUrl.substring(dataUrl.indexOf(',')! + 1); + const contentType = + media.contentType || + dataUrl.substring(dataUrl.indexOf(':')! + 1, dataUrl.indexOf(';')); + return { inlineData: { mimeType: contentType, data: b64Data } }; + } + + throw Error( + 'Could not convert genkit part to gemini tool response part: missing file data' + ); +}; + +const toGeminiToolRequestPart = (part: Part): GeminiPart => { + if (!part?.toolRequest?.input) { + throw Error( + 'Could not convert genkit part to gemini tool response part: missing tool request data' + ); + } + return { + functionCall: { + name: part.toolRequest.name, + args: part.toolRequest.input, + }, + }; +}; + +const toGeminiToolResponsePart = (part: Part): GeminiPart => { + if (!part?.toolResponse?.output) { + throw Error( + 'Could not convert genkit part to gemini tool response part: missing tool response data' + ); + } + return { + functionResponse: { + name: part.toolResponse.name, + response: { + name: part.toolResponse.name, + content: part.toolResponse.output, + }, + }, + }; +}; + +export function toGeminiSystemInstruction(message: MessageData): Content { + return { + role: 'user', + parts: message.content.map(toGeminiPart), + }; +} + +export function toGeminiMessage( + message: MessageData, + model?: ModelReference +): Content { + return { + role: toGeminiRole(message.role, model), + parts: message.content.map(toGeminiPart), + }; +} + +function fromGeminiFinishReason( + reason: GenerateContentCandidate['finishReason'] +): CandidateData['finishReason'] { + if (!reason) return 'unknown'; + switch (reason) { + case 'STOP': + return 'stop'; + case 'MAX_TOKENS': + return 'length'; + case 'SAFETY': // blocked for safety + case 'RECITATION': // blocked for reciting training data + return 'blocked'; + default: + return 'unknown'; + } +} + +function toGeminiPart(part: Part): GeminiPart { + if (part.text) { + return { text: part.text }; + } else if (part.media) { + return toGeminiFileDataPart(part); + } else if (part.toolRequest) { + return toGeminiToolRequestPart(part); + } else if (part.toolResponse) { + return toGeminiToolResponsePart(part); + } else { + throw new Error('unsupported type'); + } +} + +function fromGeminiInlineDataPart(part: GeminiPart): MediaPart { + // Check if the required properties exist + if ( + !part.inlineData || + !part.inlineData.hasOwnProperty('mimeType') || + !part.inlineData.hasOwnProperty('data') + ) { + throw new Error('Invalid GeminiPart: missing required properties'); + } + const { mimeType, data } = part.inlineData; + // Combine data and mimeType into a data URL + const dataUrl = `data:${mimeType};base64,${data}`; + return { + media: { + url: dataUrl, + contentType: mimeType, + }, + }; +} + +function fromGeminiFileDataPart(part: GeminiPart): MediaPart { + if ( + !part.fileData || + !part.fileData.hasOwnProperty('mimeType') || + !part.fileData.hasOwnProperty('url') + ) { + throw new Error( + 'Invalid Gemini File Data Part: missing required properties' + ); + } + + return { + media: { + url: part.fileData?.fileUri, + contentType: part.fileData?.mimeType, + }, + }; +} + +function fromGeminiFunctionCallPart(part: GeminiPart): Part { + if (!part.functionCall) { + throw new Error( + 'Invalid Gemini Function Call Part: missing function call data' + ); + } + return { + toolRequest: { + name: part.functionCall.name, + input: part.functionCall.args, + }, + }; +} + +function fromGeminiFunctionResponsePart(part: GeminiPart): Part { + if (!part.functionResponse) { + throw new Error( + 'Invalid Gemini Function Call Part: missing function call data' + ); + } + return { + toolResponse: { + name: part.functionResponse.name.replace(/__/g, '/'), // restore slashes + output: part.functionResponse.response, + }, + }; +} + +// Converts vertex part to genkit part +function fromGeminiPart(part: GeminiPart): Part { + if (part.text !== undefined) return { text: part.text }; + if (part.functionCall) return fromGeminiFunctionCallPart(part); + if (part.functionResponse) return fromGeminiFunctionResponsePart(part); + if (part.inlineData) return fromGeminiInlineDataPart(part); + if (part.fileData) return fromGeminiFileDataPart(part); + throw new Error( + 'Part type is unsupported/corrupted. Either data is missing or type cannot be inferred from type.' + ); +} + +export function fromGeminiCandidate( + candidate: GenerateContentCandidate +): CandidateData { + const parts = candidate.content.parts || []; + const genkitCandidate: CandidateData = { + index: candidate.index || 0, // reasonable default? + message: { + role: 'model', + content: parts.map(fromGeminiPart), + }, + finishReason: fromGeminiFinishReason(candidate.finishReason), + finishMessage: candidate.finishMessage, + custom: { + safetyRatings: candidate.safetyRatings, + citationMetadata: candidate.citationMetadata, + }, + }; + return genkitCandidate; +} + +// Translate JSON schema to Vertex AI's format. Specifically, the type field needs be mapped. +// Since JSON schemas can include nested arrays/objects, we have to recursively map the type field +// in all nested fields. +const convertSchemaProperty = (property) => { + if (!property || !property.type) { + return null; + } + if (property.type === 'object') { + const nestedProperties = {}; + Object.keys(property.properties).forEach((key) => { + nestedProperties[key] = convertSchemaProperty(property.properties[key]); + }); + return { + type: FunctionDeclarationSchemaType.OBJECT, + properties: nestedProperties, + required: property.required, + }; + } else if (property.type === 'array') { + return { + type: FunctionDeclarationSchemaType.ARRAY, + items: convertSchemaProperty(property.items), + }; + } else { + return { + type: FunctionDeclarationSchemaType[property.type.toUpperCase()], + }; + } +}; + +/** + * Define a Vertex AI Gemini model. + */ +export function defineGeminiModel( + ai: Genkit, + name: string, + vertexClientFactory: ( + request: GenerateRequest + ) => VertexAI, + options: PluginOptions +): ModelAction { + const modelName = `vertexai/${name}`; + + const model: ModelReference = SUPPORTED_GEMINI_MODELS[name]; + if (!model) throw new Error(`Unsupported model: ${name}`); + + const middlewares: ModelMiddleware[] = []; + if (SUPPORTED_V1_MODELS[name]) { + middlewares.push(simulateSystemPrompt()); + } + if (model?.info?.supports?.media) { + // the gemini api doesn't support downloading media from http(s) + middlewares.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 20 })); + } + + return ai.defineModel( + { + name: modelName, + ...model.info, + configSchema: GeminiConfigSchema, + use: middlewares, + }, + async (request, streamingCallback) => { + const vertex = vertexClientFactory(request); + const client = vertex.preview.getGenerativeModel( + { + model: request.config?.version || model.version || name, + }, + { + apiClient: GENKIT_CLIENT_HEADER, + } + ); + + // make a copy so that modifying the request will not produce side-effects + const messages = [...request.messages]; + if (messages.length === 0) throw new Error('No messages provided.'); + + // Gemini does not support messages with role system and instead expects + // systemInstructions to be provided as a separate input. The first + // message detected with role=system will be used for systemInstructions. + // Any additional system messages may be considered to be "exceptional". + let systemInstruction: Content | undefined = undefined; + if (SUPPORTED_V15_MODELS[name]) { + const systemMessage = messages.find((m) => m.role === 'system'); + if (systemMessage) { + messages.splice(messages.indexOf(systemMessage), 1); + systemInstruction = toGeminiSystemInstruction(systemMessage); + } + } + + const chatRequest: StartChatParams = { + systemInstruction, + tools: request.tools?.length + ? [{ functionDeclarations: request.tools?.map(toGeminiTool) }] + : [], + history: messages + .slice(0, -1) + .map((message) => toGeminiMessage(message, model)), + generationConfig: { + candidateCount: request.candidates || undefined, + temperature: request.config?.temperature, + maxOutputTokens: request.config?.maxOutputTokens, + topK: request.config?.topK, + topP: request.config?.topP, + stopSequences: request.config?.stopSequences, + }, + safetySettings: request.config?.safetySettings, + }; + if (request.config?.googleSearchRetrieval) { + chatRequest.tools?.push({ + googleSearchRetrieval: request.config.googleSearchRetrieval, + }); + } + if (request.config?.vertexRetrieval) { + // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/ground-gemini#ground-gemini + const vertexRetrieval = request.config.vertexRetrieval; + const _projectId = + vertexRetrieval.datastore.projectId || options.projectId; + const _location = + vertexRetrieval.datastore.location || options.location; + const _dataStoreId = vertexRetrieval.datastore.dataStoreId; + const datastore = `projects/${_projectId}/locations/${_location}/collections/default_collection/dataStores/${_dataStoreId}`; + chatRequest.tools?.push({ + retrieval: { + vertexAiSearch: { + datastore, + }, + disableAttribution: vertexRetrieval.disableAttribution, + }, + }); + } + const msg = toGeminiMessage(messages[messages.length - 1], model); + if (streamingCallback) { + const result = await client + .startChat(chatRequest) + .sendMessageStream(msg.parts); + for await (const item of result.stream) { + (item as GenerateContentResponse).candidates?.forEach((candidate) => { + const c = fromGeminiCandidate(candidate); + streamingCallback({ + index: c.index, + content: c.message.content, + }); + }); + } + const response = await result.response; + if (!response.candidates?.length) { + throw new Error('No valid candidates returned.'); + } + return { + candidates: response.candidates?.map(fromGeminiCandidate) || [], + custom: response, + }; + } else { + let result: GenerateContentResult | undefined; + try { + result = await client.startChat(chatRequest).sendMessage(msg.parts); + } catch (err) { + throw new Error(`Vertex response generation failed: ${err}`); + } + if (!result?.response.candidates?.length) { + throw new Error('No valid candidates returned.'); + } + const responseCandidates = + result.response.candidates?.map(fromGeminiCandidate) || []; + return { + candidates: responseCandidates, + custom: result.response, + usage: { + ...getBasicUsageStats(request.messages, responseCandidates), + inputTokens: result.response.usageMetadata?.promptTokenCount, + outputTokens: result.response.usageMetadata?.candidatesTokenCount, + totalTokens: result.response.usageMetadata?.totalTokenCount, + }, + }; + } + } + ); +} diff --git a/js/plugins/checks/src/imagen.ts b/js/plugins/checks/src/imagen.ts new file mode 100644 index 000000000..12f11fd13 --- /dev/null +++ b/js/plugins/checks/src/imagen.ts @@ -0,0 +1,309 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit, z } from 'genkit'; +import { + CandidateData, + GenerateRequest, + GenerationCommonConfigSchema, + ModelReference, + getBasicUsageStats, + modelRef, +} from 'genkit/model'; +import { GoogleAuth } from 'google-auth-library'; +import { PluginOptions } from './index.js'; +import { PredictClient, predictModel } from './predict.js'; + +const ImagenConfigSchema = GenerationCommonConfigSchema.extend({ + /** Language of the prompt text. */ + language: z + .enum(['auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN']) + .optional(), + /** Desired aspect ratio of output image. */ + aspectRatio: z.enum(['1:1', '9:16', '16:9', '3:4', '4:3']).optional(), + /** + * A negative prompt to help generate the images. For example: "animals" + * (removes animals), "blurry" (makes the image clearer), "text" (removes + * text), or "cropped" (removes cropped images). + **/ + negativePrompt: z.string().optional(), + /** + * Any non-negative integer you provide to make output images deterministic. + * Providing the same seed number always results in the same output images. + * Accepted integer values: 1 - 2147483647. + **/ + seed: z.number().optional(), + /** Your GCP project's region. e.g.) us-central1, europe-west2, etc. **/ + location: z.string().optional(), + /** Allow generation of people by the model. */ + personGeneration: z + .enum(['dont_allow', 'allow_adult', 'allow_all']) + .optional(), + /** Adds a filter level to safety filtering. */ + safetySetting: z + .enum(['block_most', 'block_some', 'block_few', 'block_fewest']) + .optional(), + /** Add an invisible watermark to the generated images. */ + addWatermark: z.boolean().optional(), + /** Cloud Storage URI to store the generated images. **/ + storageUri: z.string().optional(), + /** Mode must be set for upscaling requests. */ + mode: z.enum(['upscale']).optional(), + /** + * Describes the editing intention for the request. + * + * Refer to https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#edit_images_2 for details. + */ + editConfig: z + .object({ + /** Describes the editing intention for the request. */ + editMode: z + .enum([ + 'inpainting-insert', + 'inpainting-remove', + 'outpainting', + 'product-image', + ]) + .optional(), + /** Prompts the model to generate a mask instead of you needing to provide one. Consequently, when you provide this parameter you can omit a mask object. */ + maskMode: z + .object({ + maskType: z.enum(['background', 'foreground', 'semantic']), + classes: z.array(z.number()).optional(), + }) + .optional(), + maskDilation: z.number().optional(), + guidanceScale: z.number().optional(), + productPosition: z.enum(['reposition', 'fixed']).optional(), + }) + .passthrough() + .optional(), + /** Upscale config object. */ + upscaleConfig: z.object({ upscaleFactor: z.enum(['x2', 'x4']) }).optional(), +}).passthrough(); + +export const imagen2 = modelRef({ + name: 'vertexai/imagen2', + info: { + label: 'Vertex AI - Imagen2', + versions: ['imagegeneration@006', 'imagegeneration@005'], + supports: { + media: false, + multiturn: false, + tools: false, + systemRole: false, + output: ['media'], + }, + }, + version: 'imagegeneration@006', + configSchema: ImagenConfigSchema, +}); + +export const imagen3 = modelRef({ + name: 'vertexai/imagen3', + info: { + label: 'Vertex AI - Imagen3', + versions: ['imagen-3.0-generate-001'], + supports: { + media: true, + multiturn: false, + tools: false, + systemRole: false, + output: ['media'], + }, + }, + version: 'imagen-3.0-generate-001', + configSchema: ImagenConfigSchema, +}); + +export const imagen3Fast = modelRef({ + name: 'vertexai/imagen3-fast', + info: { + label: 'Vertex AI - Imagen3 Fast', + versions: ['imagen-3.0-fast-generate-001'], + supports: { + media: false, + multiturn: false, + tools: false, + systemRole: false, + output: ['media'], + }, + }, + version: 'imagen-3.0-fast-generate-001', + configSchema: ImagenConfigSchema, +}); + +export const SUPPORTED_IMAGEN_MODELS = { + imagen2: imagen2, + imagen3: imagen3, + 'imagen3-fast': imagen3Fast, +}; + +function extractText(request: GenerateRequest) { + return request.messages + .at(-1)! + .content.map((c) => c.text || '') + .join(''); +} + +interface ImagenParameters { + sampleCount?: number; + aspectRatio?: string; + negativePrompt?: string; + seed?: number; + language?: string; + personGeneration?: string; + safetySetting?: string; + addWatermark?: boolean; + storageUri?: string; +} + +function toParameters( + request: GenerateRequest +): ImagenParameters { + const out = { + sampleCount: request.candidates ?? 1, + ...request?.config, + }; + + for (const k in out) { + if (!out[k]) delete out[k]; + } + + return out; +} + +function extractMaskImage(request: GenerateRequest): string | undefined { + return request.messages + .at(-1) + ?.content.find((p) => !!p.media && p.metadata?.type === 'mask') + ?.media?.url.split(',')[1]; +} + +function extractBaseImage(request: GenerateRequest): string | undefined { + return request.messages + .at(-1) + ?.content.find( + (p) => !!p.media && (!p.metadata?.type || p.metadata?.type === 'base') + ) + ?.media?.url.split(',')[1]; +} + +interface ImagenPrediction { + bytesBase64Encoded: string; + mimeType: string; +} + +interface ImagenInstance { + prompt: string; + image?: { bytesBase64Encoded: string }; + mask?: { image?: { bytesBase64Encoded: string } }; +} + +export function imagenModel( + ai: Genkit, + name: string, + client: GoogleAuth, + options: PluginOptions +) { + const modelName = `vertexai/${name}`; + const model: ModelReference = SUPPORTED_IMAGEN_MODELS[name]; + if (!model) throw new Error(`Unsupported model: ${name}`); + + const predictClients: Record< + string, + PredictClient + > = {}; + const predictClientFactory = ( + request: GenerateRequest + ): PredictClient => { + const requestLocation = request.config?.location || options.location; + if (!predictClients[requestLocation]) { + predictClients[requestLocation] = predictModel< + ImagenInstance, + ImagenPrediction, + ImagenParameters + >( + client, + { + ...options, + location: requestLocation, + }, + request.config?.version || model.version || name + ); + } + return predictClients[requestLocation]; + }; + + return ai.defineModel( + { + name: modelName, + ...model.info, + configSchema: ImagenConfigSchema, + }, + async (request) => { + const instance: ImagenInstance = { + prompt: extractText(request), + }; + const baseImage = extractBaseImage(request); + if (baseImage) { + instance.image = { bytesBase64Encoded: baseImage }; + } + const maskImage = extractMaskImage(request); + if (maskImage) { + instance.mask = { + image: { bytesBase64Encoded: maskImage }, + }; + } + + const req: any = { + instances: [instance], + parameters: toParameters(request), + }; + + const predictClient = predictClientFactory(request); + const response = await predictClient([instance], toParameters(request)); + + const candidates: CandidateData[] = response.predictions.map((p, i) => { + const b64data = p.bytesBase64Encoded; + const mimeType = p.mimeType; + return { + index: i, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + media: { + url: `data:${mimeType};base64,${b64data}`, + contentType: mimeType, + }, + }, + ], + }, + }; + }); + return { + candidates, + usage: { + ...getBasicUsageStats(request.messages, candidates), + custom: { generations: candidates.length }, + }, + custom: response, + }; + } + ); +} diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts new file mode 100644 index 000000000..088d9212d --- /dev/null +++ b/js/plugins/checks/src/index.ts @@ -0,0 +1,262 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { VertexAI } from '@google-cloud/vertexai'; +import { Genkit, z } from 'genkit'; +import { GenerateRequest, ModelReference } from 'genkit/model'; +import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; +import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; +import { + SUPPORTED_ANTHROPIC_MODELS, + anthropicModel, + claude35Sonnet, + claude3Haiku, + claude3Opus, + claude3Sonnet, +} from './anthropic.js'; +import { + SUPPORTED_EMBEDDER_MODELS, + defineVertexAIEmbedder, + textEmbedding004, + textEmbeddingGecko003, + textEmbeddingGeckoMultilingual001, + textMultilingualEmbedding002, +} from './embedder.js'; +import { + VertexAIEvaluationMetric, + VertexAIEvaluationMetricType, + vertexEvaluators, +} from './evaluation.js'; +import { + GeminiConfigSchema, + SUPPORTED_GEMINI_MODELS, + defineGeminiModel, + gemini10Pro, + gemini15Flash, + gemini15Pro, +} from './gemini.js'; +import { + SUPPORTED_IMAGEN_MODELS, + imagen2, + imagen3, + imagen3Fast, + imagenModel, +} from './imagen.js'; +import { + SUPPORTED_OPENAI_FORMAT_MODELS, + llama3, + llama31, + llama32, + modelGardenOpenaiCompatibleModel, +} from './model_garden.js'; +import { VertexRerankerConfig, vertexAiRerankers } from './reranker.js'; +import { + VectorSearchOptions, + vertexAiIndexers, + vertexAiRetrievers, +} from './vector-search/index.js'; +export { + DocumentIndexer, + DocumentRetriever, + Neighbor, + VectorSearchOptions, + getBigQueryDocumentIndexer, + getBigQueryDocumentRetriever, + getFirestoreDocumentIndexer, + getFirestoreDocumentRetriever, + vertexAiIndexerRef, + vertexAiIndexers, + vertexAiRetrieverRef, + vertexAiRetrievers, +} from './vector-search/index.js'; +export { + VertexAIEvaluationMetricType as VertexAIEvaluationMetricType, + claude35Sonnet, + claude3Haiku, + claude3Opus, + claude3Sonnet, + gemini10Pro, + gemini15Flash, + gemini15Pro, + imagen2, + imagen3, + imagen3Fast, + llama3, + llama31, + llama32, + textEmbedding004, + textEmbeddingGecko003, + textEmbeddingGeckoMultilingual001, + textMultilingualEmbedding002, +}; + +export interface PluginOptions { + /** The Google Cloud project id to call. */ + projectId?: string; + /** The Google Cloud region to call. */ + location: string; + /** Provide custom authentication configuration for connecting to Vertex AI. */ + googleAuth?: GoogleAuthOptions; + /** Configure Vertex AI evaluators */ + evaluation?: { + metrics: VertexAIEvaluationMetric[]; + }; + /** + * @deprecated use `modelGarden.models` + */ + modelGardenModels?: ModelReference[]; + modelGarden?: { + models: ModelReference[]; + openAiBaseUrlTemplate?: string; + }; + /** Configure Vertex AI vector search index options */ + vectorSearchOptions?: VectorSearchOptions[]; + /** Configure reranker options */ + rerankOptions?: VertexRerankerConfig[]; +} + +const CLOUD_PLATFROM_OAUTH_SCOPE = + 'https://www.googleapis.com/auth/cloud-platform'; + +/** + * Add Google Cloud Vertex AI to Genkit. Includes Gemini and Imagen models and text embedder. + */ +export function vertexAI(options?: PluginOptions): GenkitPlugin { + return genkitPlugin('vertexai', async (ai: Genkit) => { + let authClient; + let authOptions = options?.googleAuth; + + // Allow customers to pass in cloud credentials from environment variables + // following: https://github.com/googleapis/google-auth-library-nodejs?tab=readme-ov-file#loading-credentials-from-environment-variables + if (process.env.GCLOUD_SERVICE_ACCOUNT_CREDS) { + const serviceAccountCreds = JSON.parse( + process.env.GCLOUD_SERVICE_ACCOUNT_CREDS + ); + authOptions = { + credentials: serviceAccountCreds, + scopes: [CLOUD_PLATFROM_OAUTH_SCOPE], + }; + authClient = new GoogleAuth(authOptions); + } else { + authClient = new GoogleAuth( + authOptions ?? { scopes: [CLOUD_PLATFROM_OAUTH_SCOPE] } + ); + } + + const projectId = options?.projectId || (await authClient.getProjectId()); + + const location = options?.location || 'us-central1'; + const confError = (parameter: string, envVariableName: string) => { + return new Error( + `VertexAI Plugin is missing the '${parameter}' configuration. Please set the '${envVariableName}' environment variable or explicitly pass '${parameter}' into genkit config.` + ); + }; + if (!location) { + throw confError('location', 'GCLOUD_LOCATION'); + } + if (!projectId) { + throw confError('project', 'GCLOUD_PROJECT'); + } + + const vertexClientFactoryCache: Record = {}; + const vertexClientFactory = ( + request: GenerateRequest + ): VertexAI => { + const requestLocation = request.config?.location || location; + if (!vertexClientFactoryCache[requestLocation]) { + vertexClientFactoryCache[requestLocation] = new VertexAI({ + project: projectId, + location: requestLocation, + googleAuthOptions: authOptions, + }); + } + return vertexClientFactoryCache[requestLocation]; + }; + const metrics = + options?.evaluation && options.evaluation.metrics.length > 0 + ? options.evaluation.metrics + : []; + + Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) => + imagenModel(ai, name, authClient, { projectId, location }) + ); + Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => + defineGeminiModel(ai, name, vertexClientFactory, { projectId, location }) + ); + + if (options?.modelGardenModels || options?.modelGarden?.models) { + const mgModels = + options?.modelGardenModels || options?.modelGarden?.models; + mgModels!.forEach((m) => { + const anthropicEntry = Object.entries(SUPPORTED_ANTHROPIC_MODELS).find( + ([_, value]) => value.name === m.name + ); + if (anthropicEntry) { + anthropicModel(ai, anthropicEntry[0], projectId, location); + return; + } + const openaiModel = Object.entries(SUPPORTED_OPENAI_FORMAT_MODELS).find( + ([_, value]) => value.name === m.name + ); + if (openaiModel) { + modelGardenOpenaiCompatibleModel( + ai, + openaiModel[0], + projectId, + location, + authClient, + options.modelGarden?.openAiBaseUrlTemplate + ); + return; + } + throw new Error(`Unsupported model garden model: ${m.name}`); + }); + } + + const embedders = Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => + defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) + ); + + if ( + options?.vectorSearchOptions && + options.vectorSearchOptions.length > 0 + ) { + const defaultEmbedder = embedders[0]; + + vertexAiIndexers(ai, { + pluginOptions: options, + authClient, + defaultEmbedder, + }); + + vertexAiRetrievers(ai, { + pluginOptions: options, + authClient, + defaultEmbedder, + }); + } + + const rerankOptions = { + pluginOptions: options, + authClient, + projectId, + }; + await vertexAiRerankers(ai, rerankOptions); + vertexEvaluators(ai, authClient, metrics, projectId, location); + }); +} + +export default vertexAI; diff --git a/js/plugins/checks/src/model_garden.ts b/js/plugins/checks/src/model_garden.ts new file mode 100644 index 000000000..eec87274c --- /dev/null +++ b/js/plugins/checks/src/model_garden.ts @@ -0,0 +1,123 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; +import { GenerateRequest, ModelAction, modelRef } from 'genkit/model'; +import { GoogleAuth } from 'google-auth-library'; +import OpenAI from 'openai'; +import { + openaiCompatibleModel, + OpenAIConfigSchema, +} from './openai_compatibility.js'; + +export const ModelGardenModelConfigSchema = OpenAIConfigSchema.extend({ + location: z.string().optional(), +}); + +export const llama31 = modelRef({ + name: 'vertexai/llama-3.1', + info: { + label: 'Llama 3.1', + supports: { + multiturn: true, + tools: true, + media: false, + systemRole: true, + output: ['text', 'json'], + }, + versions: [ + 'meta/llama3-405b-instruct-maas', + // 8b and 70b versions are coming soon + ], + }, + configSchema: ModelGardenModelConfigSchema, + version: 'meta/llama3-405b-instruct-maas', +}); + +export const llama32 = modelRef({ + name: 'vertexai/llama-3.2', + info: { + label: 'Llama 3.2', + supports: { + multiturn: true, + tools: true, + media: true, + systemRole: true, + output: ['text', 'json'], + }, + versions: ['meta/llama-3.2-90b-vision-instruct-maas'], + }, + configSchema: ModelGardenModelConfigSchema, + version: 'meta/llama-3.2-90b-vision-instruct-maas', +}); + +/** + * @deprecated use `llama31` instead + */ +export const llama3 = modelRef({ + name: 'vertexai/llama3-405b', + info: { + label: 'Llama 3.1 405b', + supports: { + multiturn: true, + tools: true, + media: false, + systemRole: true, + output: ['text'], + }, + versions: ['meta/llama3-405b-instruct-maas'], + }, + configSchema: ModelGardenModelConfigSchema, + version: 'meta/llama3-405b-instruct-maas', +}); + +export const SUPPORTED_OPENAI_FORMAT_MODELS = { + 'llama3-405b': llama3, + 'llama-3.1': llama31, + 'llama-3.2': llama32, +}; + +export function modelGardenOpenaiCompatibleModel( + ai: Genkit, + name: string, + projectId: string, + location: string, + googleAuth: GoogleAuth, + baseUrlTemplate: string | undefined +): ModelAction { + const model = SUPPORTED_OPENAI_FORMAT_MODELS[name]; + if (!model) throw new Error(`Unsupported model: ${name}`); + if (!baseUrlTemplate) { + baseUrlTemplate = + 'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{projectId}/locations/{location}/endpoints/openapi'; + } + + const clientFactory = async ( + request: GenerateRequest + ): Promise => { + const requestLocation = request.config?.location || location; + return new OpenAI({ + baseURL: baseUrlTemplate! + .replace(/{location}/g, requestLocation) + .replace(/{projectId}/g, projectId), + apiKey: (await googleAuth.getAccessToken())!, + defaultHeaders: { + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + }, + }); + }; + return openaiCompatibleModel(ai, model, clientFactory); +} diff --git a/js/plugins/checks/src/openai_compatibility.ts b/js/plugins/checks/src/openai_compatibility.ts new file mode 100644 index 000000000..2de914f57 --- /dev/null +++ b/js/plugins/checks/src/openai_compatibility.ts @@ -0,0 +1,350 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit, Message, StreamingCallback, z } from 'genkit'; +import { + GenerateResponseChunkData, + GenerateResponseData, + GenerationCommonConfigSchema, + ModelAction, + ModelReference, + type CandidateData, + type GenerateRequest, + type MessageData, + type Part, + type Role, + type ToolDefinition, + type ToolRequestPart, +} from 'genkit/model'; +import OpenAI from 'openai'; +import { + type ChatCompletion, + type ChatCompletionChunk, + type ChatCompletionContentPart, + type ChatCompletionCreateParamsNonStreaming, + type ChatCompletionMessageParam, + type ChatCompletionMessageToolCall, + type ChatCompletionRole, + type ChatCompletionTool, + type CompletionChoice, +} from 'openai/resources/index.mjs'; + +export const OpenAIConfigSchema = GenerationCommonConfigSchema.extend({ + frequencyPenalty: z.number().min(-2).max(2).optional(), + logitBias: z.record(z.string(), z.number().min(-100).max(100)).optional(), + logProbs: z.boolean().optional(), + presencePenalty: z.number().min(-2).max(2).optional(), + seed: z.number().int().optional(), + topLogProbs: z.number().int().min(0).max(20).optional(), + user: z.string().optional(), +}); + +export function toOpenAIRole(role: Role): ChatCompletionRole { + switch (role) { + case 'user': + return 'user'; + case 'model': + return 'assistant'; + case 'system': + return 'system'; + case 'tool': + return 'tool'; + default: + throw new Error(`role ${role} doesn't map to an OpenAI role.`); + } +} + +function toOpenAiTool(tool: ToolDefinition): ChatCompletionTool { + return { + type: 'function', + function: { + name: tool.name, + parameters: tool.inputSchema || undefined, + }, + }; +} + +export function toOpenAiTextAndMedia(part: Part): ChatCompletionContentPart { + if (part.text) { + return { + type: 'text', + text: part.text, + }; + } else if (part.media) { + return { + type: 'image_url', + image_url: { + url: part.media.url, + }, + }; + } + throw Error( + `Unsupported genkit part fields encountered for current message role: ${JSON.stringify(part)}.` + ); +} + +export function toOpenAiMessages( + messages: MessageData[] +): ChatCompletionMessageParam[] { + const openAiMsgs: ChatCompletionMessageParam[] = []; + for (const message of messages) { + const msg = new Message(message); + const role = toOpenAIRole(message.role); + switch (role) { + case 'user': + openAiMsgs.push({ + role: role, + content: msg.content.map((part) => toOpenAiTextAndMedia(part)), + }); + break; + case 'system': + openAiMsgs.push({ + role: role, + content: msg.text, + }); + break; + case 'assistant': { + const toolCalls: ChatCompletionMessageToolCall[] = msg.content + .filter( + ( + part + ): part is Part & { + toolRequest: NonNullable; + } => Boolean(part.toolRequest) + ) + .map((part) => ({ + id: part.toolRequest.ref ?? '', + type: 'function', + function: { + name: part.toolRequest.name, + arguments: JSON.stringify(part.toolRequest.input), + }, + })); + if (toolCalls.length > 0) { + openAiMsgs.push({ + role: role, + tool_calls: toolCalls, + }); + } else { + openAiMsgs.push({ + role: role, + content: msg.text, + }); + } + break; + } + case 'tool': { + const toolResponseParts = msg.toolResponseParts(); + toolResponseParts.map((part) => { + openAiMsgs.push({ + role: role, + tool_call_id: part.toolResponse.ref ?? '', + content: + typeof part.toolResponse.output === 'string' + ? part.toolResponse.output + : JSON.stringify(part.toolResponse.output), + }); + }); + break; + } + } + } + return openAiMsgs; +} + +const finishReasonMap: Record< + CompletionChoice['finish_reason'] | 'tool_calls', + CandidateData['finishReason'] +> = { + length: 'length', + stop: 'stop', + tool_calls: 'stop', + content_filter: 'blocked', +}; + +export function fromOpenAiToolCall( + toolCall: + | ChatCompletionMessageToolCall + | ChatCompletionChunk.Choice.Delta.ToolCall +): ToolRequestPart { + if (!toolCall.function) { + throw Error( + `Unexpected openAI chunk choice. tool_calls was provided but one or more tool_calls is missing.` + ); + } + const f = toolCall.function; + return { + toolRequest: { + name: f.name!, + ref: toolCall.id, + input: f.arguments ? JSON.parse(f.arguments) : f.arguments, + }, + }; +} + +export function fromOpenAiChoice( + choice: ChatCompletion.Choice, + jsonMode = false +): CandidateData { + const toolRequestParts = choice.message.tool_calls?.map(fromOpenAiToolCall); + return { + index: choice.index, + finishReason: finishReasonMap[choice.finish_reason] || 'other', + message: { + role: 'model', + content: toolRequestParts + ? // Note: Not sure why I have to cast here exactly. + // Otherwise it thinks toolRequest must be 'undefined' if provided + (toolRequestParts as ToolRequestPart[]) + : [ + jsonMode + ? { data: JSON.parse(choice.message.content!) } + : { text: choice.message.content! }, + ], + }, + custom: {}, + }; +} + +export function fromOpenAiChunkChoice( + choice: ChatCompletionChunk.Choice, + jsonMode = false +): CandidateData { + const toolRequestParts = choice.delta.tool_calls?.map(fromOpenAiToolCall); + return { + index: choice.index, + finishReason: choice.finish_reason + ? finishReasonMap[choice.finish_reason] || 'other' + : 'unknown', + message: { + role: 'model', + content: toolRequestParts + ? (toolRequestParts as ToolRequestPart[]) + : [ + jsonMode + ? { data: JSON.parse(choice.delta.content!) } + : { text: choice.delta.content! }, + ], + }, + custom: {}, + }; +} + +export function toRequestBody( + model: ModelReference, + request: GenerateRequest +) { + const openAiMessages = toOpenAiMessages(request.messages); + const mappedModelName = + request.config?.version || model.version || model.name; + const body = { + model: mappedModelName, + messages: openAiMessages, + temperature: request.config?.temperature, + max_tokens: request.config?.maxOutputTokens, + top_p: request.config?.topP, + stop: request.config?.stopSequences, + frequency_penalty: request.config?.frequencyPenalty, + logit_bias: request.config?.logitBias, + logprobs: request.config?.logProbs, + presence_penalty: request.config?.presencePenalty, + seed: request.config?.seed, + top_logprobs: request.config?.topLogProbs, + user: request.config?.user, + tools: request.tools?.map(toOpenAiTool), + n: request.candidates, + } as ChatCompletionCreateParamsNonStreaming; + const response_format = request.output?.format; + if (response_format) { + if ( + response_format === 'json' && + model.info?.supports?.output?.includes('json') + ) { + body.response_format = { + type: 'json_object', + }; + } else if ( + response_format === 'text' && + model.info?.supports?.output?.includes('text') + ) { + // this is default format, don't need to set it + // body.response_format = { + // type: 'text', + // }; + } else { + throw new Error(`${response_format} format is not supported currently`); + } + } + for (const key in body) { + if (!body[key] || (Array.isArray(body[key]) && !body[key].length)) + delete body[key]; + } + return body; +} + +export function openaiCompatibleModel( + ai: Genkit, + model: ModelReference, + clientFactory: (request: GenerateRequest) => Promise +): ModelAction { + const modelId = model.name; + if (!model) throw new Error(`Unsupported model: ${name}`); + + return ai.defineModel( + { + name: modelId, + ...model.info, + configSchema: model.configSchema, + }, + async ( + request: GenerateRequest, + streamingCallback?: StreamingCallback + ): Promise => { + let response: ChatCompletion; + const client = await clientFactory(request); + const body = toRequestBody(model, request); + if (streamingCallback) { + const stream = client.beta.chat.completions.stream({ + ...body, + stream: true, + }); + for await (const chunk of stream) { + chunk.choices?.forEach((chunk) => { + const c = fromOpenAiChunkChoice(chunk); + streamingCallback({ + index: c.index, + content: c.message.content, + }); + }); + } + response = await stream.finalChatCompletion(); + } else { + response = await client.chat.completions.create(body); + } + return { + candidates: response.choices.map((c) => + fromOpenAiChoice(c, request.output?.format === 'json') + ), + usage: { + inputTokens: response.usage?.prompt_tokens, + outputTokens: response.usage?.completion_tokens, + totalTokens: response.usage?.total_tokens, + }, + custom: response, + }; + } + ); +} diff --git a/js/plugins/checks/src/predict.ts b/js/plugins/checks/src/predict.ts new file mode 100644 index 000000000..dfc538a5b --- /dev/null +++ b/js/plugins/checks/src/predict.ts @@ -0,0 +1,83 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GENKIT_CLIENT_HEADER } from 'genkit'; +import { GoogleAuth } from 'google-auth-library'; +import { PluginOptions } from '.'; + +function endpoint(options: { + projectId: string; + location: string; + model: string; +}) { + // eslint-disable-next-line max-len + return `https://${options.location}-aiplatform.googleapis.com/v1/projects/${options.projectId}/locations/${options.location}/publishers/google/models/${options.model}:predict`; +} + +interface PredictionResponse { + predictions: R[]; +} + +export type PredictClient = ( + instances: I[], + parameters?: P +) => Promise>; + +export function predictModel( + auth: GoogleAuth, + { location, projectId }: PluginOptions, + model: string +): PredictClient { + return async ( + instances: I[], + parameters?: P + ): Promise> => { + const fetch = (await import('node-fetch')).default; + + const accessToken = await auth.getAccessToken(); + const req = { + instances, + parameters: parameters || {}, + }; + + const response = await fetch( + endpoint({ + projectId: projectId!, + location, + model, + }), + { + method: 'POST', + body: JSON.stringify(req), + headers: { + Authorization: `Bearer ${accessToken}`, + 'Content-Type': 'application/json', + 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, + }, + } + ); + + if (!response.ok) { + throw new Error( + `Error from Vertex AI predict: HTTP ${ + response.status + }: ${await response.text()}` + ); + } + + return (await response.json()) as PredictionResponse; + }; +} diff --git a/js/plugins/checks/src/reranker.ts b/js/plugins/checks/src/reranker.ts new file mode 100644 index 000000000..95df9b2c9 --- /dev/null +++ b/js/plugins/checks/src/reranker.ts @@ -0,0 +1,159 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit, z } from 'genkit'; +import { RankedDocument, RerankerAction, rerankerRef } from 'genkit/reranker'; +import { GoogleAuth } from 'google-auth-library'; +import { PluginOptions } from '.'; + +const DEFAULT_MODEL = 'semantic-ranker-512@latest'; + +const getRerankEndpoint = (projectId: string, location: string) => { + return `https://discoveryengine.googleapis.com/v1/projects/${projectId}/locations/${location}/rankingConfigs/default_ranking_config:rank`; +}; + +// Define the schema for the options used in the Vertex AI reranker +export const VertexAIRerankerOptionsSchema = z.object({ + k: z.number().optional().describe('Number of top documents to rerank'), // Optional: Number of documents to rerank + model: z.string().optional().describe('Model name for reranking'), // Optional: Model name, defaults to a pre-defined model + location: z + .string() + .optional() + .describe('Google Cloud location, e.g., "us-central1"'), // Optional: Location of the reranking model +}); + +// Type alias for the options schema +export type VertexAIRerankerOptions = z.infer< + typeof VertexAIRerankerOptionsSchema +>; + +// Define the structure for each individual reranker configuration +export const VertexRerankerConfigSchema = z.object({ + model: z.string().optional().describe('Model name for reranking'), // Optional: Model name, defaults to a pre-defined model +}); + +export interface VertexRerankerConfig { + name?: string; + model?: string; +} + +export interface VertexRerankPluginOptions { + rerankOptions: VertexRerankerConfig[]; + projectId: string; + location?: string; // Optional: Location of the reranker service +} + +export interface VertexRerankOptions { + authClient: GoogleAuth; + pluginOptions?: PluginOptions; +} + +/** + * Creates Vertex AI rerankers. + * + * This function returns a list of reranker actions for Vertex AI based on the provided + * rerank options and configuration. + * + * @param {VertexRerankOptions} params - The parameters for creating the rerankers. + * @returns {RerankerAction[]} - An array of reranker actions. + */ +export async function vertexAiRerankers( + ai: Genkit, + params: VertexRerankOptions +): Promise[]> { + if (!params.pluginOptions) { + return []; + } + const pluginOptions = params.pluginOptions; + if (!params.pluginOptions.rerankOptions) { + return []; + } + + const rerankOptions = params.pluginOptions.rerankOptions; + const rerankers: RerankerAction[] = []; + + if (!rerankOptions || rerankOptions.length === 0) { + return rerankers; + } + const auth = new GoogleAuth(); + const client = await auth.getClient(); + const projectId = await auth.getProjectId(); + + for (const rerankOption of rerankOptions) { + const reranker = ai.defineReranker( + { + name: `vertexai/${rerankOption.name || rerankOption.model}`, + configSchema: VertexAIRerankerOptionsSchema.optional(), + }, + async (query, documents, _options) => { + const response = await client.request({ + method: 'POST', + url: getRerankEndpoint( + projectId, + pluginOptions.location ?? 'us-central1' + ), + data: { + model: rerankOption.model || DEFAULT_MODEL, // Use model from config or default + query: query.text, + records: documents.map((doc, idx) => ({ + id: `${idx}`, + content: doc.text, + })), + }, + }); + + const rankedDocuments: RankedDocument[] = ( + response.data as any + ).records.map((record: any) => { + const doc = documents[record.id]; + return new RankedDocument({ + content: doc.content, + metadata: { + ...doc.metadata, + score: record.score, + }, + }); + }); + + return { documents: rankedDocuments }; + } + ); + + rerankers.push(reranker); + } + + return rerankers; +} + +/** + * Creates a reference to a Vertex AI reranker. + * + * @param {Object} params - The parameters for the reranker reference. + * @param {string} [params.displayName] - An optional display name for the reranker. + * @returns {Object} - The reranker reference object. + */ +export const vertexAiRerankerRef = (params: { + name: string; + displayName?: string; +}) => { + return rerankerRef({ + name: `vertexai/${name}`, + info: { + label: params.displayName ?? `Vertex AI Reranker`, + }, + configSchema: VertexAIRerankerOptionsSchema.optional(), + }); +}; diff --git a/js/plugins/checks/src/vector-search/bigquery.ts b/js/plugins/checks/src/vector-search/bigquery.ts new file mode 100644 index 000000000..e3a40ba61 --- /dev/null +++ b/js/plugins/checks/src/vector-search/bigquery.ts @@ -0,0 +1,131 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BigQuery, QueryRowsResponse } from '@google-cloud/bigquery'; +import { z } from 'genkit'; +import { logger } from 'genkit/logging'; +import { Document, DocumentDataSchema } from 'genkit/retriever'; +import { DocumentIndexer, DocumentRetriever, Neighbor } from './types'; + +/** + * Creates a BigQuery Document Retriever. + * + * This function returns a DocumentRetriever function that retrieves documents + * from a BigQuery table based on the provided neighbors. + * + * @param {BigQuery} bq - The BigQuery instance. + * @param {string} tableId - The ID of the BigQuery table. + * @param {string} datasetId - The ID of the BigQuery dataset. + * @returns {DocumentRetriever} - The DocumentRetriever function. + */ +export const getBigQueryDocumentRetriever = ( + bq: BigQuery, + tableId: string, + datasetId: string +): DocumentRetriever => { + const bigQueryRetriever: DocumentRetriever = async ( + neighbors: Neighbor[] + ): Promise => { + const ids: string[] = neighbors + .map((neighbor) => neighbor.datapoint?.datapointId) + .filter(Boolean) as string[]; + + const query = ` + SELECT * FROM \`${datasetId}.${tableId}\` + WHERE id IN UNNEST(@ids) + `; + + const options = { + query, + params: { ids }, + }; + + let rows: QueryRowsResponse[0]; + + try { + [rows] = await bq.query(options); + } catch (queryError) { + logger.error('Failed to execute BigQuery query:', queryError); + return []; + } + + const documents: Document[] = []; + + for (const row of rows) { + try { + const docData: { content: any; metadata?: any } = { + content: JSON.parse(row.content), + }; + + if (row.metadata) { + docData.metadata = JSON.parse(row.metadata); + } + + const parsedDocData = DocumentDataSchema.parse(docData); + documents.push(new Document(parsedDocData)); + } catch (error) { + const id = row.id; + const errorPrefix = `Failed to parse document data for document with ID ${id}:`; + + if (error instanceof z.ZodError || error instanceof Error) { + logger.warn(`${errorPrefix} ${error.message}`); + } else { + logger.warn(errorPrefix); + } + } + } + + return documents; + }; + + return bigQueryRetriever; +}; + +/** + * Creates a BigQuery Document Indexer. + * + * This function returns a DocumentIndexer function that indexes documents + * into a BigQuery table. Note this indexer does not handle duplicate + * documents. + * + * @param {BigQuery} bq - The BigQuery instance. + * @param {string} tableId - The ID of the BigQuery table. + * @param {string} datasetId - The ID of the BigQuery dataset. + * @returns {DocumentIndexer} - The DocumentIndexer function. + */ +export const getBigQueryDocumentIndexer = ( + bq: BigQuery, + tableId: string, + datasetId: string +): DocumentIndexer => { + const bigQueryIndexer: DocumentIndexer = async ( + docs: Document[] + ): Promise => { + const ids: string[] = []; + const rows = docs.map((doc) => { + const id = Math.random().toString(36).substring(7); + ids.push(id); + return { + id, + content: JSON.stringify(doc.content), + metadata: JSON.stringify(doc.metadata), + }; + }); + await bq.dataset(datasetId).table(tableId).insert(rows); + return ids; + }; + return bigQueryIndexer; +}; diff --git a/js/plugins/checks/src/vector-search/firestore.ts b/js/plugins/checks/src/vector-search/firestore.ts new file mode 100644 index 000000000..1eefc894f --- /dev/null +++ b/js/plugins/checks/src/vector-search/firestore.ts @@ -0,0 +1,87 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Firestore } from 'firebase-admin/firestore'; +import { Document, DocumentDataSchema } from 'genkit'; +import { DocumentIndexer, DocumentRetriever, Neighbor } from './types'; +/** + * Creates a Firestore Document Retriever. + * + * This function returns a DocumentRetriever function that retrieves documents + * from a Firestore collection based on the provided Vertex AI Vector Search neighbors. + * + * @param {Firestore} db - The Firestore instance. + * @param {string} collectionName - The name of the Firestore collection. + * @returns {DocumentRetriever} - The DocumentRetriever function. + */ +export const getFirestoreDocumentRetriever = ( + db: Firestore, + collectionName: string +): DocumentRetriever => { + const firestoreRetriever: DocumentRetriever = async ( + neighbors: Neighbor[] + ): Promise => { + const docs: Document[] = []; + for (const neighbor of neighbors) { + const docRef = db + .collection(collectionName) + .doc(neighbor.datapoint?.datapointId!); + const docSnapshot = await docRef.get(); + if (docSnapshot.exists) { + const docData = { ...docSnapshot.data(), metadata: { ...neighbor } }; + const parsedDocData = DocumentDataSchema.safeParse(docData); + if (parsedDocData.success) { + docs.push(new Document(parsedDocData.data)); + } + } + } + return docs; + }; + return firestoreRetriever; +}; + +/** + * Creates a Firestore Document Indexer. + * + * This function returns a DocumentIndexer function that indexes documents + * into a Firestore collection. + * + * @param {Firestore} db - The Firestore instance. + * @param {string} collectionName - The name of the Firestore collection. + * @returns {DocumentIndexer} - The DocumentIndexer function. + */ +export const getFirestoreDocumentIndexer = ( + db: Firestore, + collectionName: string +) => { + const firestoreIndexer: DocumentIndexer = async ( + docs: Document[] + ): Promise => { + const batch = db.batch(); + const ids: string[] = []; + docs.forEach((doc) => { + const docRef = db.collection(collectionName).doc(); + batch.set(docRef, { + content: doc.content, + metadata: doc.metadata || null, + }); + ids.push(docRef.id); + }); + await batch.commit(); + return ids; + }; + return firestoreIndexer; +}; diff --git a/js/plugins/checks/src/vector-search/index.ts b/js/plugins/checks/src/vector-search/index.ts new file mode 100644 index 000000000..638ba1abc --- /dev/null +++ b/js/plugins/checks/src/vector-search/index.ts @@ -0,0 +1,36 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +export { + getBigQueryDocumentIndexer, + getBigQueryDocumentRetriever, +} from './bigquery'; +export { + getFirestoreDocumentIndexer, + getFirestoreDocumentRetriever, +} from './firestore'; +export { vertexAiIndexerRef, vertexAiIndexers } from './indexers'; +export { vertexAiRetrieverRef, vertexAiRetrievers } from './retrievers'; +export { + DocumentIndexer, + DocumentRetriever, + Neighbor, + VectorSearchOptions, + VertexAIVectorIndexerOptions, + VertexAIVectorIndexerOptionsSchema, + VertexAIVectorRetrieverOptions, + VertexAIVectorRetrieverOptionsSchema, +} from './types'; diff --git a/js/plugins/checks/src/vector-search/indexers.ts b/js/plugins/checks/src/vector-search/indexers.ts new file mode 100644 index 000000000..66a00e913 --- /dev/null +++ b/js/plugins/checks/src/vector-search/indexers.ts @@ -0,0 +1,120 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit, z } from 'genkit'; +import { IndexerAction, indexerRef } from 'genkit/retriever'; +import { + Datapoint, + VertexAIVectorIndexerOptionsSchema, + VertexVectorSearchOptions, +} from './types'; +import { upsertDatapoints } from './upsert_datapoints'; + +/** + * Creates a reference to a Vertex AI indexer. + * + * @param {Object} params - The parameters for the indexer reference. + * @param {string} params.indexId - The ID of the Vertex AI index. + * @param {string} [params.displayName] - An optional display name for the indexer. + * @returns {Object} - The indexer reference object. + */ +export const vertexAiIndexerRef = (params: { + indexId: string; + displayName?: string; +}) => { + return indexerRef({ + name: `vertexai/${params.indexId}`, + info: { + label: params.displayName ?? `Vertex AI - ${params.indexId}`, + }, + configSchema: VertexAIVectorIndexerOptionsSchema.optional(), + }); +}; + +/** + * Creates Vertex AI indexers. + * + * This function returns a list of indexer actions for Vertex AI based on the provided + * vector search options and embedder configurations. + * + * @param {VertexVectorSearchOptions} params - The parameters for creating the indexers. + * @returns {IndexerAction[]} - An array of indexer actions. + */ +export function vertexAiIndexers( + ai: Genkit, + params: VertexVectorSearchOptions +): IndexerAction[] { + const vectorSearchOptions = params.pluginOptions.vectorSearchOptions; + const defaultEmbedder = params.defaultEmbedder; + const indexers: IndexerAction[] = []; + + if (!vectorSearchOptions || vectorSearchOptions.length === 0) { + return indexers; + } + + for (const vectorSearchOption of vectorSearchOptions) { + const { documentIndexer, indexId } = vectorSearchOption; + const embedder = vectorSearchOption.embedder ?? defaultEmbedder; + const embedderOptions = vectorSearchOption.embedderOptions; + + const indexer = ai.defineIndexer( + { + name: `vertexai/${indexId}`, + configSchema: VertexAIVectorIndexerOptionsSchema.optional(), + }, + async (docs, options) => { + let docIds: string[] = []; + + try { + docIds = await documentIndexer(docs, options); + } catch (error) { + throw new Error( + `Error storing your document content/metadata: ${error}` + ); + } + + const embeddings = await ai.embedMany({ + embedder, + content: docs, + options: embedderOptions, + }); + + const datapoints = embeddings.map( + ({ embedding }, i) => + new Datapoint({ + datapointId: docIds[i], + featureVector: embedding, + }) + ); + + try { + await upsertDatapoints({ + datapoints, + authClient: params.authClient, + projectId: params.pluginOptions.projectId!, + location: params.pluginOptions.location!, + indexId: indexId, + }); + } catch (error) { + throw error; + } + } + ); + + indexers.push(indexer); + } + return indexers; +} diff --git a/js/plugins/checks/src/vector-search/query_public_endpoint.ts b/js/plugins/checks/src/vector-search/query_public_endpoint.ts new file mode 100644 index 000000000..f055e3b9d --- /dev/null +++ b/js/plugins/checks/src/vector-search/query_public_endpoint.ts @@ -0,0 +1,92 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { logger } from 'genkit/logging'; +import { FindNeighborsResponse } from './types'; + +interface QueryPublicEndpointParams { + featureVector: number[]; + neighborCount: number; + accessToken: string; + projectId: string; + location: string; + indexEndpointId: string; + publicDomainName: string; + projectNumber: string; + deployedIndexId: string; +} +/** + * Queries a public index endpoint to find neighbors for a given feature vector. + * + * This function sends a request to a specified public endpoint to find neighbors + * for a given feature vector using the provided parameters. + * + * @param {QueryPublicEndpointParams} params - The parameters required to query the public endpoint. + * @param {number[]} params.featureVector - The feature vector for which to find neighbors. + * @param {number} params.neighborCount - The number of neighbors to retrieve. + * @param {string} params.accessToken - The access token for authorization. + * @param {string} params.projectId - The ID of the Google Cloud project. + * @param {string} params.location - The location of the index endpoint. + * @param {string} params.indexEndpointId - The ID of the index endpoint. + * @param {string} params.publicDomainName - The domain name of the public endpoint. + * @param {string} params.projectNumber - The project number. + * @param {string} params.deployedIndexId - The ID of the deployed index. + * @returns {Promise} - The response from the public endpoint. + */ +export async function queryPublicEndpoint( + params: QueryPublicEndpointParams +): Promise { + const { + featureVector, + neighborCount, + accessToken, + indexEndpointId, + publicDomainName, + projectNumber, + deployedIndexId, + location, + } = params; + const url = new URL( + `https://${publicDomainName}/v1/projects/${projectNumber}/locations/${location}/indexEndpoints/${indexEndpointId}:findNeighbors` + ); + + const requestBody = { + deployed_index_id: deployedIndexId, + queries: [ + { + datapoint: { + datapoint_id: '0', + feature_vector: featureVector, + }, + neighbor_count: neighborCount, + }, + ], + }; + + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${accessToken}`, + }, + body: JSON.stringify(requestBody), + }); + + if (!response.ok) { + logger.error('Error querying index: ', response.statusText); + } + return response.json(); +} diff --git a/js/plugins/checks/src/vector-search/retrievers.ts b/js/plugins/checks/src/vector-search/retrievers.ts new file mode 100644 index 000000000..67f47f33d --- /dev/null +++ b/js/plugins/checks/src/vector-search/retrievers.ts @@ -0,0 +1,136 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Genkit, RetrieverAction, retrieverRef, z } from 'genkit'; +import { queryPublicEndpoint } from './query_public_endpoint'; +import { + VertexAIVectorRetrieverOptionsSchema, + VertexVectorSearchOptions, +} from './types'; +import { getProjectNumber } from './utils'; + +const DEFAULT_K = 10; + +/** + * Creates Vertex AI retrievers. + * + * This function returns a list of retriever actions for Vertex AI based on the provided + * vector search options and embedder configurations. + * + * @param {VertexVectorSearchOptions} params - The parameters for creating the retrievers. + * @returns {RetrieverAction[]} - An array of retriever actions. + */ +export function vertexAiRetrievers( + ai: Genkit, + params: VertexVectorSearchOptions +): RetrieverAction[] { + const vectorSearchOptions = params.pluginOptions.vectorSearchOptions; + const defaultEmbedder = params.defaultEmbedder; + + const retrievers: RetrieverAction[] = []; + + if (!vectorSearchOptions || vectorSearchOptions.length === 0) { + return retrievers; + } + + for (const vectorSearchOption of vectorSearchOptions) { + const { documentRetriever, indexId, publicDomainName } = vectorSearchOption; + const embedder = vectorSearchOption.embedder ?? defaultEmbedder; + const embedderOptions = vectorSearchOption.embedderOptions; + + const retriever = ai.defineRetriever( + { + name: `vertexai/${indexId}`, + configSchema: VertexAIVectorRetrieverOptionsSchema.optional(), + }, + async (content, options) => { + const queryEmbeddings = await ai.embed({ + embedder, + options: embedderOptions, + content, + }); + + const accessToken = await params.authClient.getAccessToken(); + + if (!accessToken) { + throw new Error( + 'Error generating access token when defining Vertex AI retriever' + ); + } + + const projectId = params.pluginOptions.projectId; + if (!projectId) { + throw new Error( + 'Project ID is required to define Vertex AI retriever' + ); + } + const projectNumber = await getProjectNumber(projectId); + const location = params.pluginOptions.location; + if (!location) { + throw new Error('Location is required to define Vertex AI retriever'); + } + + let res = await queryPublicEndpoint({ + featureVector: queryEmbeddings, + neighborCount: options?.k || DEFAULT_K, + accessToken, + projectId, + location, + publicDomainName, + projectNumber, + indexEndpointId: vectorSearchOption.indexEndpointId, + deployedIndexId: vectorSearchOption.deployedIndexId, + }); + const nearestNeighbors = res.nearestNeighbors; + + const queryRes = nearestNeighbors ? nearestNeighbors[0] : null; + const neighbors = queryRes ? queryRes.neighbors : null; + if (!neighbors) { + return { documents: [] }; + } + + const documents = await documentRetriever(neighbors, options); + + return { documents }; + } + ); + + retrievers.push(retriever); + } + + return retrievers; +} + +/** + * Creates a reference to a Vertex AI retriever. + * + * @param {Object} params - The parameters for the retriever reference. + * @param {string} params.indexId - The ID of the Vertex AI index. + * @param {string} [params.displayName] - An optional display name for the retriever. + * @returns {Object} - The retriever reference object. + */ +export const vertexAiRetrieverRef = (params: { + indexId: string; + displayName?: string; +}) => { + return retrieverRef({ + name: `vertexai/${params.indexId}`, + info: { + label: params.displayName ?? `ertex AI - ${params.indexId}`, + }, + configSchema: VertexAIVectorRetrieverOptionsSchema.optional(), + }); +}; diff --git a/js/plugins/checks/src/vector-search/types.ts b/js/plugins/checks/src/vector-search/types.ts new file mode 100644 index 000000000..6b58e4f34 --- /dev/null +++ b/js/plugins/checks/src/vector-search/types.ts @@ -0,0 +1,189 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import * as aiplatform from '@google-cloud/aiplatform'; +import { z } from 'genkit'; +import { EmbedderArgument } from 'genkit/embedder'; +import { CommonRetrieverOptionsSchema, Document } from 'genkit/retriever'; +import { GoogleAuth } from 'google-auth-library'; +import { PluginOptions } from '..'; + +// This internal interface will be passed to the vertexIndexers and vertexRetrievers functions +export interface VertexVectorSearchOptions< + EmbedderCustomOptions extends z.ZodTypeAny, +> { + pluginOptions: PluginOptions; + authClient: GoogleAuth; + defaultEmbedder: EmbedderArgument; +} + +export type IIndexDatapoint = + aiplatform.protos.google.cloud.aiplatform.v1.IIndexDatapoint; + +export class Datapoint extends aiplatform.protos.google.cloud.aiplatform.v1 + .IndexDatapoint { + constructor(properties: IIndexDatapoint) { + super(properties); + } +} + +export type IFindNeighborsRequest = + aiplatform.protos.google.cloud.aiplatform.v1.IFindNeighborsRequest; +export type IFindNeighborsResponse = + aiplatform.protos.google.cloud.aiplatform.v1.IFindNeighborsResponse; +export type ISparseEmbedding = + aiplatform.protos.google.cloud.aiplatform.v1.IndexDatapoint.ISparseEmbedding; +export type IRestriction = + aiplatform.protos.google.cloud.aiplatform.v1.IndexDatapoint.IRestriction; +export type INumericRestriction = + aiplatform.protos.google.cloud.aiplatform.v1.IndexDatapoint.INumericRestriction; + +// Define the Zod schema for ISparseEmbedding +export const SparseEmbeddingSchema = z.object({ + values: z.array(z.number()).optional(), + dimensions: z.array(z.union([z.number(), z.string()])).optional(), +}); + +export type SparseEmbedding = z.infer; + +// Define the Zod schema for IRestriction +export const RestrictionSchema = z.object({ + namespace: z.string().optional(), + allowList: z.array(z.string()).optional(), + denyList: z.array(z.string()).optional(), +}); + +export type Restriction = z.infer; + +// Define the Zod schema for INumericRestriction +export const NumericRestrictionSchema = z.object({ + valueInt: z.union([z.number(), z.string()]).optional(), + valueFloat: z.number().optional(), + valueDouble: z.number().optional(), + namespace: z.string().optional(), + op: z + .union([ + z.enum([ + 'OPERATOR_UNSPECIFIED', + 'LESS', + 'LESS_EQUAL', + 'EQUAL', + 'GREATER_EQUAL', + 'GREATER', + 'NOT_EQUAL', + ]), + z.null(), + ]) + .optional(), +}); + +export type NumericRestriction = z.infer; + +// Define the Zod schema for ICrowdingTag +export const CrowdingTagSchema = z.object({ + crowdingAttribute: z.string().optional(), +}); + +export type CrowdingTag = z.infer; + +// Define the Zod schema for IIndexDatapoint +const IndexDatapointSchema = z.object({ + datapointId: z.string().optional(), + featureVector: z.array(z.number()).optional(), + sparseEmbedding: SparseEmbeddingSchema.optional(), + restricts: z.array(RestrictionSchema).optional(), + numericRestricts: z.array(NumericRestrictionSchema).optional(), + crowdingTag: CrowdingTagSchema.optional(), +}); + +// Define the Zod schema for INeighbor +export const NeighborSchema = z.object({ + datapoint: IndexDatapointSchema.optional(), + distance: z.number().optional(), + sparseDistance: z.number().optional(), +}); + +export type Neighbor = z.infer; + +// Define the Zod schema for INearestNeighbors +const NearestNeighborsSchema = z.object({ + id: z.string().optional(), + neighbors: z.array(NeighborSchema).optional(), +}); + +// Define the Zod schema for IFindNeighborsResponse +export const FindNeighborsResponseSchema = z.object({ + nearestNeighbors: z.array(NearestNeighborsSchema).optional(), +}); + +export type FindNeighborsResponse = z.infer; + +// TypeScript types for Zod schemas +type IndexDatapoint = z.infer; + +// Function to assert type equality +function assertTypeEquality(value: T): void {} + +// Asserting type equality +assertTypeEquality({} as IndexDatapoint); +assertTypeEquality({} as FindNeighborsResponse); + +export const VertexAIVectorRetrieverOptionsSchema = + CommonRetrieverOptionsSchema.extend({}).optional(); + +export type VertexAIVectorRetrieverOptions = z.infer< + typeof VertexAIVectorRetrieverOptionsSchema +>; + +export const VertexAIVectorIndexerOptionsSchema = z.any(); + +export type VertexAIVectorIndexerOptions = z.infer< + typeof VertexAIVectorIndexerOptionsSchema +>; + +/** + * A document retriever function that takes an array of Neighbors from Vertex AI Vector Search query result, and resolves to a list of documents. + * Also takes an options object that can be used to configure the retriever. + */ +export type DocumentRetriever = + (docIds: Neighbor[], options?: Options) => Promise; + +/** + * Indexer function that takes an array of documents, stores them in a database of the user's choice, and resolves to a list of document ids. + * Also takes an options object that can be used to configure the indexer. Only Streaming Update Indexers are supported. + */ +export type DocumentIndexer = ( + docs: Document[], + options?: Options +) => Promise; + +export interface VectorSearchOptions< + EmbedderCustomOptions extends z.ZodTypeAny, + IndexerOptions extends {}, + RetrieverOptions extends { k?: number }, +> { + // Specify the Vertex AI Index and IndexEndpoint to use for indexing and retrieval + deployedIndexId: string; + indexEndpointId: string; + publicDomainName: string; + indexId: string; + // Document retriever and indexer functions to use for indexing and retrieval by the plugin's own indexers and retrievers + documentRetriever: DocumentRetriever; + documentIndexer: DocumentIndexer; + // Embedder and default options to use for indexing and retrieval + embedder?: EmbedderArgument; + embedderOptions?: z.infer; +} diff --git a/js/plugins/checks/src/vector-search/upsert_datapoints.ts b/js/plugins/checks/src/vector-search/upsert_datapoints.ts new file mode 100644 index 000000000..cfeb8d5ec --- /dev/null +++ b/js/plugins/checks/src/vector-search/upsert_datapoints.ts @@ -0,0 +1,71 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GoogleAuth } from 'google-auth-library'; +import { IIndexDatapoint } from './types'; + +interface UpsertDatapointsParams { + datapoints: IIndexDatapoint[]; + authClient: GoogleAuth; + projectId: string; + location: string; + indexId: string; +} + +/** + * Upserts datapoints into a specified index. + * + * This function sends a request to the Google AI Platform to upsert datapoints + * into a specified index using the provided parameters. + * + * @param {UpsertDatapointsParams} params - The parameters required to upsert datapoints. + * @param {IIndexDatapoint[]} params.datapoints - The datapoints to be upserted. + * @param {GoogleAuth} params.authClient - The GoogleAuth client for authorization. + * @param {string} params.projectId - The ID of the Google Cloud project. + * @param {string} params.location - The location of the AI Platform index. + * @param {string} params.indexId - The ID of the index. + * @returns {Promise} - A promise that resolves when the upsert is complete. + * @throws {Error} - Throws an error if the upsert fails. + */ +export async function upsertDatapoints( + params: UpsertDatapointsParams +): Promise { + const { datapoints, authClient, projectId, location, indexId } = params; + const accessToken = await authClient.getAccessToken(); + const url = `https://${location}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${location}/indexes/${indexId}:upsertDatapoints`; + + const requestBody = { + datapoints: datapoints.map((dp) => ({ + datapoint_id: dp.datapointId, + feature_vector: dp.featureVector, + })), + }; + + const response = await fetch(url, { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + Authorization: `Bearer ${accessToken}`, + }, + body: JSON.stringify(requestBody), + }); + + if (!response.ok) { + throw new Error( + `Error upserting datapoints into index ${indexId}: ${response.statusText}` + ); + } +} diff --git a/js/plugins/checks/src/vector-search/utils.ts b/js/plugins/checks/src/vector-search/utils.ts new file mode 100644 index 000000000..c6415b8c5 --- /dev/null +++ b/js/plugins/checks/src/vector-search/utils.ts @@ -0,0 +1,65 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GoogleAuth } from 'google-auth-library'; +import { google } from 'googleapis'; + +/** + * Retrieves an access token using the provided GoogleAuth client. + * + * @param {GoogleAuth} auth - The GoogleAuth client. + * @returns {Promise} - A promise that resolves to the access token. + */ +export async function getAccessToken(auth: GoogleAuth): Promise { + const client = await auth.getClient(); + const _accessToken = await client.getAccessToken(); + return _accessToken.token || null; +} + +/** + * Retrieves the project number for a given project ID. + * + * This function sends a request to the Google Cloud Resource Manager API to + * fetch the project number for the specified project ID. + * + * @param {string} projectId - The ID of the Google Cloud project. + * @returns {Promise} - A promise that resolves to the project number. + * @throws {Error} - Throws an error if the project number cannot be fetched. + */ +export async function getProjectNumber(projectId: string): Promise { + const client = google.cloudresourcemanager('v1'); + const authClient = await google.auth.getClient({ + scopes: ['https://www.googleapis.com/auth/cloud-platform'], + }); + + try { + const response = await client.projects.get({ + projectId: projectId, + auth: authClient, + }); + + if (!response.data.projectNumber) { + throw new Error( + `Error fetching project number for Vertex AI plugin for project ${projectId}` + ); + } + return response.data['projectNumber']; + } catch (error) { + throw new Error( + `Error fetching project number for Vertex AI plugin for project ${projectId}` + ); + } +} diff --git a/js/plugins/checks/tests/anthropic_test.ts b/js/plugins/checks/tests/anthropic_test.ts new file mode 100644 index 000000000..f5870e6a1 --- /dev/null +++ b/js/plugins/checks/tests/anthropic_test.ts @@ -0,0 +1,313 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + Message, + MessageCreateParamsBase, +} from '@anthropic-ai/sdk/resources/messages.mjs'; +import { GenerateRequest, GenerateResponseData } from 'genkit'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { + AnthropicConfigSchema, + fromAnthropicResponse, + toAnthropicRequest, +} from '../src/anthropic.js'; + +const MODEL_ID = 'modelid'; + +describe('toAnthropicRequest', () => { + const testCases: { + should: string; + input: GenerateRequest; + expectedOutput: MessageCreateParamsBase; + }[] = [ + { + should: 'should transform genkit message (text content) correctly', + input: { + messages: [ + { + role: 'user', + content: [{ text: 'Tell a joke about dogs.' }], + }, + ], + }, + expectedOutput: { + max_tokens: 4096, + model: MODEL_ID, + messages: [ + { + role: 'user', + content: [{ type: 'text', text: 'Tell a joke about dogs.' }], + }, + ], + }, + }, + { + should: 'should transform system message', + input: { + messages: [ + { + role: 'system', + content: [{ text: 'Talk like a pirate.' }], + }, + { + role: 'user', + content: [{ text: 'Tell a joke about dogs.' }], + }, + ], + }, + expectedOutput: { + max_tokens: 4096, + model: MODEL_ID, + system: 'Talk like a pirate.', + messages: [ + { + role: 'user', + content: [{ type: 'text', text: 'Tell a joke about dogs.' }], + }, + ], + }, + }, + { + should: + 'should transform genkit message (inline base64 image content) correctly', + input: { + messages: [ + { + role: 'user', + content: [ + { text: 'describe the following image:' }, + { + media: { + contentType: 'image/jpeg', + url: '', + }, + }, + ], + }, + ], + }, + expectedOutput: { + max_tokens: 4096, + model: MODEL_ID, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'describe the following image:' }, + { + type: 'image', + source: { + type: 'base64', + media_type: 'image/jpeg', + data: '/9j/4QDeRXhpZgAASUkqAAgAAAAGABIBAwABAAAAAQAAABoBBQABAAAAVgAAABsBBQABAAAAXgAAACgBAwABAAAAAgAAABMCAwABAAAAAQAAAGmHBAABAAAAZgAAAAAAAABIAAAAAQAAAEgAAAABAAAABwAAkAcABAAAADAyMTABkQcABAAAAAECAwCGkgcAFgAAAMAAAAAAoAcABAAAADAxMDABoAMAAQAAAP//AAACoAQAAQAAAMgAAAADoAQAAQAAAMgAAAAAAAAAQVNDSUkAAABQaWNzdW0gSUQ6IDY4N//bAEMACAYGBwYFCAcHBwkJCAoMFA0MCwsMGRITDxQdGh8eHRocHCAkLicgIiwjHBwoNyksMDE0NDQfJzk9ODI8LjM0Mv/bAEMBCQkJDAsMGA0NGDIhHCEyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMv/CABEIAMgAyAMBIgACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQIDBAUGB//EABgBAQEBAQEAAAAAAAAAAAAAAAABAgME/9oADAMBAAIQAxAAAAH3ZOsiYEgAmIkWEiEiEkiRYSICBVSQSRIBEhQAUAEAARMAJWYmpRBZWYmYkBQAUAAEARIgJEsViidRMKmYmW98M5uVEzQAAAAIABoa3zTLZ9M2Pltl+pvmWU+kvn+xHt7eMzHrcnlMy+mam2AAAAgEBPj9/Y+XWuTb6U1xLbOWNO29EupO1Ea85IOp6/ldXeQoAAEAgJq+G9/pteA6WjoR0ev5v1Rv8Xv8jGuTERF/W07G4yGoAAACCAE1Zz6a6/z33XKXgVv0MXzfd5+1VvY4O/E2i24AACCAgkqqlAiKzXNybOmc/j+i4eNYfQ7G/Ldjy6zdUWioupKWipbRCyYgTCKlAxzjnWcnK6PJl2c2v0+W74djUrPOO28WmguoW6sF4qLREWWVgsrBZRWvNZ1iedbyWN+u6nzfoc9++1PO82X206mx343UF4rBdQWVgtEKmIglAKiZx2TT8j6bl8uvA2e1Obj1d+M69Hm4fa78rRVrN4oLTQXisF4rBaIhLKCygrIcPhnm72znHpagdD0h6uFZOvOAoJECgRBIAC//xAApEAABBAECBQUBAAMAAAAAAAABAAIDBBEFEhATITBAFBUgIjFBIzJw/9oACAEBAAEFAv8AgGVnxyfkD4RPYz4EtuGFC9VK58LkHNPz6jv9XHCwgENyEsgQszhC7ZCGoWUNSmQ1ORN1NQztnb3MFjgh9ljjjjjjpufU9zU60T2D06hjjlfJS2jlxoRBclckrkPXJlXLkC2PVSt6dnckG1X6XpnwzPghgvSOjaesLQUYQWysAPDLlSquiH73Wu3NcxqtVn0pgVuUO2KNlyuVdG12UOppUuUP3vO+jn4exzG2opG8qWHAdNPtDLW58UpLR1VKlykUOnfP+MztOX6QySUMdG/242me1yRPazElGgYiSgMeAQHBw6tK1CBskNfUDCJLQmNOnylnDQPCkZvaFP8AatHWbIq1SGHg0dfDyv8AYRxBqAwEPzwcr+KM/YL+IfnazwysrK3cf4mHBBWVn5Z71yYwVotRlklGqSKLVXlTX5I17jMXQSGSPwrsPPqw0Jo5hp9lR0rLTPWnkMenWN8MZji8K7YfXA1eRe7yBe7yr3aZe6zqpfltTLPDKys8MrKysrKzwz8NRdlmxudgWxi5bMmNhVBobZWeOTxzxyshZWeP/8QAIBEAAgICAgIDAAAAAAAAAAAAAAEREgIwECEDIDFAUP/aAAgBAwEBPwH2j8VaYKlSpUh/ab1ob14j4nShbMe3HGShwPVgmnZFRp5OTJRqw7RUqeT50//EAB8RAAICAwACAwAAAAAAAAAAAAABAhEDEjAQIRNAUP/aAAgBAgEBPwH8uuNmxsbGxa+0lxq/DEuNCRMXsrnIXOfpWxog7Vi5ZGmtWOaItRVEXa5ZHQ8jPlZidrj/AP/EADkQAAEDAQMKAgYLAQAAAAAAAAEAAhEDEiExEBMiMjNAQVFhkTCSQnFygYKhBBQgIzRSYHCiscHh/9oACAEBAAY/Av3IipUa09SvxFPutrTPxLFvfdiXGScT9i4lXVH+ZbZ/dbUrWHlWDOyvY1X0uxVpnvHilpEEYjIB4UcIv8X6zZdIueGrGsOy0X1CfZ/6ryR7lth5Stqz5rXZ3WLfMFq/MLZv7LZv8qGg6/oiTrux6eLaxb6QVtl9F2HRE0jDnGLXJVKdV1oWbieeXBHLEnus5UJtnhyXTxZRpPE0n4ItxY7AnjkhA1HBvrUZzuFIwOSEKlQaf9KOHjWx8QXQp30etrDj/qfTmbBhF5ExgEXvkkqCIRou1Th0OTOVBp8ByUDcI9A4dEHNMOGBVSqK7peZiLk6k8Q4FQCAcRKmoWwPy3oLOVYngOS67jBwKsOx4HnkzuFRmBVlzVotMlZyoPvOXLJJx3KOPDI8LSbKtNZpc8k9t1I/RLngSrGiFgsELPEIK1EX7m5l89Fask3LZnstkhZp8FpNPZRBG5h7ZuV7Hd1dTefWVsj5lsv5K5ndydSc0Na1s4zuebmMCrn8FtFrke5azo9lTaePhVRzCYdwI3H/xAAqEAACAQIFAgYDAQEAAAAAAAABEQAhMRBBUWFxIDCBkaGx0fBAweFw8f/aAAgBAQABPyH/AABYmsY/GZQW6WZqfhZQgGfpgoK59ASAvvmW9lU8p/IEFt4JBjBuAgZy8oq2LxY5wegoC+8y0OZDUwAoBdsAg9gVLGOClq9fM6+QDArm5CZkk5t+cXJfyguXHdJ0ikTI4HQ4tHJSrAIARbABFqICTRWspEEWT4QwILj7qoPUADxMzB5l8QLBTqLgb90L/uDK8X4IXb6OJtPIPYvHzSjQydjfuKHyJ/QPiVgqEgyqYND3D4d1oqwJTY1lPiXUDfpKozeYFYc/qVAfN2Agi6bhwZxkJPiVUAlAa2j0JgehnYAowIChGz5mv/rujU4IOR0h4D0AHI6fE2uKUDQ7iFIVAG0IFWaGziGMxHEPqL3BEuxEWMMCsKk0gBFpH3eGtOXeEo2PMGvMAwGoMpqYAz7BClRIpkbRZyQyxJjJZpqYMVWtWHQebgVFGdIMRWfp3hM5ntAAADvgde+JpxAXGuDDHOCoxCqGPvMKOOpZ4wVNixFvGGIFVXeDIqPIikrrCc45nvuGInQMMGGFdB84BWGgMw7GE7K0vElitdmIAJrfW8VnoIatxi+p9bjgMsFS0MogkFoPeetxow6lRfUwmXTKzTBxx4OOPqcccccdJdANwhRwa9oTChx9hxxx9DGAmriA4TJM0JhUdDjjjjjjwOJF1i6xNcTQ+MPMcMUHjADSM6xnWM6xnWNGdYzrGdY444444THiFgVqyhwCCbQk2UMJo+yc2FVAJMAMRjBUCGxjjjlca9Pj0uOEHHMCR/jOLQD58rSjgxuiCWhQENQImoSS9+prE4sRjBxiNZQHaAJpVAN60ELs66YSKRsuNtESChbv+YfovtAWYHcNYScDnCHAfUjhHgcHOBKCRubRlQkb1zymTc60y8o0020CqNNZoqBa4BBxGAIWAdnnie8Z2nBGYzDDGkYGUTOGBDecI45//9oADAMBAAIAAwAAABBhJAAuAIJODC8BIgAcABSwEABb+gg6kAFzww0BA2VeEAbzzzzwEABOYQKszTzziED9uo7LMuvzzwED+T//AGn90884hEzOM7iPE4whC/Nwm7969CeLeTKcADfPOfnOGS0kgwBv7elCCCmwMoMyt33g4A4AAV9xefhCjhABAgffA//EAB4RAAMBAAIDAQEAAAAAAAAAAAABERAgMSEwQVFh/9oACAEDAQE/EOExoQntfK4sWwnB6hvT5dka2JXF6tuvv0LKs+HGlxMs7If0fhyu0bqGGG8UvPyOgxspUVFWUpcVoTtia6PIDs/UpfATWloGN55UpcWAlf1iQWR6f//EAB4RAAMAAgMBAQEAAAAAAAAAAAABERAhIDAxQVFh/9oACAECAQE/EOVLyfa8b60qJJYY+loPBJI3dDw0RjTN/glF0N6glRppvEvehmjRfjP4PpyhMMQdE2WleEIQhCYhB6NmeSEIQhCExsZM8F6O9ESeEREIQnJo8GMei4ovR1UNwU8SGlxJDKvk8//EACkQAQACAgECBgIDAQEBAAAAAAEAESExQVFhEHGBkaGxMMEg0fBA4fH/2gAIAQEAAT8QgeBrwD+FQjqUypUqV+EPAPE14mv+KvEL8T874PjUPxX+F8TwNf8AOeF0QXT8QcsFQTSP5alfwZUrwvWebrLB7w8QNL7zh93/ABNyIbrntNORj0Srw7gj+orYCRwW68OtRscpAAmn8L/ItqZnPRb8TQuvX90IM50/vmAMOZ+4hHMeq5cTqdGomtZ6xMjxAOUPsTDjWN8QRdepuGcev5bg65ESndlwVa83Vf3EIQ+kpbAHsQhRp1f2mYH7GvmYMvTMfMrP9g6QzzL+tBtv90fuUK8mR+2Wsdcv6kmSZKDT9z8rD3EUhhGK0Ab4DmKsGMKov99oAitaTVvrNhm4xaRqrIZ4KJb5S14UTa3cAA16wxfS5QzYd7ltEaw733uLi4fXSqfdh+RSXYhQMZRyYLOEdETRCNlFfrajFcxmHzIo06cq3ssyxVvoH9pXWPQ/c0PkKfYS2VjRk+4iAyHZCeRtHhh2DHxcvaWuZspIBis0ZNw8gA1oBwee8P5P8sTbK6GsOUNnJ5EMt1BWZE9Hh9IJCnQJ7FmrLmKZs83UUDza1XaWwcmGGCbdLiGwZtqf+TEwDHVfLp3jqL7Gn3iwU6aDkjSopZHPFdZgYAqne/yoZ4YuDq6+X5TxI9wLa7kCorqGy9t7XCV0gYcxMRuj+jqRDR+Bq39xWGVkekOSkD7JBb66lQZYwL6wXQjaL7g6hf8AyJm5p3XVRwKqu3R++7iZuphnPaABQUH5XTEqo0cez5PIlYV9a42I+ySn1CgUj5RpP0zWa79VXp5RKxMS0sWcgCpLUE2ct/nylDG6RYvvHcAWzWWJ0GqTv2iICRKBdvSFUrtMhf8ALeNR7LqWtOvnNJB/K/G5fitQloYzvD8u7jo45i8qflo8icjyQU3EiNeeaH1qLWWv2sUeSrX0zBDTo1vWgZpHZLBNstA1aCiEvZ0DKTRXdQiXlbvs694tRHafb2gFLUym1/EsGpcuLAQnZKHki9hbHz9nPvKOeY0NUecAeoZvt5XD30UGAfPp7w4vSK2RxQVcNuDeBO5Tf10cqntvg2L0IbBd/Y4DtLlyhLPG4udzzRaikvwXPhlzBERJsUdlpiqC5mYZavZL8GggkHkn1AoOpUTTS69JRQOtTiVYCYHL5suWlplLlpad0slxZcW4giJWGEsTZ5QPul55AHzKgLTtKgFDZbZkmLKeC+8vvL7y+8vvLlhKeBjKXFzLIdaA6kG29L6mU3mZV9v3HVAs01H4IFwpXLlstO6Uj4NesqckRKx6h7ynhG/SI8IGNsB6sbk8vqDZKG+j6ZYbl5931KXuPqyo2neTvJ3kU5Z3k7yPWS/Vi+rHHcXUvL9YjzFHMt6y5YAARJvmbuwC9V6xSp4Cg4sgxadq/qLhYuUVX6QgoE2Gr8pYcZ5ikLipaWixb1guoo3LUje7mes9Y7h1uevhzmUdQtBBRWcVOPmYMyO58oOrkcWEwC6HJ9R9ptZ7+IMdQ4pedbilJAdquJXNxSoAtxQlzJFx5ziZYmo1d0yqWSyIvmWTNi0XqlYhZRWQeo/UeBDBMwXmUXrU/wAMQylncn6hxS8BA8q+0K6DjZ8S9gCXeoA0V7QTxHKn4le8v1MStQFR733iThlb1G+os7xFZqNCiFvFRVQ7os8mCnVlSrQTZksBafGFMbFabqDpQDANvh5t+8b4RwWOb5sls01QtHzqcUUqD69YMXdONZAm6W5YvLfvADSssaHrG3cC8S3al3BKOBlhf3lCEYUYX0gnSgVwhj1ktsCf/9k=', + }, + }, + ], + }, + ], + }, + }, + ]; + for (const test of testCases) { + it(test.should, () => { + assert.deepEqual( + toAnthropicRequest(MODEL_ID, test.input), + test.expectedOutput + ); + }); + } +}); + +describe('fromAnthropicResponse', () => { + const testCases: { + should: string; + input: GenerateRequest; + response: Message; + expectedOutput: GenerateResponseData; + }[] = [ + { + should: 'should transform genkit message (text content) correctly', + input: { + messages: [ + { + role: 'user', + content: [{ text: 'Tell a joke about dogs.' }], + }, + ], + }, + response: { + id: 'abcd1234', + model: MODEL_ID, + role: 'assistant', + stop_reason: 'end_turn', + usage: { + input_tokens: 123, + output_tokens: 234, + }, + stop_sequence: null, + type: 'message', + content: [ + { + type: 'text', + text: 'part 1', + }, + { + type: 'text', + text: 'part 2', + }, + ], + }, + expectedOutput: { + custom: { + id: 'abcd1234', + model: MODEL_ID, + type: 'message', + }, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + text: 'part 1', + }, + { + text: 'part 2', + }, + ], + }, + usage: { + inputAudioFiles: 0, + inputCharacters: 23, + inputImages: 0, + inputTokens: 123, + inputVideos: 0, + outputAudioFiles: 0, + outputCharacters: 12, + outputImages: 0, + outputTokens: 234, + outputVideos: 0, + }, + }, + }, + { + should: 'should transform genkit tool call correctly', + input: { + messages: [ + { + role: 'user', + content: [{ text: "What's the weather like today?" }], + }, + ], + tools: [ + { + name: 'get_weather', + description: 'Get the weather for a location.', + inputSchema: { + type: 'object', + properties: { + location: { + type: 'string', + description: 'The city and state, e.g. San Francisco, CA', + }, + }, + required: ['location'], + }, + }, + ], + }, + response: { + id: 'abcd1234', + model: MODEL_ID, + role: 'assistant', + type: 'message', + stop_reason: 'tool_use', + stop_sequence: null, + usage: { + input_tokens: 123, + output_tokens: 234, + }, + content: [ + { + id: 'toolu_get_weather', + name: 'get_weather', + type: 'tool_use', + input: { + type: 'object', + properties: { + location: { + type: 'string', + description: 'The city and state, e.g. San Francisco, CA', + }, + }, + required: ['location'], + }, + }, + ], + }, + expectedOutput: { + custom: { + id: 'abcd1234', + model: MODEL_ID, + type: 'message', + }, + finishReason: 'stop', + message: { + role: 'model', + content: [ + { + toolRequest: { + name: 'get_weather', + ref: 'toolu_get_weather', + input: { + type: 'object', + properties: { + location: { + type: 'string', + description: 'The city and state, e.g. San Francisco, CA', + }, + }, + required: ['location'], + }, + }, + }, + ], + }, + usage: { + inputAudioFiles: 0, + inputCharacters: 30, + inputImages: 0, + inputTokens: 123, + inputVideos: 0, + outputAudioFiles: 0, + outputCharacters: 0, + outputImages: 0, + outputTokens: 234, + outputVideos: 0, + }, + }, + }, + ]; + for (const test of testCases) { + it(test.should, () => { + assert.deepEqual( + fromAnthropicResponse(test.input, test.response), + test.expectedOutput + ); + }); + } +}); diff --git a/js/plugins/checks/tests/gemini_test.ts b/js/plugins/checks/tests/gemini_test.ts new file mode 100644 index 000000000..c6156b4be --- /dev/null +++ b/js/plugins/checks/tests/gemini_test.ts @@ -0,0 +1,347 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { GenerateContentCandidate } from '@google-cloud/vertexai'; +import { MessageData } from 'genkit'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { + fromGeminiCandidate, + toGeminiMessage, + toGeminiSystemInstruction, +} from '../src/gemini.js'; + +describe('toGeminiMessages', () => { + const testCases = [ + { + should: 'should transform genkit message (text content) correctly', + inputMessage: { + role: 'user', + content: [{ text: 'Tell a joke about dogs.' }], + }, + expectedOutput: { + role: 'user', + parts: [{ text: 'Tell a joke about dogs.' }], + }, + }, + { + should: + 'should transform genkit message (tool request content) correctly', + inputMessage: { + role: 'model', + content: [ + { toolRequest: { name: 'tellAFunnyJoke', input: { topic: 'dogs' } } }, + ], + }, + expectedOutput: { + role: 'model', + parts: [ + { functionCall: { name: 'tellAFunnyJoke', args: { topic: 'dogs' } } }, + ], + }, + }, + { + should: + 'should transform genkit message (tool response content) correctly', + inputMessage: { + role: 'tool', + content: [ + { + toolResponse: { + name: 'tellAFunnyJoke', + output: 'Why did the dogs cross the road?', + }, + }, + ], + }, + expectedOutput: { + role: 'function', + parts: [ + { + functionResponse: { + name: 'tellAFunnyJoke', + response: { + name: 'tellAFunnyJoke', + content: 'Why did the dogs cross the road?', + }, + }, + }, + ], + }, + }, + { + should: + 'should transform genkit message (inline base64 image content) correctly', + inputMessage: { + role: 'user', + content: [ + { text: 'describe the following image:' }, + { + media: { + contentType: 'image/jpeg', + url: '', + }, + }, + ], + }, + expectedOutput: { + role: 'user', + parts: [ + { text: 'describe the following image:' }, + { + inlineData: { + mimeType: 'image/jpeg', + data: '/9j/4QDeRXhpZgAASUkqAAgAAAAGABIBAwABAAAAAQAAABoBBQABAAAAVgAAABsBBQABAAAAXgAAACgBAwABAAAAAgAAABMCAwABAAAAAQAAAGmHBAABAAAAZgAAAAAAAABIAAAAAQAAAEgAAAABAAAABwAAkAcABAAAADAyMTABkQcABAAAAAECAwCGkgcAFgAAAMAAAAAAoAcABAAAADAxMDABoAMAAQAAAP//AAACoAQAAQAAAMgAAAADoAQAAQAAAMgAAAAAAAAAQVNDSUkAAABQaWNzdW0gSUQ6IDY4N//bAEMACAYGBwYFCAcHBwkJCAoMFA0MCwsMGRITDxQdGh8eHRocHCAkLicgIiwjHBwoNyksMDE0NDQfJzk9ODI8LjM0Mv/bAEMBCQkJDAsMGA0NGDIhHCEyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMv/CABEIAMgAyAMBIgACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQIDBAUGB//EABgBAQEBAQEAAAAAAAAAAAAAAAABAgME/9oADAMBAAIQAxAAAAH3ZOsiYEgAmIkWEiEiEkiRYSICBVSQSRIBEhQAUAEAARMAJWYmpRBZWYmYkBQAUAAEARIgJEsViidRMKmYmW98M5uVEzQAAAAIABoa3zTLZ9M2Pltl+pvmWU+kvn+xHt7eMzHrcnlMy+mam2AAAAgEBPj9/Y+XWuTb6U1xLbOWNO29EupO1Ea85IOp6/ldXeQoAAEAgJq+G9/pteA6WjoR0ev5v1Rv8Xv8jGuTERF/W07G4yGoAAACCAE1Zz6a6/z33XKXgVv0MXzfd5+1VvY4O/E2i24AACCAgkqqlAiKzXNybOmc/j+i4eNYfQ7G/Ldjy6zdUWioupKWipbRCyYgTCKlAxzjnWcnK6PJl2c2v0+W74djUrPOO28WmguoW6sF4qLREWWVgsrBZRWvNZ1iedbyWN+u6nzfoc9++1PO82X206mx343UF4rBdQWVgtEKmIglAKiZx2TT8j6bl8uvA2e1Obj1d+M69Hm4fa78rRVrN4oLTQXisF4rBaIhLKCygrIcPhnm72znHpagdD0h6uFZOvOAoJECgRBIAC//xAApEAABBAECBQUBAAMAAAAAAAABAAIDBBEFEhATITBAFBUgIjFBIzJw/9oACAEBAAEFAv8AgGVnxyfkD4RPYz4EtuGFC9VK58LkHNPz6jv9XHCwgENyEsgQszhC7ZCGoWUNSmQ1ORN1NQztnb3MFjgh9ljjjjjjpufU9zU60T2D06hjjlfJS2jlxoRBclckrkPXJlXLkC2PVSt6dnckG1X6XpnwzPghgvSOjaesLQUYQWysAPDLlSquiH73Wu3NcxqtVn0pgVuUO2KNlyuVdG12UOppUuUP3vO+jn4exzG2opG8qWHAdNPtDLW58UpLR1VKlykUOnfP+MztOX6QySUMdG/242me1yRPazElGgYiSgMeAQHBw6tK1CBskNfUDCJLQmNOnylnDQPCkZvaFP8AatHWbIq1SGHg0dfDyv8AYRxBqAwEPzwcr+KM/YL+IfnazwysrK3cf4mHBBWVn5Z71yYwVotRlklGqSKLVXlTX5I17jMXQSGSPwrsPPqw0Jo5hp9lR0rLTPWnkMenWN8MZji8K7YfXA1eRe7yBe7yr3aZe6zqpfltTLPDKys8MrKysrKzwz8NRdlmxudgWxi5bMmNhVBobZWeOTxzxyshZWeP/8QAIBEAAgICAgIDAAAAAAAAAAAAAAEREgIwECEDIDFAUP/aAAgBAwEBPwH2j8VaYKlSpUh/ab1ob14j4nShbMe3HGShwPVgmnZFRp5OTJRqw7RUqeT50//EAB8RAAICAwACAwAAAAAAAAAAAAABAhEDEjAQIRNAUP/aAAgBAgEBPwH8uuNmxsbGxa+0lxq/DEuNCRMXsrnIXOfpWxog7Vi5ZGmtWOaItRVEXa5ZHQ8jPlZidrj/AP/EADkQAAEDAQMKAgYLAQAAAAAAAAEAAhEDEiExEBMiMjNAQVFhkTCSQnFygYKhBBQgIzRSYHCiscHh/9oACAEBAAY/Av3IipUa09SvxFPutrTPxLFvfdiXGScT9i4lXVH+ZbZ/dbUrWHlWDOyvY1X0uxVpnvHilpEEYjIB4UcIv8X6zZdIueGrGsOy0X1CfZ/6ryR7lth5Stqz5rXZ3WLfMFq/MLZv7LZv8qGg6/oiTrux6eLaxb6QVtl9F2HRE0jDnGLXJVKdV1oWbieeXBHLEnus5UJtnhyXTxZRpPE0n4ItxY7AnjkhA1HBvrUZzuFIwOSEKlQaf9KOHjWx8QXQp30etrDj/qfTmbBhF5ExgEXvkkqCIRou1Th0OTOVBp8ByUDcI9A4dEHNMOGBVSqK7peZiLk6k8Q4FQCAcRKmoWwPy3oLOVYngOS67jBwKsOx4HnkzuFRmBVlzVotMlZyoPvOXLJJx3KOPDI8LSbKtNZpc8k9t1I/RLngSrGiFgsELPEIK1EX7m5l89Fask3LZnstkhZp8FpNPZRBG5h7ZuV7Hd1dTefWVsj5lsv5K5ndydSc0Na1s4zuebmMCrn8FtFrke5azo9lTaePhVRzCYdwI3H/xAAqEAACAQIFAgYDAQEAAAAAAAABEQAhMRBBUWFxIDCBkaGx0fBAweFw8f/aAAgBAQABPyH/AABYmsY/GZQW6WZqfhZQgGfpgoK59ASAvvmW9lU8p/IEFt4JBjBuAgZy8oq2LxY5wegoC+8y0OZDUwAoBdsAg9gVLGOClq9fM6+QDArm5CZkk5t+cXJfyguXHdJ0ikTI4HQ4tHJSrAIARbABFqICTRWspEEWT4QwILj7qoPUADxMzB5l8QLBTqLgb90L/uDK8X4IXb6OJtPIPYvHzSjQydjfuKHyJ/QPiVgqEgyqYND3D4d1oqwJTY1lPiXUDfpKozeYFYc/qVAfN2Agi6bhwZxkJPiVUAlAa2j0JgehnYAowIChGz5mv/rujU4IOR0h4D0AHI6fE2uKUDQ7iFIVAG0IFWaGziGMxHEPqL3BEuxEWMMCsKk0gBFpH3eGtOXeEo2PMGvMAwGoMpqYAz7BClRIpkbRZyQyxJjJZpqYMVWtWHQebgVFGdIMRWfp3hM5ntAAADvgde+JpxAXGuDDHOCoxCqGPvMKOOpZ4wVNixFvGGIFVXeDIqPIikrrCc45nvuGInQMMGGFdB84BWGgMw7GE7K0vElitdmIAJrfW8VnoIatxi+p9bjgMsFS0MogkFoPeetxow6lRfUwmXTKzTBxx4OOPqcccccdJdANwhRwa9oTChx9hxxx9DGAmriA4TJM0JhUdDjjjjjjwOJF1i6xNcTQ+MPMcMUHjADSM6xnWM6xnWNGdYzrGdY444444THiFgVqyhwCCbQk2UMJo+yc2FVAJMAMRjBUCGxjjjlca9Pj0uOEHHMCR/jOLQD58rSjgxuiCWhQENQImoSS9+prE4sRjBxiNZQHaAJpVAN60ELs66YSKRsuNtESChbv+YfovtAWYHcNYScDnCHAfUjhHgcHOBKCRubRlQkb1zymTc60y8o0020CqNNZoqBa4BBxGAIWAdnnie8Z2nBGYzDDGkYGUTOGBDecI45//9oADAMBAAIAAwAAABBhJAAuAIJODC8BIgAcABSwEABb+gg6kAFzww0BA2VeEAbzzzzwEABOYQKszTzziED9uo7LMuvzzwED+T//AGn90884hEzOM7iPE4whC/Nwm7969CeLeTKcADfPOfnOGS0kgwBv7elCCCmwMoMyt33g4A4AAV9xefhCjhABAgffA//EAB4RAAMBAAIDAQEAAAAAAAAAAAABERAgMSEwQVFh/9oACAEDAQE/EOExoQntfK4sWwnB6hvT5dka2JXF6tuvv0LKs+HGlxMs7If0fhyu0bqGGG8UvPyOgxspUVFWUpcVoTtia6PIDs/UpfATWloGN55UpcWAlf1iQWR6f//EAB4RAAMAAgMBAQEAAAAAAAAAAAABERAhIDAxQVFh/9oACAECAQE/EOVLyfa8b60qJJYY+loPBJI3dDw0RjTN/glF0N6glRppvEvehmjRfjP4PpyhMMQdE2WleEIQhCYhB6NmeSEIQhCExsZM8F6O9ESeEREIQnJo8GMei4ovR1UNwU8SGlxJDKvk8//EACkQAQACAgECBgIDAQEBAAAAAAEAESExQVFhEHGBkaGxMMEg0fBA4fH/2gAIAQEAAT8QgeBrwD+FQjqUypUqV+EPAPE14mv+KvEL8T874PjUPxX+F8TwNf8AOeF0QXT8QcsFQTSP5alfwZUrwvWebrLB7w8QNL7zh93/ABNyIbrntNORj0Srw7gj+orYCRwW68OtRscpAAmn8L/ItqZnPRb8TQuvX90IM50/vmAMOZ+4hHMeq5cTqdGomtZ6xMjxAOUPsTDjWN8QRdepuGcev5bg65ESndlwVa83Vf3EIQ+kpbAHsQhRp1f2mYH7GvmYMvTMfMrP9g6QzzL+tBtv90fuUK8mR+2Wsdcv6kmSZKDT9z8rD3EUhhGK0Ab4DmKsGMKov99oAitaTVvrNhm4xaRqrIZ4KJb5S14UTa3cAA16wxfS5QzYd7ltEaw733uLi4fXSqfdh+RSXYhQMZRyYLOEdETRCNlFfrajFcxmHzIo06cq3ssyxVvoH9pXWPQ/c0PkKfYS2VjRk+4iAyHZCeRtHhh2DHxcvaWuZspIBis0ZNw8gA1oBwee8P5P8sTbK6GsOUNnJ5EMt1BWZE9Hh9IJCnQJ7FmrLmKZs83UUDza1XaWwcmGGCbdLiGwZtqf+TEwDHVfLp3jqL7Gn3iwU6aDkjSopZHPFdZgYAqne/yoZ4YuDq6+X5TxI9wLa7kCorqGy9t7XCV0gYcxMRuj+jqRDR+Bq39xWGVkekOSkD7JBb66lQZYwL6wXQjaL7g6hf8AyJm5p3XVRwKqu3R++7iZuphnPaABQUH5XTEqo0cez5PIlYV9a42I+ySn1CgUj5RpP0zWa79VXp5RKxMS0sWcgCpLUE2ct/nylDG6RYvvHcAWzWWJ0GqTv2iICRKBdvSFUrtMhf8ALeNR7LqWtOvnNJB/K/G5fitQloYzvD8u7jo45i8qflo8icjyQU3EiNeeaH1qLWWv2sUeSrX0zBDTo1vWgZpHZLBNstA1aCiEvZ0DKTRXdQiXlbvs694tRHafb2gFLUym1/EsGpcuLAQnZKHki9hbHz9nPvKOeY0NUecAeoZvt5XD30UGAfPp7w4vSK2RxQVcNuDeBO5Tf10cqntvg2L0IbBd/Y4DtLlyhLPG4udzzRaikvwXPhlzBERJsUdlpiqC5mYZavZL8GggkHkn1AoOpUTTS69JRQOtTiVYCYHL5suWlplLlpad0slxZcW4giJWGEsTZ5QPul55AHzKgLTtKgFDZbZkmLKeC+8vvL7y+8vvLlhKeBjKXFzLIdaA6kG29L6mU3mZV9v3HVAs01H4IFwpXLlstO6Uj4NesqckRKx6h7ynhG/SI8IGNsB6sbk8vqDZKG+j6ZYbl5931KXuPqyo2neTvJ3kU5Z3k7yPWS/Vi+rHHcXUvL9YjzFHMt6y5YAARJvmbuwC9V6xSp4Cg4sgxadq/qLhYuUVX6QgoE2Gr8pYcZ5ikLipaWixb1guoo3LUje7mes9Y7h1uevhzmUdQtBBRWcVOPmYMyO58oOrkcWEwC6HJ9R9ptZ7+IMdQ4pedbilJAdquJXNxSoAtxQlzJFx5ziZYmo1d0yqWSyIvmWTNi0XqlYhZRWQeo/UeBDBMwXmUXrU/wAMQylncn6hxS8BA8q+0K6DjZ8S9gCXeoA0V7QTxHKn4le8v1MStQFR733iThlb1G+os7xFZqNCiFvFRVQ7os8mCnVlSrQTZksBafGFMbFabqDpQDANvh5t+8b4RwWOb5sls01QtHzqcUUqD69YMXdONZAm6W5YvLfvADSssaHrG3cC8S3al3BKOBlhf3lCEYUYX0gnSgVwhj1ktsCf/9k=', + }, + }, + ], + }, + }, + ]; + for (const test of testCases) { + it(test.should, () => { + assert.deepEqual( + toGeminiMessage(test.inputMessage as MessageData), + test.expectedOutput + ); + }); + } +}); + +describe('toGeminiSystemInstruction', () => { + const testCases = [ + { + should: 'should transform from system to user', + inputMessage: { + role: 'system', + content: [{ text: 'You are an expert in all things cats.' }], + }, + expectedOutput: { + role: 'user', + parts: [{ text: 'You are an expert in all things cats.' }], + }, + }, + { + should: 'should transform from system to user with multiple parts', + inputMessage: { + role: 'system', + content: [ + { text: 'You are an expert in all things animals.' }, + { text: 'You love cats.' }, + ], + }, + expectedOutput: { + role: 'user', + parts: [ + { text: 'You are an expert in all things animals.' }, + { text: 'You love cats.' }, + ], + }, + }, + ]; + for (const test of testCases) { + it(test.should, () => { + assert.deepEqual( + toGeminiSystemInstruction(test.inputMessage as MessageData), + test.expectedOutput + ); + }); + } +}); + +describe('fromGeminiCandidate', () => { + const testCases = [ + { + should: + 'should transform gemini candidate to genkit candidate (text parts) correctly', + // had to delete the probabilityScore, severity, severityScore for the HARM_CATEGORY_SEXUALLY_EXPLICIT safety rating category + geminiCandidate: { + content: { + role: 'model', + parts: [ + { + text: 'Why did the dog go to the bank?\n\nTo get his bones cashed!', + }, + ], + }, + finishReason: 'STOP', + safetyRatings: [ + { + category: 'HARM_CATEGORY_HATE_SPEECH', + probability: 'NEGLIGIBLE', + probabilityScore: 0.12074952, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.18388656, + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.37874627, + severity: 'HARM_SEVERITY_LOW', + severityScore: 0.37227696, + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.3983479, + severity: 'HARM_SEVERITY_LOW', + severityScore: 0.22270013, + }, + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + probability: 'NEGLIGIBLE', + }, + ], + }, + expectedOutput: { + index: 0, + message: { + role: 'model', + content: [ + { + text: 'Why did the dog go to the bank?\n\nTo get his bones cashed!', + }, + ], + }, + finishReason: 'stop', + finishMessage: undefined, + custom: { + citationMetadata: undefined, + safetyRatings: [ + { + category: 'HARM_CATEGORY_HATE_SPEECH', + probability: 'NEGLIGIBLE', + probabilityScore: 0.12074952, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.18388656, + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.37874627, + severity: 'HARM_SEVERITY_LOW', + severityScore: 0.37227696, + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.3983479, + severity: 'HARM_SEVERITY_LOW', + severityScore: 0.22270013, + }, + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + probability: 'NEGLIGIBLE', + }, + ], + }, + }, + }, + { + should: + 'should transform gemini candidate to genkit candidate (function call parts) correctly', + geminiCandidate: { + content: { + role: 'model', + parts: [ + { + functionCall: { name: 'tellAFunnyJoke', args: { topic: 'dog' } }, + }, + ], + }, + finishReason: 'STOP', + safetyRatings: [ + { + category: 'HARM_CATEGORY_HATE_SPEECH', + probability: 'NEGLIGIBLE', + probabilityScore: 0.11858909, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.11456649, + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.13857833, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.11417085, + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.28012377, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.112405084, + }, + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + probability: 'NEGLIGIBLE', + }, + ], + }, + expectedOutput: { + index: 0, + message: { + role: 'model', + content: [ + { + toolRequest: { name: 'tellAFunnyJoke', input: { topic: 'dog' } }, + }, + ], + }, + finishReason: 'stop', + finishMessage: undefined, + custom: { + citationMetadata: undefined, + safetyRatings: [ + { + category: 'HARM_CATEGORY_HATE_SPEECH', + probability: 'NEGLIGIBLE', + probabilityScore: 0.11858909, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.11456649, + }, + { + category: 'HARM_CATEGORY_DANGEROUS_CONTENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.13857833, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.11417085, + }, + { + category: 'HARM_CATEGORY_HARASSMENT', + probability: 'NEGLIGIBLE', + probabilityScore: 0.28012377, + severity: 'HARM_SEVERITY_NEGLIGIBLE', + severityScore: 0.112405084, + }, + { + category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', + probability: 'NEGLIGIBLE', + }, + ], + }, + }, + }, + ]; + for (const test of testCases) { + it(test.should, () => { + assert.deepEqual( + fromGeminiCandidate(test.geminiCandidate as GenerateContentCandidate), + test.expectedOutput + ); + }); + } +}); diff --git a/js/plugins/checks/tests/vector-search/bigquery_test.ts b/js/plugins/checks/tests/vector-search/bigquery_test.ts new file mode 100644 index 000000000..1cbc54314 --- /dev/null +++ b/js/plugins/checks/tests/vector-search/bigquery_test.ts @@ -0,0 +1,168 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { BigQuery } from '@google-cloud/bigquery'; +import { Document } from 'genkit/retriever'; +import assert from 'node:assert'; +import { describe, it } from 'node:test'; +import { getBigQueryDocumentRetriever } from '../../src'; + +class MockBigQuery { + query: Function; + + constructor({ + mockRows, + shouldThrowError = false, + }: { + mockRows: any[]; + shouldThrowError?: boolean; + }) { + this.query = async (_options: { + query: string; + params: { ids: string[] }; + }) => { + if (shouldThrowError) { + throw new Error('Query failed'); + } + return [mockRows]; + }; + } +} + +describe('getBigQueryDocumentRetriever', () => { + it('returns a function that retrieves documents from BigQuery', async () => { + const doc1 = Document.fromText('content1'); + const doc2 = Document.fromText('content2'); + + const mockRows = [ + { + id: '1', + content: JSON.stringify(doc1.content), + metadata: null, + }, + { + id: '2', + content: JSON.stringify(doc2.content), + metadata: null, + }, + ]; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '1' } }, + { datapoint: { datapointId: '2' } }, + ]); + + assert.deepStrictEqual(documents, [doc1, doc2]); + }); + + it('returns an empty array when no documents match', async () => { + const mockRows: any[] = []; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '3' } }, + ]); + + assert.deepStrictEqual(documents, []); + }); + + it('handles BigQuery query errors', async () => { + const mockBigQuery = new MockBigQuery({ + mockRows: [], + shouldThrowError: true, + }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + // no need to assert the error, just make sure it doesn't throw + await documentRetriever([{ datapoint: { datapointId: '1' } }]); + }); + + it('filters out invalid documents', async () => { + const validDoc = Document.fromText('valid content'); + const mockRows = [ + { + id: '1', + content: JSON.stringify(validDoc.content), + metadata: null, + }, + { + id: '2', + content: 'invalid JSON', + metadata: null, + }, + ]; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '1' } }, + { datapoint: { datapointId: '2' } }, + ]); + + assert.deepStrictEqual(documents, [validDoc]); + }); + + it('handles missing content in documents', async () => { + const validDoc = Document.fromText('valid content'); + const mockRows = [ + { + id: '1', + content: JSON.stringify(validDoc.content), + metadata: null, + }, + { + id: '2', + content: null, + metadata: null, + }, + ]; + + const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; + const documentRetriever = getBigQueryDocumentRetriever( + mockBigQuery, + 'test-table', + 'test-dataset' + ); + + const documents = await documentRetriever([ + { datapoint: { datapointId: '1' } }, + { datapoint: { datapointId: '2' } }, + ]); + + assert.deepStrictEqual(documents, [validDoc]); + }); +}); diff --git a/js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts b/js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts new file mode 100644 index 000000000..9419f2916 --- /dev/null +++ b/js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts @@ -0,0 +1,86 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'assert'; +import { describe, it, Mock } from 'node:test'; +import { queryPublicEndpoint } from '../../src/vector-search/query_public_endpoint'; + +describe('queryPublicEndpoint', () => { + // FIXME -- t.mock.method is not supported node above 20 + it.skip('queryPublicEndpoint sends the correct request and retrieves neighbors', async (t) => { + t.mock.method(global, 'fetch', async (url, options) => { + return { + ok: true, + json: async () => ({ neighbors: ['neighbor1', 'neighbor2'] }), + } as any; + }); + + const params = { + featureVector: [0.1, 0.2, 0.3], + neighborCount: 5, + accessToken: 'test-access-token', + projectId: 'test-project-id', + location: 'us-central1', + indexEndpointId: 'idx123', + publicDomainName: 'example.com', + projectNumber: '123456789', + deployedIndexId: 'deployed-idx123', + }; + + const expectedResponse = { neighbors: ['neighbor1', 'neighbor2'] }; + + const response = await queryPublicEndpoint(params); + + const calls = ( + global.fetch as Mock< + (url: string, options: Record) => Promise + > + ).mock.calls; + + assert.strictEqual(calls.length, 1); + + const [url, options] = calls[0].arguments; + + const expectedUrl = `https://example.com/v1/projects/123456789/locations/us-central1/indexEndpoints/idx123:findNeighbors`; + + assert.strictEqual(url.toString(), expectedUrl); + + assert.strictEqual(options.method, 'POST'); + + assert.strictEqual(options.headers['Content-Type'], 'application/json'); + assert.strictEqual( + options.headers['Authorization'], + 'Bearer test-access-token' + ); + + const body = JSON.parse(options.body); + assert.deepStrictEqual(body, { + deployed_index_id: 'deployed-idx123', + queries: [ + { + datapoint: { + datapoint_id: '0', + feature_vector: [0.1, 0.2, 0.3], + }, + neighbor_count: 5, + }, + ], + }); + + // Verifying the response + assert.deepStrictEqual(response, expectedResponse); + }); +}); diff --git a/js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts b/js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts new file mode 100644 index 000000000..5b36a47d0 --- /dev/null +++ b/js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts @@ -0,0 +1,81 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'assert'; +import { GoogleAuth } from 'google-auth-library'; +import { describe, it, Mock } from 'node:test'; +import { IIndexDatapoint } from '../../src/vector-search/types'; +import { upsertDatapoints } from '../../src/vector-search/upsert_datapoints'; + +describe('upsertDatapoints', () => { + // FIXME -- t.mock.method is not supported node above 20 + it.skip('upsertDatapoints sends the correct request and handles response', async (t) => { + // Mocking the fetch method within the test scope + t.mock.method(global, 'fetch', async (url, options) => { + return { + ok: true, + json: async () => ({}), + } as any; + }); + + // Mocking the GoogleAuth client + const mockAuthClient = { + getAccessToken: async () => 'test-access-token', + } as GoogleAuth; + + const params = { + datapoints: [ + { datapointId: 'dp1', featureVector: [0.1, 0.2, 0.3] }, + { datapointId: 'dp2', featureVector: [0.4, 0.5, 0.6] }, + ] as IIndexDatapoint[], + authClient: mockAuthClient, + projectId: 'test-project-id', + location: 'us-central1', + indexId: 'idx123', + }; + + await upsertDatapoints(params); + + // Verifying the fetch call + const calls = ( + global.fetch as Mock< + (url: string, options: Record) => Promise + > + ).mock.calls; + + assert.strictEqual(calls.length, 1); + const [url, options] = calls[0].arguments; + + assert.strictEqual( + url.toString(), + 'https://us-central1-aiplatform.googleapis.com/v1/projects/test-project-id/locations/us-central1/indexes/idx123:upsertDatapoints' + ); + assert.strictEqual(options.method, 'POST'); + assert.strictEqual(options.headers['Content-Type'], 'application/json'); + assert.strictEqual( + options.headers['Authorization'], + 'Bearer test-access-token' + ); + + const body = JSON.parse(options.body); + assert.deepStrictEqual(body, { + datapoints: [ + { datapoint_id: 'dp1', feature_vector: [0.1, 0.2, 0.3] }, + { datapoint_id: 'dp2', feature_vector: [0.4, 0.5, 0.6] }, + ], + }); + }); +}); diff --git a/js/plugins/checks/tests/vector-search/utils_test.ts b/js/plugins/checks/tests/vector-search/utils_test.ts new file mode 100644 index 000000000..38b130b3a --- /dev/null +++ b/js/plugins/checks/tests/vector-search/utils_test.ts @@ -0,0 +1,70 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import assert from 'assert'; +import { google } from 'googleapis'; +import { describe, it } from 'node:test'; +import { + getAccessToken, + getProjectNumber, +} from '../../src/vector-search/utils'; + +// Mocking the google.auth.getClient method +google.auth.getClient = async () => { + return { + getRequestHeaders: async () => ({ Authorization: 'Bearer test-token' }), + } as any; // Using `any` to bypass type checks for the mock +}; + +// Mocking the google.cloudresourcemanager method +google.cloudresourcemanager = () => { + return { + projects: { + get: async ({ projectId }) => { + return { + data: { + projectNumber: '123456789', + }, + }; + }, + }, + } as any; // Using `any` to bypass type checks for the mock +}; + +describe('utils', () => { + it('getProjectNumber retrieves the project number', async () => { + const projectId = 'test-project-id'; + const expectedProjectNumber = '123456789'; + + const projectNumber = await getProjectNumber(projectId); + assert.strictEqual(projectNumber, expectedProjectNumber); + }); + + // Mocking the GoogleAuth client + const mockAuthClient = { + getAccessToken: async () => ({ token: 'test-access-token' }), + }; + + it('getAccessToken retrieves the access token', async () => { + // Mocking the GoogleAuth.getClient method to return the mockAuthClient + const auth = { + getClient: async () => mockAuthClient, + } as any; // Using `any` to bypass type checks for the mock + + const accessToken = await getAccessToken(auth); + assert.strictEqual(accessToken, 'test-access-token'); + }); +}); diff --git a/js/plugins/checks/tsconfig.json b/js/plugins/checks/tsconfig.json new file mode 100644 index 000000000..596e2cf72 --- /dev/null +++ b/js/plugins/checks/tsconfig.json @@ -0,0 +1,4 @@ +{ + "extends": "../../tsconfig.json", + "include": ["src"] +} diff --git a/js/plugins/checks/tsup.config.ts b/js/plugins/checks/tsup.config.ts new file mode 100644 index 000000000..01dce0a6b --- /dev/null +++ b/js/plugins/checks/tsup.config.ts @@ -0,0 +1,22 @@ +/** + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { defineConfig, Options } from 'tsup'; +import { defaultOptions } from '../../tsup.common'; + +export default defineConfig({ + ...(defaultOptions as Options), +}); diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 66d54f4dc..77c9da06a 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -193,6 +193,62 @@ importers: specifier: ^10.0.0 version: 10.0.0 + plugins/checks: + dependencies: + '@anthropic-ai/sdk': + specifier: ^0.24.3 + version: 0.24.3(encoding@0.1.13) + '@anthropic-ai/vertex-sdk': + specifier: ^0.4.0 + version: 0.4.0(encoding@0.1.13) + '@google-cloud/aiplatform': + specifier: ^3.23.0 + version: 3.25.0(encoding@0.1.13) + '@google-cloud/vertexai': + specifier: ^1.1.0 + version: 1.1.0(encoding@0.1.13) + genkit: + specifier: workspace:* + version: link:../../genkit + google-auth-library: + specifier: ^9.6.3 + version: 9.7.0(encoding@0.1.13) + googleapis: + specifier: ^140.0.1 + version: 140.0.1(encoding@0.1.13) + node-fetch: + specifier: ^3.3.2 + version: 3.3.2 + openai: + specifier: ^4.52.7 + version: 4.53.0(encoding@0.1.13) + optionalDependencies: + '@google-cloud/bigquery': + specifier: ^7.8.0 + version: 7.8.0(encoding@0.1.13) + firebase-admin: + specifier: '>=12.2' + version: 12.3.1(encoding@0.1.13) + devDependencies: + '@types/node': + specifier: ^20.11.16 + version: 20.11.30 + npm-run-all: + specifier: ^4.1.5 + version: 4.1.5 + rimraf: + specifier: ^6.0.1 + version: 6.0.1 + tsup: + specifier: ^8.0.2 + version: 8.0.2(postcss@8.4.47)(typescript@4.9.5) + tsx: + specifier: ^4.7.0 + version: 4.7.1 + typescript: + specifier: ^4.9.0 + version: 4.9.5 + plugins/chroma: dependencies: chromadb: From 92c95523bfbdec7e3c2234f0b54f1ba8f7836cf9 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Tue, 29 Oct 2024 17:52:10 +0000 Subject: [PATCH 02/30] Removed most of the vertex code. --- js/plugins/checks/src/anthropic.ts | 422 ------------- js/plugins/checks/src/embedder.ts | 155 ----- js/plugins/checks/src/gemini.ts | 554 ------------------ js/plugins/checks/src/imagen.ts | 309 ---------- js/plugins/checks/src/index.ts | 169 ------ js/plugins/checks/src/model_garden.ts | 123 ---- js/plugins/checks/src/openai_compatibility.ts | 350 ----------- js/plugins/checks/src/reranker.ts | 159 ----- .../checks/src/vector-search/bigquery.ts | 131 ----- .../checks/src/vector-search/firestore.ts | 87 --- js/plugins/checks/src/vector-search/index.ts | 36 -- .../checks/src/vector-search/indexers.ts | 120 ---- .../vector-search/query_public_endpoint.ts | 92 --- .../checks/src/vector-search/retrievers.ts | 136 ----- js/plugins/checks/src/vector-search/types.ts | 189 ------ .../src/vector-search/upsert_datapoints.ts | 71 --- js/plugins/checks/src/vector-search/utils.ts | 65 -- 17 files changed, 3168 deletions(-) delete mode 100644 js/plugins/checks/src/anthropic.ts delete mode 100644 js/plugins/checks/src/embedder.ts delete mode 100644 js/plugins/checks/src/gemini.ts delete mode 100644 js/plugins/checks/src/imagen.ts delete mode 100644 js/plugins/checks/src/model_garden.ts delete mode 100644 js/plugins/checks/src/openai_compatibility.ts delete mode 100644 js/plugins/checks/src/reranker.ts delete mode 100644 js/plugins/checks/src/vector-search/bigquery.ts delete mode 100644 js/plugins/checks/src/vector-search/firestore.ts delete mode 100644 js/plugins/checks/src/vector-search/index.ts delete mode 100644 js/plugins/checks/src/vector-search/indexers.ts delete mode 100644 js/plugins/checks/src/vector-search/query_public_endpoint.ts delete mode 100644 js/plugins/checks/src/vector-search/retrievers.ts delete mode 100644 js/plugins/checks/src/vector-search/types.ts delete mode 100644 js/plugins/checks/src/vector-search/upsert_datapoints.ts delete mode 100644 js/plugins/checks/src/vector-search/utils.ts diff --git a/js/plugins/checks/src/anthropic.ts b/js/plugins/checks/src/anthropic.ts deleted file mode 100644 index a28ea12d4..000000000 --- a/js/plugins/checks/src/anthropic.ts +++ /dev/null @@ -1,422 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - ContentBlock as AnthropicContent, - ImageBlockParam, - Message, - MessageCreateParamsBase, - MessageParam, - TextBlock, - TextBlockParam, - TextDelta, - Tool, - ToolResultBlockParam, - ToolUseBlock, - ToolUseBlockParam, -} from '@anthropic-ai/sdk/resources/messages'; -import { AnthropicVertex } from '@anthropic-ai/vertex-sdk'; -import { - GENKIT_CLIENT_HEADER, - GenerateRequest, - Genkit, - Part as GenkitPart, - MessageData, - ModelReference, - ModelResponseData, - Part, - z, -} from 'genkit'; -import { - GenerationCommonConfigSchema, - getBasicUsageStats, - modelRef, -} from 'genkit/model'; - -export const AnthropicConfigSchema = GenerationCommonConfigSchema.extend({ - location: z.string().optional(), -}); - -export const claude35Sonnet = modelRef({ - name: 'vertexai/claude-3-5-sonnet', - info: { - label: 'Vertex AI Model Garden - Claude 3.5 Sonnet', - versions: ['claude-3-5-sonnet@20240620'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - output: ['text'], - }, - }, - configSchema: AnthropicConfigSchema, -}); - -export const claude3Sonnet = modelRef({ - name: 'vertexai/claude-3-sonnet', - info: { - label: 'Vertex AI Model Garden - Claude 3 Sonnet', - versions: ['claude-3-sonnet@20240229'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - output: ['text'], - }, - }, - configSchema: AnthropicConfigSchema, -}); - -export const claude3Haiku = modelRef({ - name: 'vertexai/claude-3-haiku', - info: { - label: 'Vertex AI Model Garden - Claude 3 Haiku', - versions: ['claude-3-haiku@20240307'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - output: ['text'], - }, - }, - configSchema: AnthropicConfigSchema, -}); - -export const claude3Opus = modelRef({ - name: 'vertexai/claude-3-opus', - info: { - label: 'Vertex AI Model Garden - Claude 3 Opus', - versions: ['claude-3-opus@20240229'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - output: ['text'], - }, - }, - configSchema: AnthropicConfigSchema, -}); - -export const SUPPORTED_ANTHROPIC_MODELS: Record< - string, - ModelReference -> = { - 'claude-3-5-sonnet': claude35Sonnet, - 'claude-3-sonnet': claude3Sonnet, - 'claude-3-opus': claude3Opus, - 'claude-3-haiku': claude3Haiku, -}; - -export function toAnthropicRequest( - model: string, - input: GenerateRequest -): MessageCreateParamsBase { - let system: string | undefined = undefined; - const messages: MessageParam[] = []; - for (const msg of input.messages) { - if (msg.role === 'system') { - system = msg.content - .map((c) => { - if (!c.text) { - throw new Error( - 'Only text context is supported for system messages.' - ); - } - return c.text; - }) - .join(); - } - // If the last message is a tool response, we need to add a user message. - // https://docs.anthropic.com/en/docs/build-with-claude/tool-use#handling-tool-use-and-tool-result-content-blocks - else if (msg.content[msg.content.length - 1].toolResponse) { - messages.push({ - role: 'user', - content: toAnthropicContent(msg.content), - }); - } else { - messages.push({ - role: toAnthropicRole(msg.role), - content: toAnthropicContent(msg.content), - }); - } - } - const request = { - model, - messages, - // https://docs.anthropic.com/claude/docs/models-overview#model-comparison - max_tokens: input.config?.maxOutputTokens ?? 4096, - } as MessageCreateParamsBase; - if (system) { - request['system'] = system; - } - if (input.tools) { - request.tools = input.tools?.map((tool) => { - return { - name: tool.name, - description: tool.description, - input_schema: tool.inputSchema, - }; - }) as Array; - } - if (input.config?.stopSequences) { - request.stop_sequences = input.config?.stopSequences; - } - if (input.config?.temperature) { - request.temperature = input.config?.temperature; - } - if (input.config?.topK) { - request.top_k = input.config?.topK; - } - if (input.config?.topP) { - request.top_p = input.config?.topP; - } - return request; -} - -function toAnthropicContent( - content: GenkitPart[] -): Array< - TextBlockParam | ImageBlockParam | ToolUseBlockParam | ToolResultBlockParam -> { - return content.map((p) => { - if (p.text) { - return { - type: 'text', - text: p.text, - }; - } - if (p.media) { - let b64Data = p.media.url; - if (b64Data.startsWith('data:')) { - b64Data = b64Data.substring(b64Data.indexOf(',')! + 1); - } - - return { - type: 'image', - source: { - type: 'base64', - data: b64Data, - media_type: p.media.contentType as - | 'image/jpeg' - | 'image/png' - | 'image/gif' - | 'image/webp', - }, - }; - } - if (p.toolRequest) { - return toAnthropicToolRequest(p.toolRequest); - } - if (p.toolResponse) { - return toAnthropicToolResponse(p); - } - throw new Error(`Unsupported content type: ${JSON.stringify(p)}`); - }); -} - -function toAnthropicRole(role): 'user' | 'assistant' { - if (role === 'model') { - return 'assistant'; - } - if (role === 'user') { - return 'user'; - } - if (role === 'tool') { - return 'assistant'; - } - throw new Error(`Unsupported role type ${role}`); -} - -function fromAnthropicTextPart(part: TextBlock): Part { - return { - text: part.text, - }; -} - -function fromAnthropicToolCallPart(part: ToolUseBlock): Part { - return { - toolRequest: { - name: part.name, - input: part.input, - ref: part.id, - }, - }; -} - -// Converts an Anthropic part to a Genkit part. -function fromAnthropicPart(part: AnthropicContent): Part { - if (part.type === 'text') return fromAnthropicTextPart(part); - if (part.type === 'tool_use') return fromAnthropicToolCallPart(part); - throw new Error( - 'Part type is unsupported/corrupted. Either data is missing or type cannot be inferred from type.' - ); -} - -// Converts an Anthropic response to a Genkit response. -export function fromAnthropicResponse( - input: GenerateRequest, - response: Message -): ModelResponseData { - const parts = response.content as AnthropicContent[]; - const message: MessageData = { - role: 'model', - content: parts.map(fromAnthropicPart), - }; - return { - message, - finishReason: toGenkitFinishReason( - response.stop_reason as - | 'end_turn' - | 'max_tokens' - | 'stop_sequence' - | 'tool_use' - | null - ), - custom: { - id: response.id, - model: response.model, - type: response.type, - }, - usage: { - ...getBasicUsageStats(input.messages, message), - inputTokens: response.usage.input_tokens, - outputTokens: response.usage.output_tokens, - }, - }; -} - -function toGenkitFinishReason( - reason: 'end_turn' | 'max_tokens' | 'stop_sequence' | 'tool_use' | null -): ModelResponseData['finishReason'] { - switch (reason) { - case 'end_turn': - return 'stop'; - case 'max_tokens': - return 'length'; - case 'stop_sequence': - return 'stop'; - case 'tool_use': - return 'stop'; - case null: - return 'unknown'; - default: - return 'other'; - } -} - -function toAnthropicToolRequest(tool: Record): ToolUseBlock { - if (!tool.name) { - throw new Error('Tool name is required'); - } - // Validate the tool name, Anthropic only supports letters, numbers, and underscores. - // https://docs.anthropic.com/en/docs/build-with-claude/tool-use#specifying-tools - if (!/^[a-zA-Z0-9_-]{1,64}$/.test(tool.name)) { - throw new Error( - `Tool name ${tool.name} contains invalid characters. - Only letters, numbers, and underscores are allowed, - and the name must be between 1 and 64 characters long.` - ); - } - const declaration: ToolUseBlock = { - type: 'tool_use', - id: tool.ref, - name: tool.name, - input: tool.input, - }; - return declaration; -} - -function toAnthropicToolResponse(part: Part): ToolResultBlockParam { - if (!part.toolResponse?.ref) { - throw new Error('Tool response reference is required'); - } - - if (!part.toolResponse.output) { - throw new Error('Tool response output is required'); - } - - return { - type: 'tool_result', - tool_use_id: part.toolResponse.ref, - content: JSON.stringify(part.toolResponse.output), - }; -} - -export function anthropicModel( - ai: Genkit, - modelName: string, - projectId: string, - region: string -) { - const clients: Record = {}; - const clientFactory = (region: string): AnthropicVertex => { - if (!clients[region]) { - clients[region] = new AnthropicVertex({ - region, - projectId, - defaultHeaders: { - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - }); - } - return clients[region]; - }; - const model = SUPPORTED_ANTHROPIC_MODELS[modelName]; - if (!model) { - throw new Error(`unsupported Anthropic model name ${modelName}`); - } - - return ai.defineModel( - { - name: model.name, - label: model.info?.label, - configSchema: AnthropicConfigSchema, - supports: model.info?.supports, - versions: model.info?.versions, - }, - async (input, streamingCallback) => { - const client = clientFactory(input.config?.location || region); - if (!streamingCallback) { - const response = await client.messages.create({ - ...toAnthropicRequest(input.config?.version ?? modelName, input), - stream: false, - }); - return fromAnthropicResponse(input, response); - } else { - const stream = await client.messages.stream( - toAnthropicRequest(input.config?.version ?? modelName, input) - ); - for await (const event of stream) { - if (event.type === 'content_block_delta') { - streamingCallback({ - index: 0, - content: [ - { - text: (event.delta as TextDelta).text, - }, - ], - }); - } - } - return fromAnthropicResponse(input, await stream.finalMessage()); - } - } - ); -} diff --git a/js/plugins/checks/src/embedder.ts b/js/plugins/checks/src/embedder.ts deleted file mode 100644 index 10d2ca18c..000000000 --- a/js/plugins/checks/src/embedder.ts +++ /dev/null @@ -1,155 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Genkit, z } from 'genkit'; -import { EmbedderReference, embedderRef } from 'genkit/embedder'; -import { GoogleAuth } from 'google-auth-library'; -import { PluginOptions } from './index.js'; -import { PredictClient, predictModel } from './predict.js'; - -export const TaskTypeSchema = z.enum([ - 'RETRIEVAL_DOCUMENT', - 'RETRIEVAL_QUERY', - 'SEMANTIC_SIMILARITY', - 'CLASSIFICATION', - 'CLUSTERING', -]); - -export type TaskType = z.infer; - -export const VertexEmbeddingConfigSchema = z.object({ - /** - * The `task_type` parameter is defined as the intended downstream application to help the model - * produce better quality embeddings. - **/ - taskType: TaskTypeSchema.optional(), - title: z.string().optional(), - location: z.string().optional(), - version: z.string().optional(), -}); - -export type VertexEmbeddingConfig = z.infer; - -function commonRef( - name: string, - input?: ('text' | 'image')[] -): EmbedderReference { - return embedderRef({ - name: `vertexai/${name}`, - configSchema: VertexEmbeddingConfigSchema, - info: { - dimensions: 768, - label: `Vertex AI - ${name}`, - supports: { - input: input ?? ['text'], - }, - }, - }); -} - -export const textEmbeddingGecko003 = commonRef('textembedding-gecko@003'); -export const textEmbedding004 = commonRef('text-embedding-004'); -export const textEmbeddingGeckoMultilingual001 = commonRef( - 'textembedding-gecko-multilingual@001' -); -export const textMultilingualEmbedding002 = commonRef( - 'text-multilingual-embedding-002' -); - -export const SUPPORTED_EMBEDDER_MODELS: Record = { - 'textembedding-gecko@003': textEmbeddingGecko003, - 'text-embedding-004': textEmbedding004, - 'textembedding-gecko-multilingual@001': textEmbeddingGeckoMultilingual001, - 'text-multilingual-embedding-002': textMultilingualEmbedding002, - // TODO: add support for multimodal embeddings - // 'multimodalembedding@001': commonRef('multimodalembedding@001', [ - // 'image', - // 'text', - // ]), -}; - -interface EmbeddingInstance { - task_type?: TaskType; - content: string; - title?: string; -} -interface EmbeddingPrediction { - embeddings: { - statistics: { - truncated: boolean; - token_count: number; - }; - values: number[]; - }; -} - -export function defineVertexAIEmbedder( - ai: Genkit, - name: string, - client: GoogleAuth, - options: PluginOptions -) { - const embedder = SUPPORTED_EMBEDDER_MODELS[name]; - const predictClients: Record< - string, - PredictClient - > = {}; - const predictClientFactory = ( - config: VertexEmbeddingConfig - ): PredictClient => { - const requestLocation = config?.location || options.location; - if (!predictClients[requestLocation]) { - // TODO: Figure out how to allow different versions while still sharing a single implementation. - predictClients[requestLocation] = predictModel< - EmbeddingInstance, - EmbeddingPrediction - >( - client, - { - ...options, - location: requestLocation, - }, - name - ); - } - return predictClients[requestLocation]; - }; - - return ai.defineEmbedder( - { - name: embedder.name, - configSchema: embedder.configSchema, - info: embedder.info!, - }, - async (input, options) => { - const predictClient = predictClientFactory(options); - const response = await predictClient( - input.map((i) => { - return { - content: i.text, - task_type: options?.taskType, - title: options?.title, - }; - }) - ); - return { - embeddings: response.predictions.map((p) => ({ - embedding: p.embeddings.values, - })), - }; - } - ); -} diff --git a/js/plugins/checks/src/gemini.ts b/js/plugins/checks/src/gemini.ts deleted file mode 100644 index eccf1a869..000000000 --- a/js/plugins/checks/src/gemini.ts +++ /dev/null @@ -1,554 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - Content, - FunctionDeclaration, - FunctionDeclarationSchemaType, - Part as GeminiPart, - GenerateContentCandidate, - GenerateContentResponse, - GenerateContentResult, - HarmBlockThreshold, - HarmCategory, - StartChatParams, - VertexAI, -} from '@google-cloud/vertexai'; -import { GENKIT_CLIENT_HEADER, Genkit, z } from 'genkit'; -import { - CandidateData, - GenerateRequest, - GenerationCommonConfigSchema, - MediaPart, - MessageData, - ModelAction, - ModelMiddleware, - ModelReference, - Part, - ToolDefinitionSchema, - getBasicUsageStats, - modelRef, -} from 'genkit/model'; -import { - downloadRequestMedia, - simulateSystemPrompt, -} from 'genkit/model/middleware'; -import { PluginOptions } from './index.js'; - -const SafetySettingsSchema = z.object({ - category: z.nativeEnum(HarmCategory), - threshold: z.nativeEnum(HarmBlockThreshold), -}); - -const VertexRetrievalSchema = z.object({ - datastore: z.object({ - projectId: z.string().optional(), - location: z.string().optional(), - dataStoreId: z.string(), - }), - disableAttribution: z.boolean().optional(), -}); - -const GoogleSearchRetrievalSchema = z.object({ - disableAttribution: z.boolean().optional(), -}); - -export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ - safetySettings: z.array(SafetySettingsSchema).optional(), - location: z.string().optional(), - vertexRetrieval: VertexRetrievalSchema.optional(), - googleSearchRetrieval: GoogleSearchRetrievalSchema.optional(), -}); - -export const gemini10Pro = modelRef({ - name: 'vertexai/gemini-1.0-pro', - info: { - label: 'Vertex AI - Gemini Pro', - versions: ['gemini-1.0-pro-001', 'gemini-1.0-pro-002'], - supports: { - multiturn: true, - media: false, - tools: true, - systemRole: true, - }, - }, - configSchema: GeminiConfigSchema, -}); - -export const gemini15Pro = modelRef({ - name: 'vertexai/gemini-1.5-pro', - info: { - label: 'Vertex AI - Gemini 1.5 Pro', - versions: ['gemini-1.5-pro-001', 'gemini-1.5-pro-002'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - }, - }, - configSchema: GeminiConfigSchema, -}); - -export const gemini15Flash = modelRef({ - name: 'vertexai/gemini-1.5-flash', - info: { - label: 'Vertex AI - Gemini 1.5 Flash', - versions: ['gemini-1.5-flash-001', 'gemini-1.5-flash-002'], - supports: { - multiturn: true, - media: true, - tools: true, - systemRole: true, - }, - }, - configSchema: GeminiConfigSchema, -}); - -export const SUPPORTED_V1_MODELS = { - 'gemini-1.0-pro': gemini10Pro, -}; - -export const SUPPORTED_V15_MODELS = { - 'gemini-1.5-pro': gemini15Pro, - 'gemini-1.5-flash': gemini15Flash, -}; - -export const SUPPORTED_GEMINI_MODELS = { - ...SUPPORTED_V1_MODELS, - ...SUPPORTED_V15_MODELS, -}; - -function toGeminiRole( - role: MessageData['role'], - model?: ModelReference -): string { - switch (role) { - case 'user': - return 'user'; - case 'model': - return 'model'; - case 'system': - if (model && SUPPORTED_V15_MODELS[model.name]) { - // We should have already pulled out the supported system messages, - // anything remaining is unsupported; throw an error. - throw new Error( - 'system role is only supported for a single message in the first position' - ); - } else { - throw new Error('system role is not supported'); - } - case 'tool': - return 'function'; - default: - return 'user'; - } -} - -const toGeminiTool = ( - tool: z.infer -): FunctionDeclaration => { - const declaration: FunctionDeclaration = { - name: tool.name.replace(/\//g, '__'), // Gemini throws on '/' in tool name - description: tool.description, - parameters: convertSchemaProperty(tool.inputSchema), - }; - return declaration; -}; - -const toGeminiFileDataPart = (part: MediaPart): GeminiPart => { - const media = part.media; - if (media.url.startsWith('gs://')) { - if (!media.contentType) - throw new Error( - 'Must supply contentType when using media from gs:// URLs.' - ); - return { - fileData: { - mimeType: media.contentType, - fileUri: media.url, - }, - }; - } else if (media.url.startsWith('data:')) { - const dataUrl = media.url; - const b64Data = dataUrl.substring(dataUrl.indexOf(',')! + 1); - const contentType = - media.contentType || - dataUrl.substring(dataUrl.indexOf(':')! + 1, dataUrl.indexOf(';')); - return { inlineData: { mimeType: contentType, data: b64Data } }; - } - - throw Error( - 'Could not convert genkit part to gemini tool response part: missing file data' - ); -}; - -const toGeminiToolRequestPart = (part: Part): GeminiPart => { - if (!part?.toolRequest?.input) { - throw Error( - 'Could not convert genkit part to gemini tool response part: missing tool request data' - ); - } - return { - functionCall: { - name: part.toolRequest.name, - args: part.toolRequest.input, - }, - }; -}; - -const toGeminiToolResponsePart = (part: Part): GeminiPart => { - if (!part?.toolResponse?.output) { - throw Error( - 'Could not convert genkit part to gemini tool response part: missing tool response data' - ); - } - return { - functionResponse: { - name: part.toolResponse.name, - response: { - name: part.toolResponse.name, - content: part.toolResponse.output, - }, - }, - }; -}; - -export function toGeminiSystemInstruction(message: MessageData): Content { - return { - role: 'user', - parts: message.content.map(toGeminiPart), - }; -} - -export function toGeminiMessage( - message: MessageData, - model?: ModelReference -): Content { - return { - role: toGeminiRole(message.role, model), - parts: message.content.map(toGeminiPart), - }; -} - -function fromGeminiFinishReason( - reason: GenerateContentCandidate['finishReason'] -): CandidateData['finishReason'] { - if (!reason) return 'unknown'; - switch (reason) { - case 'STOP': - return 'stop'; - case 'MAX_TOKENS': - return 'length'; - case 'SAFETY': // blocked for safety - case 'RECITATION': // blocked for reciting training data - return 'blocked'; - default: - return 'unknown'; - } -} - -function toGeminiPart(part: Part): GeminiPart { - if (part.text) { - return { text: part.text }; - } else if (part.media) { - return toGeminiFileDataPart(part); - } else if (part.toolRequest) { - return toGeminiToolRequestPart(part); - } else if (part.toolResponse) { - return toGeminiToolResponsePart(part); - } else { - throw new Error('unsupported type'); - } -} - -function fromGeminiInlineDataPart(part: GeminiPart): MediaPart { - // Check if the required properties exist - if ( - !part.inlineData || - !part.inlineData.hasOwnProperty('mimeType') || - !part.inlineData.hasOwnProperty('data') - ) { - throw new Error('Invalid GeminiPart: missing required properties'); - } - const { mimeType, data } = part.inlineData; - // Combine data and mimeType into a data URL - const dataUrl = `data:${mimeType};base64,${data}`; - return { - media: { - url: dataUrl, - contentType: mimeType, - }, - }; -} - -function fromGeminiFileDataPart(part: GeminiPart): MediaPart { - if ( - !part.fileData || - !part.fileData.hasOwnProperty('mimeType') || - !part.fileData.hasOwnProperty('url') - ) { - throw new Error( - 'Invalid Gemini File Data Part: missing required properties' - ); - } - - return { - media: { - url: part.fileData?.fileUri, - contentType: part.fileData?.mimeType, - }, - }; -} - -function fromGeminiFunctionCallPart(part: GeminiPart): Part { - if (!part.functionCall) { - throw new Error( - 'Invalid Gemini Function Call Part: missing function call data' - ); - } - return { - toolRequest: { - name: part.functionCall.name, - input: part.functionCall.args, - }, - }; -} - -function fromGeminiFunctionResponsePart(part: GeminiPart): Part { - if (!part.functionResponse) { - throw new Error( - 'Invalid Gemini Function Call Part: missing function call data' - ); - } - return { - toolResponse: { - name: part.functionResponse.name.replace(/__/g, '/'), // restore slashes - output: part.functionResponse.response, - }, - }; -} - -// Converts vertex part to genkit part -function fromGeminiPart(part: GeminiPart): Part { - if (part.text !== undefined) return { text: part.text }; - if (part.functionCall) return fromGeminiFunctionCallPart(part); - if (part.functionResponse) return fromGeminiFunctionResponsePart(part); - if (part.inlineData) return fromGeminiInlineDataPart(part); - if (part.fileData) return fromGeminiFileDataPart(part); - throw new Error( - 'Part type is unsupported/corrupted. Either data is missing or type cannot be inferred from type.' - ); -} - -export function fromGeminiCandidate( - candidate: GenerateContentCandidate -): CandidateData { - const parts = candidate.content.parts || []; - const genkitCandidate: CandidateData = { - index: candidate.index || 0, // reasonable default? - message: { - role: 'model', - content: parts.map(fromGeminiPart), - }, - finishReason: fromGeminiFinishReason(candidate.finishReason), - finishMessage: candidate.finishMessage, - custom: { - safetyRatings: candidate.safetyRatings, - citationMetadata: candidate.citationMetadata, - }, - }; - return genkitCandidate; -} - -// Translate JSON schema to Vertex AI's format. Specifically, the type field needs be mapped. -// Since JSON schemas can include nested arrays/objects, we have to recursively map the type field -// in all nested fields. -const convertSchemaProperty = (property) => { - if (!property || !property.type) { - return null; - } - if (property.type === 'object') { - const nestedProperties = {}; - Object.keys(property.properties).forEach((key) => { - nestedProperties[key] = convertSchemaProperty(property.properties[key]); - }); - return { - type: FunctionDeclarationSchemaType.OBJECT, - properties: nestedProperties, - required: property.required, - }; - } else if (property.type === 'array') { - return { - type: FunctionDeclarationSchemaType.ARRAY, - items: convertSchemaProperty(property.items), - }; - } else { - return { - type: FunctionDeclarationSchemaType[property.type.toUpperCase()], - }; - } -}; - -/** - * Define a Vertex AI Gemini model. - */ -export function defineGeminiModel( - ai: Genkit, - name: string, - vertexClientFactory: ( - request: GenerateRequest - ) => VertexAI, - options: PluginOptions -): ModelAction { - const modelName = `vertexai/${name}`; - - const model: ModelReference = SUPPORTED_GEMINI_MODELS[name]; - if (!model) throw new Error(`Unsupported model: ${name}`); - - const middlewares: ModelMiddleware[] = []; - if (SUPPORTED_V1_MODELS[name]) { - middlewares.push(simulateSystemPrompt()); - } - if (model?.info?.supports?.media) { - // the gemini api doesn't support downloading media from http(s) - middlewares.push(downloadRequestMedia({ maxBytes: 1024 * 1024 * 20 })); - } - - return ai.defineModel( - { - name: modelName, - ...model.info, - configSchema: GeminiConfigSchema, - use: middlewares, - }, - async (request, streamingCallback) => { - const vertex = vertexClientFactory(request); - const client = vertex.preview.getGenerativeModel( - { - model: request.config?.version || model.version || name, - }, - { - apiClient: GENKIT_CLIENT_HEADER, - } - ); - - // make a copy so that modifying the request will not produce side-effects - const messages = [...request.messages]; - if (messages.length === 0) throw new Error('No messages provided.'); - - // Gemini does not support messages with role system and instead expects - // systemInstructions to be provided as a separate input. The first - // message detected with role=system will be used for systemInstructions. - // Any additional system messages may be considered to be "exceptional". - let systemInstruction: Content | undefined = undefined; - if (SUPPORTED_V15_MODELS[name]) { - const systemMessage = messages.find((m) => m.role === 'system'); - if (systemMessage) { - messages.splice(messages.indexOf(systemMessage), 1); - systemInstruction = toGeminiSystemInstruction(systemMessage); - } - } - - const chatRequest: StartChatParams = { - systemInstruction, - tools: request.tools?.length - ? [{ functionDeclarations: request.tools?.map(toGeminiTool) }] - : [], - history: messages - .slice(0, -1) - .map((message) => toGeminiMessage(message, model)), - generationConfig: { - candidateCount: request.candidates || undefined, - temperature: request.config?.temperature, - maxOutputTokens: request.config?.maxOutputTokens, - topK: request.config?.topK, - topP: request.config?.topP, - stopSequences: request.config?.stopSequences, - }, - safetySettings: request.config?.safetySettings, - }; - if (request.config?.googleSearchRetrieval) { - chatRequest.tools?.push({ - googleSearchRetrieval: request.config.googleSearchRetrieval, - }); - } - if (request.config?.vertexRetrieval) { - // https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/ground-gemini#ground-gemini - const vertexRetrieval = request.config.vertexRetrieval; - const _projectId = - vertexRetrieval.datastore.projectId || options.projectId; - const _location = - vertexRetrieval.datastore.location || options.location; - const _dataStoreId = vertexRetrieval.datastore.dataStoreId; - const datastore = `projects/${_projectId}/locations/${_location}/collections/default_collection/dataStores/${_dataStoreId}`; - chatRequest.tools?.push({ - retrieval: { - vertexAiSearch: { - datastore, - }, - disableAttribution: vertexRetrieval.disableAttribution, - }, - }); - } - const msg = toGeminiMessage(messages[messages.length - 1], model); - if (streamingCallback) { - const result = await client - .startChat(chatRequest) - .sendMessageStream(msg.parts); - for await (const item of result.stream) { - (item as GenerateContentResponse).candidates?.forEach((candidate) => { - const c = fromGeminiCandidate(candidate); - streamingCallback({ - index: c.index, - content: c.message.content, - }); - }); - } - const response = await result.response; - if (!response.candidates?.length) { - throw new Error('No valid candidates returned.'); - } - return { - candidates: response.candidates?.map(fromGeminiCandidate) || [], - custom: response, - }; - } else { - let result: GenerateContentResult | undefined; - try { - result = await client.startChat(chatRequest).sendMessage(msg.parts); - } catch (err) { - throw new Error(`Vertex response generation failed: ${err}`); - } - if (!result?.response.candidates?.length) { - throw new Error('No valid candidates returned.'); - } - const responseCandidates = - result.response.candidates?.map(fromGeminiCandidate) || []; - return { - candidates: responseCandidates, - custom: result.response, - usage: { - ...getBasicUsageStats(request.messages, responseCandidates), - inputTokens: result.response.usageMetadata?.promptTokenCount, - outputTokens: result.response.usageMetadata?.candidatesTokenCount, - totalTokens: result.response.usageMetadata?.totalTokenCount, - }, - }; - } - } - ); -} diff --git a/js/plugins/checks/src/imagen.ts b/js/plugins/checks/src/imagen.ts deleted file mode 100644 index 12f11fd13..000000000 --- a/js/plugins/checks/src/imagen.ts +++ /dev/null @@ -1,309 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Genkit, z } from 'genkit'; -import { - CandidateData, - GenerateRequest, - GenerationCommonConfigSchema, - ModelReference, - getBasicUsageStats, - modelRef, -} from 'genkit/model'; -import { GoogleAuth } from 'google-auth-library'; -import { PluginOptions } from './index.js'; -import { PredictClient, predictModel } from './predict.js'; - -const ImagenConfigSchema = GenerationCommonConfigSchema.extend({ - /** Language of the prompt text. */ - language: z - .enum(['auto', 'en', 'es', 'hi', 'ja', 'ko', 'pt', 'zh-TW', 'zh', 'zh-CN']) - .optional(), - /** Desired aspect ratio of output image. */ - aspectRatio: z.enum(['1:1', '9:16', '16:9', '3:4', '4:3']).optional(), - /** - * A negative prompt to help generate the images. For example: "animals" - * (removes animals), "blurry" (makes the image clearer), "text" (removes - * text), or "cropped" (removes cropped images). - **/ - negativePrompt: z.string().optional(), - /** - * Any non-negative integer you provide to make output images deterministic. - * Providing the same seed number always results in the same output images. - * Accepted integer values: 1 - 2147483647. - **/ - seed: z.number().optional(), - /** Your GCP project's region. e.g.) us-central1, europe-west2, etc. **/ - location: z.string().optional(), - /** Allow generation of people by the model. */ - personGeneration: z - .enum(['dont_allow', 'allow_adult', 'allow_all']) - .optional(), - /** Adds a filter level to safety filtering. */ - safetySetting: z - .enum(['block_most', 'block_some', 'block_few', 'block_fewest']) - .optional(), - /** Add an invisible watermark to the generated images. */ - addWatermark: z.boolean().optional(), - /** Cloud Storage URI to store the generated images. **/ - storageUri: z.string().optional(), - /** Mode must be set for upscaling requests. */ - mode: z.enum(['upscale']).optional(), - /** - * Describes the editing intention for the request. - * - * Refer to https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/imagen-api#edit_images_2 for details. - */ - editConfig: z - .object({ - /** Describes the editing intention for the request. */ - editMode: z - .enum([ - 'inpainting-insert', - 'inpainting-remove', - 'outpainting', - 'product-image', - ]) - .optional(), - /** Prompts the model to generate a mask instead of you needing to provide one. Consequently, when you provide this parameter you can omit a mask object. */ - maskMode: z - .object({ - maskType: z.enum(['background', 'foreground', 'semantic']), - classes: z.array(z.number()).optional(), - }) - .optional(), - maskDilation: z.number().optional(), - guidanceScale: z.number().optional(), - productPosition: z.enum(['reposition', 'fixed']).optional(), - }) - .passthrough() - .optional(), - /** Upscale config object. */ - upscaleConfig: z.object({ upscaleFactor: z.enum(['x2', 'x4']) }).optional(), -}).passthrough(); - -export const imagen2 = modelRef({ - name: 'vertexai/imagen2', - info: { - label: 'Vertex AI - Imagen2', - versions: ['imagegeneration@006', 'imagegeneration@005'], - supports: { - media: false, - multiturn: false, - tools: false, - systemRole: false, - output: ['media'], - }, - }, - version: 'imagegeneration@006', - configSchema: ImagenConfigSchema, -}); - -export const imagen3 = modelRef({ - name: 'vertexai/imagen3', - info: { - label: 'Vertex AI - Imagen3', - versions: ['imagen-3.0-generate-001'], - supports: { - media: true, - multiturn: false, - tools: false, - systemRole: false, - output: ['media'], - }, - }, - version: 'imagen-3.0-generate-001', - configSchema: ImagenConfigSchema, -}); - -export const imagen3Fast = modelRef({ - name: 'vertexai/imagen3-fast', - info: { - label: 'Vertex AI - Imagen3 Fast', - versions: ['imagen-3.0-fast-generate-001'], - supports: { - media: false, - multiturn: false, - tools: false, - systemRole: false, - output: ['media'], - }, - }, - version: 'imagen-3.0-fast-generate-001', - configSchema: ImagenConfigSchema, -}); - -export const SUPPORTED_IMAGEN_MODELS = { - imagen2: imagen2, - imagen3: imagen3, - 'imagen3-fast': imagen3Fast, -}; - -function extractText(request: GenerateRequest) { - return request.messages - .at(-1)! - .content.map((c) => c.text || '') - .join(''); -} - -interface ImagenParameters { - sampleCount?: number; - aspectRatio?: string; - negativePrompt?: string; - seed?: number; - language?: string; - personGeneration?: string; - safetySetting?: string; - addWatermark?: boolean; - storageUri?: string; -} - -function toParameters( - request: GenerateRequest -): ImagenParameters { - const out = { - sampleCount: request.candidates ?? 1, - ...request?.config, - }; - - for (const k in out) { - if (!out[k]) delete out[k]; - } - - return out; -} - -function extractMaskImage(request: GenerateRequest): string | undefined { - return request.messages - .at(-1) - ?.content.find((p) => !!p.media && p.metadata?.type === 'mask') - ?.media?.url.split(',')[1]; -} - -function extractBaseImage(request: GenerateRequest): string | undefined { - return request.messages - .at(-1) - ?.content.find( - (p) => !!p.media && (!p.metadata?.type || p.metadata?.type === 'base') - ) - ?.media?.url.split(',')[1]; -} - -interface ImagenPrediction { - bytesBase64Encoded: string; - mimeType: string; -} - -interface ImagenInstance { - prompt: string; - image?: { bytesBase64Encoded: string }; - mask?: { image?: { bytesBase64Encoded: string } }; -} - -export function imagenModel( - ai: Genkit, - name: string, - client: GoogleAuth, - options: PluginOptions -) { - const modelName = `vertexai/${name}`; - const model: ModelReference = SUPPORTED_IMAGEN_MODELS[name]; - if (!model) throw new Error(`Unsupported model: ${name}`); - - const predictClients: Record< - string, - PredictClient - > = {}; - const predictClientFactory = ( - request: GenerateRequest - ): PredictClient => { - const requestLocation = request.config?.location || options.location; - if (!predictClients[requestLocation]) { - predictClients[requestLocation] = predictModel< - ImagenInstance, - ImagenPrediction, - ImagenParameters - >( - client, - { - ...options, - location: requestLocation, - }, - request.config?.version || model.version || name - ); - } - return predictClients[requestLocation]; - }; - - return ai.defineModel( - { - name: modelName, - ...model.info, - configSchema: ImagenConfigSchema, - }, - async (request) => { - const instance: ImagenInstance = { - prompt: extractText(request), - }; - const baseImage = extractBaseImage(request); - if (baseImage) { - instance.image = { bytesBase64Encoded: baseImage }; - } - const maskImage = extractMaskImage(request); - if (maskImage) { - instance.mask = { - image: { bytesBase64Encoded: maskImage }, - }; - } - - const req: any = { - instances: [instance], - parameters: toParameters(request), - }; - - const predictClient = predictClientFactory(request); - const response = await predictClient([instance], toParameters(request)); - - const candidates: CandidateData[] = response.predictions.map((p, i) => { - const b64data = p.bytesBase64Encoded; - const mimeType = p.mimeType; - return { - index: i, - finishReason: 'stop', - message: { - role: 'model', - content: [ - { - media: { - url: `data:${mimeType};base64,${b64data}`, - contentType: mimeType, - }, - }, - ], - }, - }; - }); - return { - candidates, - usage: { - ...getBasicUsageStats(request.messages, candidates), - custom: { generations: candidates.length }, - }, - custom: response, - }; - } - ); -} diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index 088d9212d..8eb8d9ff1 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -14,93 +14,16 @@ * limitations under the License. */ -import { VertexAI } from '@google-cloud/vertexai'; import { Genkit, z } from 'genkit'; -import { GenerateRequest, ModelReference } from 'genkit/model'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; -import { - SUPPORTED_ANTHROPIC_MODELS, - anthropicModel, - claude35Sonnet, - claude3Haiku, - claude3Opus, - claude3Sonnet, -} from './anthropic.js'; -import { - SUPPORTED_EMBEDDER_MODELS, - defineVertexAIEmbedder, - textEmbedding004, - textEmbeddingGecko003, - textEmbeddingGeckoMultilingual001, - textMultilingualEmbedding002, -} from './embedder.js'; import { VertexAIEvaluationMetric, VertexAIEvaluationMetricType, vertexEvaluators, } from './evaluation.js'; -import { - GeminiConfigSchema, - SUPPORTED_GEMINI_MODELS, - defineGeminiModel, - gemini10Pro, - gemini15Flash, - gemini15Pro, -} from './gemini.js'; -import { - SUPPORTED_IMAGEN_MODELS, - imagen2, - imagen3, - imagen3Fast, - imagenModel, -} from './imagen.js'; -import { - SUPPORTED_OPENAI_FORMAT_MODELS, - llama3, - llama31, - llama32, - modelGardenOpenaiCompatibleModel, -} from './model_garden.js'; -import { VertexRerankerConfig, vertexAiRerankers } from './reranker.js'; -import { - VectorSearchOptions, - vertexAiIndexers, - vertexAiRetrievers, -} from './vector-search/index.js'; -export { - DocumentIndexer, - DocumentRetriever, - Neighbor, - VectorSearchOptions, - getBigQueryDocumentIndexer, - getBigQueryDocumentRetriever, - getFirestoreDocumentIndexer, - getFirestoreDocumentRetriever, - vertexAiIndexerRef, - vertexAiIndexers, - vertexAiRetrieverRef, - vertexAiRetrievers, -} from './vector-search/index.js'; export { VertexAIEvaluationMetricType as VertexAIEvaluationMetricType, - claude35Sonnet, - claude3Haiku, - claude3Opus, - claude3Sonnet, - gemini10Pro, - gemini15Flash, - gemini15Pro, - imagen2, - imagen3, - imagen3Fast, - llama3, - llama31, - llama32, - textEmbedding004, - textEmbeddingGecko003, - textEmbeddingGeckoMultilingual001, - textMultilingualEmbedding002, }; export interface PluginOptions { @@ -114,18 +37,6 @@ export interface PluginOptions { evaluation?: { metrics: VertexAIEvaluationMetric[]; }; - /** - * @deprecated use `modelGarden.models` - */ - modelGardenModels?: ModelReference[]; - modelGarden?: { - models: ModelReference[]; - openAiBaseUrlTemplate?: string; - }; - /** Configure Vertex AI vector search index options */ - vectorSearchOptions?: VectorSearchOptions[]; - /** Configure reranker options */ - rerankOptions?: VertexRerankerConfig[]; } const CLOUD_PLATFROM_OAUTH_SCOPE = @@ -171,90 +82,10 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { throw confError('project', 'GCLOUD_PROJECT'); } - const vertexClientFactoryCache: Record = {}; - const vertexClientFactory = ( - request: GenerateRequest - ): VertexAI => { - const requestLocation = request.config?.location || location; - if (!vertexClientFactoryCache[requestLocation]) { - vertexClientFactoryCache[requestLocation] = new VertexAI({ - project: projectId, - location: requestLocation, - googleAuthOptions: authOptions, - }); - } - return vertexClientFactoryCache[requestLocation]; - }; const metrics = options?.evaluation && options.evaluation.metrics.length > 0 ? options.evaluation.metrics : []; - - Object.keys(SUPPORTED_IMAGEN_MODELS).map((name) => - imagenModel(ai, name, authClient, { projectId, location }) - ); - Object.keys(SUPPORTED_GEMINI_MODELS).map((name) => - defineGeminiModel(ai, name, vertexClientFactory, { projectId, location }) - ); - - if (options?.modelGardenModels || options?.modelGarden?.models) { - const mgModels = - options?.modelGardenModels || options?.modelGarden?.models; - mgModels!.forEach((m) => { - const anthropicEntry = Object.entries(SUPPORTED_ANTHROPIC_MODELS).find( - ([_, value]) => value.name === m.name - ); - if (anthropicEntry) { - anthropicModel(ai, anthropicEntry[0], projectId, location); - return; - } - const openaiModel = Object.entries(SUPPORTED_OPENAI_FORMAT_MODELS).find( - ([_, value]) => value.name === m.name - ); - if (openaiModel) { - modelGardenOpenaiCompatibleModel( - ai, - openaiModel[0], - projectId, - location, - authClient, - options.modelGarden?.openAiBaseUrlTemplate - ); - return; - } - throw new Error(`Unsupported model garden model: ${m.name}`); - }); - } - - const embedders = Object.keys(SUPPORTED_EMBEDDER_MODELS).map((name) => - defineVertexAIEmbedder(ai, name, authClient, { projectId, location }) - ); - - if ( - options?.vectorSearchOptions && - options.vectorSearchOptions.length > 0 - ) { - const defaultEmbedder = embedders[0]; - - vertexAiIndexers(ai, { - pluginOptions: options, - authClient, - defaultEmbedder, - }); - - vertexAiRetrievers(ai, { - pluginOptions: options, - authClient, - defaultEmbedder, - }); - } - - const rerankOptions = { - pluginOptions: options, - authClient, - projectId, - }; - await vertexAiRerankers(ai, rerankOptions); vertexEvaluators(ai, authClient, metrics, projectId, location); }); } diff --git a/js/plugins/checks/src/model_garden.ts b/js/plugins/checks/src/model_garden.ts deleted file mode 100644 index eec87274c..000000000 --- a/js/plugins/checks/src/model_garden.ts +++ /dev/null @@ -1,123 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; -import { GenerateRequest, ModelAction, modelRef } from 'genkit/model'; -import { GoogleAuth } from 'google-auth-library'; -import OpenAI from 'openai'; -import { - openaiCompatibleModel, - OpenAIConfigSchema, -} from './openai_compatibility.js'; - -export const ModelGardenModelConfigSchema = OpenAIConfigSchema.extend({ - location: z.string().optional(), -}); - -export const llama31 = modelRef({ - name: 'vertexai/llama-3.1', - info: { - label: 'Llama 3.1', - supports: { - multiturn: true, - tools: true, - media: false, - systemRole: true, - output: ['text', 'json'], - }, - versions: [ - 'meta/llama3-405b-instruct-maas', - // 8b and 70b versions are coming soon - ], - }, - configSchema: ModelGardenModelConfigSchema, - version: 'meta/llama3-405b-instruct-maas', -}); - -export const llama32 = modelRef({ - name: 'vertexai/llama-3.2', - info: { - label: 'Llama 3.2', - supports: { - multiturn: true, - tools: true, - media: true, - systemRole: true, - output: ['text', 'json'], - }, - versions: ['meta/llama-3.2-90b-vision-instruct-maas'], - }, - configSchema: ModelGardenModelConfigSchema, - version: 'meta/llama-3.2-90b-vision-instruct-maas', -}); - -/** - * @deprecated use `llama31` instead - */ -export const llama3 = modelRef({ - name: 'vertexai/llama3-405b', - info: { - label: 'Llama 3.1 405b', - supports: { - multiturn: true, - tools: true, - media: false, - systemRole: true, - output: ['text'], - }, - versions: ['meta/llama3-405b-instruct-maas'], - }, - configSchema: ModelGardenModelConfigSchema, - version: 'meta/llama3-405b-instruct-maas', -}); - -export const SUPPORTED_OPENAI_FORMAT_MODELS = { - 'llama3-405b': llama3, - 'llama-3.1': llama31, - 'llama-3.2': llama32, -}; - -export function modelGardenOpenaiCompatibleModel( - ai: Genkit, - name: string, - projectId: string, - location: string, - googleAuth: GoogleAuth, - baseUrlTemplate: string | undefined -): ModelAction { - const model = SUPPORTED_OPENAI_FORMAT_MODELS[name]; - if (!model) throw new Error(`Unsupported model: ${name}`); - if (!baseUrlTemplate) { - baseUrlTemplate = - 'https://{location}-aiplatform.googleapis.com/v1beta1/projects/{projectId}/locations/{location}/endpoints/openapi'; - } - - const clientFactory = async ( - request: GenerateRequest - ): Promise => { - const requestLocation = request.config?.location || location; - return new OpenAI({ - baseURL: baseUrlTemplate! - .replace(/{location}/g, requestLocation) - .replace(/{projectId}/g, projectId), - apiKey: (await googleAuth.getAccessToken())!, - defaultHeaders: { - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - }); - }; - return openaiCompatibleModel(ai, model, clientFactory); -} diff --git a/js/plugins/checks/src/openai_compatibility.ts b/js/plugins/checks/src/openai_compatibility.ts deleted file mode 100644 index 2de914f57..000000000 --- a/js/plugins/checks/src/openai_compatibility.ts +++ /dev/null @@ -1,350 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Genkit, Message, StreamingCallback, z } from 'genkit'; -import { - GenerateResponseChunkData, - GenerateResponseData, - GenerationCommonConfigSchema, - ModelAction, - ModelReference, - type CandidateData, - type GenerateRequest, - type MessageData, - type Part, - type Role, - type ToolDefinition, - type ToolRequestPart, -} from 'genkit/model'; -import OpenAI from 'openai'; -import { - type ChatCompletion, - type ChatCompletionChunk, - type ChatCompletionContentPart, - type ChatCompletionCreateParamsNonStreaming, - type ChatCompletionMessageParam, - type ChatCompletionMessageToolCall, - type ChatCompletionRole, - type ChatCompletionTool, - type CompletionChoice, -} from 'openai/resources/index.mjs'; - -export const OpenAIConfigSchema = GenerationCommonConfigSchema.extend({ - frequencyPenalty: z.number().min(-2).max(2).optional(), - logitBias: z.record(z.string(), z.number().min(-100).max(100)).optional(), - logProbs: z.boolean().optional(), - presencePenalty: z.number().min(-2).max(2).optional(), - seed: z.number().int().optional(), - topLogProbs: z.number().int().min(0).max(20).optional(), - user: z.string().optional(), -}); - -export function toOpenAIRole(role: Role): ChatCompletionRole { - switch (role) { - case 'user': - return 'user'; - case 'model': - return 'assistant'; - case 'system': - return 'system'; - case 'tool': - return 'tool'; - default: - throw new Error(`role ${role} doesn't map to an OpenAI role.`); - } -} - -function toOpenAiTool(tool: ToolDefinition): ChatCompletionTool { - return { - type: 'function', - function: { - name: tool.name, - parameters: tool.inputSchema || undefined, - }, - }; -} - -export function toOpenAiTextAndMedia(part: Part): ChatCompletionContentPart { - if (part.text) { - return { - type: 'text', - text: part.text, - }; - } else if (part.media) { - return { - type: 'image_url', - image_url: { - url: part.media.url, - }, - }; - } - throw Error( - `Unsupported genkit part fields encountered for current message role: ${JSON.stringify(part)}.` - ); -} - -export function toOpenAiMessages( - messages: MessageData[] -): ChatCompletionMessageParam[] { - const openAiMsgs: ChatCompletionMessageParam[] = []; - for (const message of messages) { - const msg = new Message(message); - const role = toOpenAIRole(message.role); - switch (role) { - case 'user': - openAiMsgs.push({ - role: role, - content: msg.content.map((part) => toOpenAiTextAndMedia(part)), - }); - break; - case 'system': - openAiMsgs.push({ - role: role, - content: msg.text, - }); - break; - case 'assistant': { - const toolCalls: ChatCompletionMessageToolCall[] = msg.content - .filter( - ( - part - ): part is Part & { - toolRequest: NonNullable; - } => Boolean(part.toolRequest) - ) - .map((part) => ({ - id: part.toolRequest.ref ?? '', - type: 'function', - function: { - name: part.toolRequest.name, - arguments: JSON.stringify(part.toolRequest.input), - }, - })); - if (toolCalls.length > 0) { - openAiMsgs.push({ - role: role, - tool_calls: toolCalls, - }); - } else { - openAiMsgs.push({ - role: role, - content: msg.text, - }); - } - break; - } - case 'tool': { - const toolResponseParts = msg.toolResponseParts(); - toolResponseParts.map((part) => { - openAiMsgs.push({ - role: role, - tool_call_id: part.toolResponse.ref ?? '', - content: - typeof part.toolResponse.output === 'string' - ? part.toolResponse.output - : JSON.stringify(part.toolResponse.output), - }); - }); - break; - } - } - } - return openAiMsgs; -} - -const finishReasonMap: Record< - CompletionChoice['finish_reason'] | 'tool_calls', - CandidateData['finishReason'] -> = { - length: 'length', - stop: 'stop', - tool_calls: 'stop', - content_filter: 'blocked', -}; - -export function fromOpenAiToolCall( - toolCall: - | ChatCompletionMessageToolCall - | ChatCompletionChunk.Choice.Delta.ToolCall -): ToolRequestPart { - if (!toolCall.function) { - throw Error( - `Unexpected openAI chunk choice. tool_calls was provided but one or more tool_calls is missing.` - ); - } - const f = toolCall.function; - return { - toolRequest: { - name: f.name!, - ref: toolCall.id, - input: f.arguments ? JSON.parse(f.arguments) : f.arguments, - }, - }; -} - -export function fromOpenAiChoice( - choice: ChatCompletion.Choice, - jsonMode = false -): CandidateData { - const toolRequestParts = choice.message.tool_calls?.map(fromOpenAiToolCall); - return { - index: choice.index, - finishReason: finishReasonMap[choice.finish_reason] || 'other', - message: { - role: 'model', - content: toolRequestParts - ? // Note: Not sure why I have to cast here exactly. - // Otherwise it thinks toolRequest must be 'undefined' if provided - (toolRequestParts as ToolRequestPart[]) - : [ - jsonMode - ? { data: JSON.parse(choice.message.content!) } - : { text: choice.message.content! }, - ], - }, - custom: {}, - }; -} - -export function fromOpenAiChunkChoice( - choice: ChatCompletionChunk.Choice, - jsonMode = false -): CandidateData { - const toolRequestParts = choice.delta.tool_calls?.map(fromOpenAiToolCall); - return { - index: choice.index, - finishReason: choice.finish_reason - ? finishReasonMap[choice.finish_reason] || 'other' - : 'unknown', - message: { - role: 'model', - content: toolRequestParts - ? (toolRequestParts as ToolRequestPart[]) - : [ - jsonMode - ? { data: JSON.parse(choice.delta.content!) } - : { text: choice.delta.content! }, - ], - }, - custom: {}, - }; -} - -export function toRequestBody( - model: ModelReference, - request: GenerateRequest -) { - const openAiMessages = toOpenAiMessages(request.messages); - const mappedModelName = - request.config?.version || model.version || model.name; - const body = { - model: mappedModelName, - messages: openAiMessages, - temperature: request.config?.temperature, - max_tokens: request.config?.maxOutputTokens, - top_p: request.config?.topP, - stop: request.config?.stopSequences, - frequency_penalty: request.config?.frequencyPenalty, - logit_bias: request.config?.logitBias, - logprobs: request.config?.logProbs, - presence_penalty: request.config?.presencePenalty, - seed: request.config?.seed, - top_logprobs: request.config?.topLogProbs, - user: request.config?.user, - tools: request.tools?.map(toOpenAiTool), - n: request.candidates, - } as ChatCompletionCreateParamsNonStreaming; - const response_format = request.output?.format; - if (response_format) { - if ( - response_format === 'json' && - model.info?.supports?.output?.includes('json') - ) { - body.response_format = { - type: 'json_object', - }; - } else if ( - response_format === 'text' && - model.info?.supports?.output?.includes('text') - ) { - // this is default format, don't need to set it - // body.response_format = { - // type: 'text', - // }; - } else { - throw new Error(`${response_format} format is not supported currently`); - } - } - for (const key in body) { - if (!body[key] || (Array.isArray(body[key]) && !body[key].length)) - delete body[key]; - } - return body; -} - -export function openaiCompatibleModel( - ai: Genkit, - model: ModelReference, - clientFactory: (request: GenerateRequest) => Promise -): ModelAction { - const modelId = model.name; - if (!model) throw new Error(`Unsupported model: ${name}`); - - return ai.defineModel( - { - name: modelId, - ...model.info, - configSchema: model.configSchema, - }, - async ( - request: GenerateRequest, - streamingCallback?: StreamingCallback - ): Promise => { - let response: ChatCompletion; - const client = await clientFactory(request); - const body = toRequestBody(model, request); - if (streamingCallback) { - const stream = client.beta.chat.completions.stream({ - ...body, - stream: true, - }); - for await (const chunk of stream) { - chunk.choices?.forEach((chunk) => { - const c = fromOpenAiChunkChoice(chunk); - streamingCallback({ - index: c.index, - content: c.message.content, - }); - }); - } - response = await stream.finalChatCompletion(); - } else { - response = await client.chat.completions.create(body); - } - return { - candidates: response.choices.map((c) => - fromOpenAiChoice(c, request.output?.format === 'json') - ), - usage: { - inputTokens: response.usage?.prompt_tokens, - outputTokens: response.usage?.completion_tokens, - totalTokens: response.usage?.total_tokens, - }, - custom: response, - }; - } - ); -} diff --git a/js/plugins/checks/src/reranker.ts b/js/plugins/checks/src/reranker.ts deleted file mode 100644 index 95df9b2c9..000000000 --- a/js/plugins/checks/src/reranker.ts +++ /dev/null @@ -1,159 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Genkit, z } from 'genkit'; -import { RankedDocument, RerankerAction, rerankerRef } from 'genkit/reranker'; -import { GoogleAuth } from 'google-auth-library'; -import { PluginOptions } from '.'; - -const DEFAULT_MODEL = 'semantic-ranker-512@latest'; - -const getRerankEndpoint = (projectId: string, location: string) => { - return `https://discoveryengine.googleapis.com/v1/projects/${projectId}/locations/${location}/rankingConfigs/default_ranking_config:rank`; -}; - -// Define the schema for the options used in the Vertex AI reranker -export const VertexAIRerankerOptionsSchema = z.object({ - k: z.number().optional().describe('Number of top documents to rerank'), // Optional: Number of documents to rerank - model: z.string().optional().describe('Model name for reranking'), // Optional: Model name, defaults to a pre-defined model - location: z - .string() - .optional() - .describe('Google Cloud location, e.g., "us-central1"'), // Optional: Location of the reranking model -}); - -// Type alias for the options schema -export type VertexAIRerankerOptions = z.infer< - typeof VertexAIRerankerOptionsSchema ->; - -// Define the structure for each individual reranker configuration -export const VertexRerankerConfigSchema = z.object({ - model: z.string().optional().describe('Model name for reranking'), // Optional: Model name, defaults to a pre-defined model -}); - -export interface VertexRerankerConfig { - name?: string; - model?: string; -} - -export interface VertexRerankPluginOptions { - rerankOptions: VertexRerankerConfig[]; - projectId: string; - location?: string; // Optional: Location of the reranker service -} - -export interface VertexRerankOptions { - authClient: GoogleAuth; - pluginOptions?: PluginOptions; -} - -/** - * Creates Vertex AI rerankers. - * - * This function returns a list of reranker actions for Vertex AI based on the provided - * rerank options and configuration. - * - * @param {VertexRerankOptions} params - The parameters for creating the rerankers. - * @returns {RerankerAction[]} - An array of reranker actions. - */ -export async function vertexAiRerankers( - ai: Genkit, - params: VertexRerankOptions -): Promise[]> { - if (!params.pluginOptions) { - return []; - } - const pluginOptions = params.pluginOptions; - if (!params.pluginOptions.rerankOptions) { - return []; - } - - const rerankOptions = params.pluginOptions.rerankOptions; - const rerankers: RerankerAction[] = []; - - if (!rerankOptions || rerankOptions.length === 0) { - return rerankers; - } - const auth = new GoogleAuth(); - const client = await auth.getClient(); - const projectId = await auth.getProjectId(); - - for (const rerankOption of rerankOptions) { - const reranker = ai.defineReranker( - { - name: `vertexai/${rerankOption.name || rerankOption.model}`, - configSchema: VertexAIRerankerOptionsSchema.optional(), - }, - async (query, documents, _options) => { - const response = await client.request({ - method: 'POST', - url: getRerankEndpoint( - projectId, - pluginOptions.location ?? 'us-central1' - ), - data: { - model: rerankOption.model || DEFAULT_MODEL, // Use model from config or default - query: query.text, - records: documents.map((doc, idx) => ({ - id: `${idx}`, - content: doc.text, - })), - }, - }); - - const rankedDocuments: RankedDocument[] = ( - response.data as any - ).records.map((record: any) => { - const doc = documents[record.id]; - return new RankedDocument({ - content: doc.content, - metadata: { - ...doc.metadata, - score: record.score, - }, - }); - }); - - return { documents: rankedDocuments }; - } - ); - - rerankers.push(reranker); - } - - return rerankers; -} - -/** - * Creates a reference to a Vertex AI reranker. - * - * @param {Object} params - The parameters for the reranker reference. - * @param {string} [params.displayName] - An optional display name for the reranker. - * @returns {Object} - The reranker reference object. - */ -export const vertexAiRerankerRef = (params: { - name: string; - displayName?: string; -}) => { - return rerankerRef({ - name: `vertexai/${name}`, - info: { - label: params.displayName ?? `Vertex AI Reranker`, - }, - configSchema: VertexAIRerankerOptionsSchema.optional(), - }); -}; diff --git a/js/plugins/checks/src/vector-search/bigquery.ts b/js/plugins/checks/src/vector-search/bigquery.ts deleted file mode 100644 index e3a40ba61..000000000 --- a/js/plugins/checks/src/vector-search/bigquery.ts +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { BigQuery, QueryRowsResponse } from '@google-cloud/bigquery'; -import { z } from 'genkit'; -import { logger } from 'genkit/logging'; -import { Document, DocumentDataSchema } from 'genkit/retriever'; -import { DocumentIndexer, DocumentRetriever, Neighbor } from './types'; - -/** - * Creates a BigQuery Document Retriever. - * - * This function returns a DocumentRetriever function that retrieves documents - * from a BigQuery table based on the provided neighbors. - * - * @param {BigQuery} bq - The BigQuery instance. - * @param {string} tableId - The ID of the BigQuery table. - * @param {string} datasetId - The ID of the BigQuery dataset. - * @returns {DocumentRetriever} - The DocumentRetriever function. - */ -export const getBigQueryDocumentRetriever = ( - bq: BigQuery, - tableId: string, - datasetId: string -): DocumentRetriever => { - const bigQueryRetriever: DocumentRetriever = async ( - neighbors: Neighbor[] - ): Promise => { - const ids: string[] = neighbors - .map((neighbor) => neighbor.datapoint?.datapointId) - .filter(Boolean) as string[]; - - const query = ` - SELECT * FROM \`${datasetId}.${tableId}\` - WHERE id IN UNNEST(@ids) - `; - - const options = { - query, - params: { ids }, - }; - - let rows: QueryRowsResponse[0]; - - try { - [rows] = await bq.query(options); - } catch (queryError) { - logger.error('Failed to execute BigQuery query:', queryError); - return []; - } - - const documents: Document[] = []; - - for (const row of rows) { - try { - const docData: { content: any; metadata?: any } = { - content: JSON.parse(row.content), - }; - - if (row.metadata) { - docData.metadata = JSON.parse(row.metadata); - } - - const parsedDocData = DocumentDataSchema.parse(docData); - documents.push(new Document(parsedDocData)); - } catch (error) { - const id = row.id; - const errorPrefix = `Failed to parse document data for document with ID ${id}:`; - - if (error instanceof z.ZodError || error instanceof Error) { - logger.warn(`${errorPrefix} ${error.message}`); - } else { - logger.warn(errorPrefix); - } - } - } - - return documents; - }; - - return bigQueryRetriever; -}; - -/** - * Creates a BigQuery Document Indexer. - * - * This function returns a DocumentIndexer function that indexes documents - * into a BigQuery table. Note this indexer does not handle duplicate - * documents. - * - * @param {BigQuery} bq - The BigQuery instance. - * @param {string} tableId - The ID of the BigQuery table. - * @param {string} datasetId - The ID of the BigQuery dataset. - * @returns {DocumentIndexer} - The DocumentIndexer function. - */ -export const getBigQueryDocumentIndexer = ( - bq: BigQuery, - tableId: string, - datasetId: string -): DocumentIndexer => { - const bigQueryIndexer: DocumentIndexer = async ( - docs: Document[] - ): Promise => { - const ids: string[] = []; - const rows = docs.map((doc) => { - const id = Math.random().toString(36).substring(7); - ids.push(id); - return { - id, - content: JSON.stringify(doc.content), - metadata: JSON.stringify(doc.metadata), - }; - }); - await bq.dataset(datasetId).table(tableId).insert(rows); - return ids; - }; - return bigQueryIndexer; -}; diff --git a/js/plugins/checks/src/vector-search/firestore.ts b/js/plugins/checks/src/vector-search/firestore.ts deleted file mode 100644 index 1eefc894f..000000000 --- a/js/plugins/checks/src/vector-search/firestore.ts +++ /dev/null @@ -1,87 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Firestore } from 'firebase-admin/firestore'; -import { Document, DocumentDataSchema } from 'genkit'; -import { DocumentIndexer, DocumentRetriever, Neighbor } from './types'; -/** - * Creates a Firestore Document Retriever. - * - * This function returns a DocumentRetriever function that retrieves documents - * from a Firestore collection based on the provided Vertex AI Vector Search neighbors. - * - * @param {Firestore} db - The Firestore instance. - * @param {string} collectionName - The name of the Firestore collection. - * @returns {DocumentRetriever} - The DocumentRetriever function. - */ -export const getFirestoreDocumentRetriever = ( - db: Firestore, - collectionName: string -): DocumentRetriever => { - const firestoreRetriever: DocumentRetriever = async ( - neighbors: Neighbor[] - ): Promise => { - const docs: Document[] = []; - for (const neighbor of neighbors) { - const docRef = db - .collection(collectionName) - .doc(neighbor.datapoint?.datapointId!); - const docSnapshot = await docRef.get(); - if (docSnapshot.exists) { - const docData = { ...docSnapshot.data(), metadata: { ...neighbor } }; - const parsedDocData = DocumentDataSchema.safeParse(docData); - if (parsedDocData.success) { - docs.push(new Document(parsedDocData.data)); - } - } - } - return docs; - }; - return firestoreRetriever; -}; - -/** - * Creates a Firestore Document Indexer. - * - * This function returns a DocumentIndexer function that indexes documents - * into a Firestore collection. - * - * @param {Firestore} db - The Firestore instance. - * @param {string} collectionName - The name of the Firestore collection. - * @returns {DocumentIndexer} - The DocumentIndexer function. - */ -export const getFirestoreDocumentIndexer = ( - db: Firestore, - collectionName: string -) => { - const firestoreIndexer: DocumentIndexer = async ( - docs: Document[] - ): Promise => { - const batch = db.batch(); - const ids: string[] = []; - docs.forEach((doc) => { - const docRef = db.collection(collectionName).doc(); - batch.set(docRef, { - content: doc.content, - metadata: doc.metadata || null, - }); - ids.push(docRef.id); - }); - await batch.commit(); - return ids; - }; - return firestoreIndexer; -}; diff --git a/js/plugins/checks/src/vector-search/index.ts b/js/plugins/checks/src/vector-search/index.ts deleted file mode 100644 index 638ba1abc..000000000 --- a/js/plugins/checks/src/vector-search/index.ts +++ /dev/null @@ -1,36 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -export { - getBigQueryDocumentIndexer, - getBigQueryDocumentRetriever, -} from './bigquery'; -export { - getFirestoreDocumentIndexer, - getFirestoreDocumentRetriever, -} from './firestore'; -export { vertexAiIndexerRef, vertexAiIndexers } from './indexers'; -export { vertexAiRetrieverRef, vertexAiRetrievers } from './retrievers'; -export { - DocumentIndexer, - DocumentRetriever, - Neighbor, - VectorSearchOptions, - VertexAIVectorIndexerOptions, - VertexAIVectorIndexerOptionsSchema, - VertexAIVectorRetrieverOptions, - VertexAIVectorRetrieverOptionsSchema, -} from './types'; diff --git a/js/plugins/checks/src/vector-search/indexers.ts b/js/plugins/checks/src/vector-search/indexers.ts deleted file mode 100644 index 66a00e913..000000000 --- a/js/plugins/checks/src/vector-search/indexers.ts +++ /dev/null @@ -1,120 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Genkit, z } from 'genkit'; -import { IndexerAction, indexerRef } from 'genkit/retriever'; -import { - Datapoint, - VertexAIVectorIndexerOptionsSchema, - VertexVectorSearchOptions, -} from './types'; -import { upsertDatapoints } from './upsert_datapoints'; - -/** - * Creates a reference to a Vertex AI indexer. - * - * @param {Object} params - The parameters for the indexer reference. - * @param {string} params.indexId - The ID of the Vertex AI index. - * @param {string} [params.displayName] - An optional display name for the indexer. - * @returns {Object} - The indexer reference object. - */ -export const vertexAiIndexerRef = (params: { - indexId: string; - displayName?: string; -}) => { - return indexerRef({ - name: `vertexai/${params.indexId}`, - info: { - label: params.displayName ?? `Vertex AI - ${params.indexId}`, - }, - configSchema: VertexAIVectorIndexerOptionsSchema.optional(), - }); -}; - -/** - * Creates Vertex AI indexers. - * - * This function returns a list of indexer actions for Vertex AI based on the provided - * vector search options and embedder configurations. - * - * @param {VertexVectorSearchOptions} params - The parameters for creating the indexers. - * @returns {IndexerAction[]} - An array of indexer actions. - */ -export function vertexAiIndexers( - ai: Genkit, - params: VertexVectorSearchOptions -): IndexerAction[] { - const vectorSearchOptions = params.pluginOptions.vectorSearchOptions; - const defaultEmbedder = params.defaultEmbedder; - const indexers: IndexerAction[] = []; - - if (!vectorSearchOptions || vectorSearchOptions.length === 0) { - return indexers; - } - - for (const vectorSearchOption of vectorSearchOptions) { - const { documentIndexer, indexId } = vectorSearchOption; - const embedder = vectorSearchOption.embedder ?? defaultEmbedder; - const embedderOptions = vectorSearchOption.embedderOptions; - - const indexer = ai.defineIndexer( - { - name: `vertexai/${indexId}`, - configSchema: VertexAIVectorIndexerOptionsSchema.optional(), - }, - async (docs, options) => { - let docIds: string[] = []; - - try { - docIds = await documentIndexer(docs, options); - } catch (error) { - throw new Error( - `Error storing your document content/metadata: ${error}` - ); - } - - const embeddings = await ai.embedMany({ - embedder, - content: docs, - options: embedderOptions, - }); - - const datapoints = embeddings.map( - ({ embedding }, i) => - new Datapoint({ - datapointId: docIds[i], - featureVector: embedding, - }) - ); - - try { - await upsertDatapoints({ - datapoints, - authClient: params.authClient, - projectId: params.pluginOptions.projectId!, - location: params.pluginOptions.location!, - indexId: indexId, - }); - } catch (error) { - throw error; - } - } - ); - - indexers.push(indexer); - } - return indexers; -} diff --git a/js/plugins/checks/src/vector-search/query_public_endpoint.ts b/js/plugins/checks/src/vector-search/query_public_endpoint.ts deleted file mode 100644 index f055e3b9d..000000000 --- a/js/plugins/checks/src/vector-search/query_public_endpoint.ts +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { logger } from 'genkit/logging'; -import { FindNeighborsResponse } from './types'; - -interface QueryPublicEndpointParams { - featureVector: number[]; - neighborCount: number; - accessToken: string; - projectId: string; - location: string; - indexEndpointId: string; - publicDomainName: string; - projectNumber: string; - deployedIndexId: string; -} -/** - * Queries a public index endpoint to find neighbors for a given feature vector. - * - * This function sends a request to a specified public endpoint to find neighbors - * for a given feature vector using the provided parameters. - * - * @param {QueryPublicEndpointParams} params - The parameters required to query the public endpoint. - * @param {number[]} params.featureVector - The feature vector for which to find neighbors. - * @param {number} params.neighborCount - The number of neighbors to retrieve. - * @param {string} params.accessToken - The access token for authorization. - * @param {string} params.projectId - The ID of the Google Cloud project. - * @param {string} params.location - The location of the index endpoint. - * @param {string} params.indexEndpointId - The ID of the index endpoint. - * @param {string} params.publicDomainName - The domain name of the public endpoint. - * @param {string} params.projectNumber - The project number. - * @param {string} params.deployedIndexId - The ID of the deployed index. - * @returns {Promise} - The response from the public endpoint. - */ -export async function queryPublicEndpoint( - params: QueryPublicEndpointParams -): Promise { - const { - featureVector, - neighborCount, - accessToken, - indexEndpointId, - publicDomainName, - projectNumber, - deployedIndexId, - location, - } = params; - const url = new URL( - `https://${publicDomainName}/v1/projects/${projectNumber}/locations/${location}/indexEndpoints/${indexEndpointId}:findNeighbors` - ); - - const requestBody = { - deployed_index_id: deployedIndexId, - queries: [ - { - datapoint: { - datapoint_id: '0', - feature_vector: featureVector, - }, - neighbor_count: neighborCount, - }, - ], - }; - - const response = await fetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${accessToken}`, - }, - body: JSON.stringify(requestBody), - }); - - if (!response.ok) { - logger.error('Error querying index: ', response.statusText); - } - return response.json(); -} diff --git a/js/plugins/checks/src/vector-search/retrievers.ts b/js/plugins/checks/src/vector-search/retrievers.ts deleted file mode 100644 index 67f47f33d..000000000 --- a/js/plugins/checks/src/vector-search/retrievers.ts +++ /dev/null @@ -1,136 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Genkit, RetrieverAction, retrieverRef, z } from 'genkit'; -import { queryPublicEndpoint } from './query_public_endpoint'; -import { - VertexAIVectorRetrieverOptionsSchema, - VertexVectorSearchOptions, -} from './types'; -import { getProjectNumber } from './utils'; - -const DEFAULT_K = 10; - -/** - * Creates Vertex AI retrievers. - * - * This function returns a list of retriever actions for Vertex AI based on the provided - * vector search options and embedder configurations. - * - * @param {VertexVectorSearchOptions} params - The parameters for creating the retrievers. - * @returns {RetrieverAction[]} - An array of retriever actions. - */ -export function vertexAiRetrievers( - ai: Genkit, - params: VertexVectorSearchOptions -): RetrieverAction[] { - const vectorSearchOptions = params.pluginOptions.vectorSearchOptions; - const defaultEmbedder = params.defaultEmbedder; - - const retrievers: RetrieverAction[] = []; - - if (!vectorSearchOptions || vectorSearchOptions.length === 0) { - return retrievers; - } - - for (const vectorSearchOption of vectorSearchOptions) { - const { documentRetriever, indexId, publicDomainName } = vectorSearchOption; - const embedder = vectorSearchOption.embedder ?? defaultEmbedder; - const embedderOptions = vectorSearchOption.embedderOptions; - - const retriever = ai.defineRetriever( - { - name: `vertexai/${indexId}`, - configSchema: VertexAIVectorRetrieverOptionsSchema.optional(), - }, - async (content, options) => { - const queryEmbeddings = await ai.embed({ - embedder, - options: embedderOptions, - content, - }); - - const accessToken = await params.authClient.getAccessToken(); - - if (!accessToken) { - throw new Error( - 'Error generating access token when defining Vertex AI retriever' - ); - } - - const projectId = params.pluginOptions.projectId; - if (!projectId) { - throw new Error( - 'Project ID is required to define Vertex AI retriever' - ); - } - const projectNumber = await getProjectNumber(projectId); - const location = params.pluginOptions.location; - if (!location) { - throw new Error('Location is required to define Vertex AI retriever'); - } - - let res = await queryPublicEndpoint({ - featureVector: queryEmbeddings, - neighborCount: options?.k || DEFAULT_K, - accessToken, - projectId, - location, - publicDomainName, - projectNumber, - indexEndpointId: vectorSearchOption.indexEndpointId, - deployedIndexId: vectorSearchOption.deployedIndexId, - }); - const nearestNeighbors = res.nearestNeighbors; - - const queryRes = nearestNeighbors ? nearestNeighbors[0] : null; - const neighbors = queryRes ? queryRes.neighbors : null; - if (!neighbors) { - return { documents: [] }; - } - - const documents = await documentRetriever(neighbors, options); - - return { documents }; - } - ); - - retrievers.push(retriever); - } - - return retrievers; -} - -/** - * Creates a reference to a Vertex AI retriever. - * - * @param {Object} params - The parameters for the retriever reference. - * @param {string} params.indexId - The ID of the Vertex AI index. - * @param {string} [params.displayName] - An optional display name for the retriever. - * @returns {Object} - The retriever reference object. - */ -export const vertexAiRetrieverRef = (params: { - indexId: string; - displayName?: string; -}) => { - return retrieverRef({ - name: `vertexai/${params.indexId}`, - info: { - label: params.displayName ?? `ertex AI - ${params.indexId}`, - }, - configSchema: VertexAIVectorRetrieverOptionsSchema.optional(), - }); -}; diff --git a/js/plugins/checks/src/vector-search/types.ts b/js/plugins/checks/src/vector-search/types.ts deleted file mode 100644 index 6b58e4f34..000000000 --- a/js/plugins/checks/src/vector-search/types.ts +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import * as aiplatform from '@google-cloud/aiplatform'; -import { z } from 'genkit'; -import { EmbedderArgument } from 'genkit/embedder'; -import { CommonRetrieverOptionsSchema, Document } from 'genkit/retriever'; -import { GoogleAuth } from 'google-auth-library'; -import { PluginOptions } from '..'; - -// This internal interface will be passed to the vertexIndexers and vertexRetrievers functions -export interface VertexVectorSearchOptions< - EmbedderCustomOptions extends z.ZodTypeAny, -> { - pluginOptions: PluginOptions; - authClient: GoogleAuth; - defaultEmbedder: EmbedderArgument; -} - -export type IIndexDatapoint = - aiplatform.protos.google.cloud.aiplatform.v1.IIndexDatapoint; - -export class Datapoint extends aiplatform.protos.google.cloud.aiplatform.v1 - .IndexDatapoint { - constructor(properties: IIndexDatapoint) { - super(properties); - } -} - -export type IFindNeighborsRequest = - aiplatform.protos.google.cloud.aiplatform.v1.IFindNeighborsRequest; -export type IFindNeighborsResponse = - aiplatform.protos.google.cloud.aiplatform.v1.IFindNeighborsResponse; -export type ISparseEmbedding = - aiplatform.protos.google.cloud.aiplatform.v1.IndexDatapoint.ISparseEmbedding; -export type IRestriction = - aiplatform.protos.google.cloud.aiplatform.v1.IndexDatapoint.IRestriction; -export type INumericRestriction = - aiplatform.protos.google.cloud.aiplatform.v1.IndexDatapoint.INumericRestriction; - -// Define the Zod schema for ISparseEmbedding -export const SparseEmbeddingSchema = z.object({ - values: z.array(z.number()).optional(), - dimensions: z.array(z.union([z.number(), z.string()])).optional(), -}); - -export type SparseEmbedding = z.infer; - -// Define the Zod schema for IRestriction -export const RestrictionSchema = z.object({ - namespace: z.string().optional(), - allowList: z.array(z.string()).optional(), - denyList: z.array(z.string()).optional(), -}); - -export type Restriction = z.infer; - -// Define the Zod schema for INumericRestriction -export const NumericRestrictionSchema = z.object({ - valueInt: z.union([z.number(), z.string()]).optional(), - valueFloat: z.number().optional(), - valueDouble: z.number().optional(), - namespace: z.string().optional(), - op: z - .union([ - z.enum([ - 'OPERATOR_UNSPECIFIED', - 'LESS', - 'LESS_EQUAL', - 'EQUAL', - 'GREATER_EQUAL', - 'GREATER', - 'NOT_EQUAL', - ]), - z.null(), - ]) - .optional(), -}); - -export type NumericRestriction = z.infer; - -// Define the Zod schema for ICrowdingTag -export const CrowdingTagSchema = z.object({ - crowdingAttribute: z.string().optional(), -}); - -export type CrowdingTag = z.infer; - -// Define the Zod schema for IIndexDatapoint -const IndexDatapointSchema = z.object({ - datapointId: z.string().optional(), - featureVector: z.array(z.number()).optional(), - sparseEmbedding: SparseEmbeddingSchema.optional(), - restricts: z.array(RestrictionSchema).optional(), - numericRestricts: z.array(NumericRestrictionSchema).optional(), - crowdingTag: CrowdingTagSchema.optional(), -}); - -// Define the Zod schema for INeighbor -export const NeighborSchema = z.object({ - datapoint: IndexDatapointSchema.optional(), - distance: z.number().optional(), - sparseDistance: z.number().optional(), -}); - -export type Neighbor = z.infer; - -// Define the Zod schema for INearestNeighbors -const NearestNeighborsSchema = z.object({ - id: z.string().optional(), - neighbors: z.array(NeighborSchema).optional(), -}); - -// Define the Zod schema for IFindNeighborsResponse -export const FindNeighborsResponseSchema = z.object({ - nearestNeighbors: z.array(NearestNeighborsSchema).optional(), -}); - -export type FindNeighborsResponse = z.infer; - -// TypeScript types for Zod schemas -type IndexDatapoint = z.infer; - -// Function to assert type equality -function assertTypeEquality(value: T): void {} - -// Asserting type equality -assertTypeEquality({} as IndexDatapoint); -assertTypeEquality({} as FindNeighborsResponse); - -export const VertexAIVectorRetrieverOptionsSchema = - CommonRetrieverOptionsSchema.extend({}).optional(); - -export type VertexAIVectorRetrieverOptions = z.infer< - typeof VertexAIVectorRetrieverOptionsSchema ->; - -export const VertexAIVectorIndexerOptionsSchema = z.any(); - -export type VertexAIVectorIndexerOptions = z.infer< - typeof VertexAIVectorIndexerOptionsSchema ->; - -/** - * A document retriever function that takes an array of Neighbors from Vertex AI Vector Search query result, and resolves to a list of documents. - * Also takes an options object that can be used to configure the retriever. - */ -export type DocumentRetriever = - (docIds: Neighbor[], options?: Options) => Promise; - -/** - * Indexer function that takes an array of documents, stores them in a database of the user's choice, and resolves to a list of document ids. - * Also takes an options object that can be used to configure the indexer. Only Streaming Update Indexers are supported. - */ -export type DocumentIndexer = ( - docs: Document[], - options?: Options -) => Promise; - -export interface VectorSearchOptions< - EmbedderCustomOptions extends z.ZodTypeAny, - IndexerOptions extends {}, - RetrieverOptions extends { k?: number }, -> { - // Specify the Vertex AI Index and IndexEndpoint to use for indexing and retrieval - deployedIndexId: string; - indexEndpointId: string; - publicDomainName: string; - indexId: string; - // Document retriever and indexer functions to use for indexing and retrieval by the plugin's own indexers and retrievers - documentRetriever: DocumentRetriever; - documentIndexer: DocumentIndexer; - // Embedder and default options to use for indexing and retrieval - embedder?: EmbedderArgument; - embedderOptions?: z.infer; -} diff --git a/js/plugins/checks/src/vector-search/upsert_datapoints.ts b/js/plugins/checks/src/vector-search/upsert_datapoints.ts deleted file mode 100644 index cfeb8d5ec..000000000 --- a/js/plugins/checks/src/vector-search/upsert_datapoints.ts +++ /dev/null @@ -1,71 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { GoogleAuth } from 'google-auth-library'; -import { IIndexDatapoint } from './types'; - -interface UpsertDatapointsParams { - datapoints: IIndexDatapoint[]; - authClient: GoogleAuth; - projectId: string; - location: string; - indexId: string; -} - -/** - * Upserts datapoints into a specified index. - * - * This function sends a request to the Google AI Platform to upsert datapoints - * into a specified index using the provided parameters. - * - * @param {UpsertDatapointsParams} params - The parameters required to upsert datapoints. - * @param {IIndexDatapoint[]} params.datapoints - The datapoints to be upserted. - * @param {GoogleAuth} params.authClient - The GoogleAuth client for authorization. - * @param {string} params.projectId - The ID of the Google Cloud project. - * @param {string} params.location - The location of the AI Platform index. - * @param {string} params.indexId - The ID of the index. - * @returns {Promise} - A promise that resolves when the upsert is complete. - * @throws {Error} - Throws an error if the upsert fails. - */ -export async function upsertDatapoints( - params: UpsertDatapointsParams -): Promise { - const { datapoints, authClient, projectId, location, indexId } = params; - const accessToken = await authClient.getAccessToken(); - const url = `https://${location}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${location}/indexes/${indexId}:upsertDatapoints`; - - const requestBody = { - datapoints: datapoints.map((dp) => ({ - datapoint_id: dp.datapointId, - feature_vector: dp.featureVector, - })), - }; - - const response = await fetch(url, { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - Authorization: `Bearer ${accessToken}`, - }, - body: JSON.stringify(requestBody), - }); - - if (!response.ok) { - throw new Error( - `Error upserting datapoints into index ${indexId}: ${response.statusText}` - ); - } -} diff --git a/js/plugins/checks/src/vector-search/utils.ts b/js/plugins/checks/src/vector-search/utils.ts deleted file mode 100644 index c6415b8c5..000000000 --- a/js/plugins/checks/src/vector-search/utils.ts +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { GoogleAuth } from 'google-auth-library'; -import { google } from 'googleapis'; - -/** - * Retrieves an access token using the provided GoogleAuth client. - * - * @param {GoogleAuth} auth - The GoogleAuth client. - * @returns {Promise} - A promise that resolves to the access token. - */ -export async function getAccessToken(auth: GoogleAuth): Promise { - const client = await auth.getClient(); - const _accessToken = await client.getAccessToken(); - return _accessToken.token || null; -} - -/** - * Retrieves the project number for a given project ID. - * - * This function sends a request to the Google Cloud Resource Manager API to - * fetch the project number for the specified project ID. - * - * @param {string} projectId - The ID of the Google Cloud project. - * @returns {Promise} - A promise that resolves to the project number. - * @throws {Error} - Throws an error if the project number cannot be fetched. - */ -export async function getProjectNumber(projectId: string): Promise { - const client = google.cloudresourcemanager('v1'); - const authClient = await google.auth.getClient({ - scopes: ['https://www.googleapis.com/auth/cloud-platform'], - }); - - try { - const response = await client.projects.get({ - projectId: projectId, - auth: authClient, - }); - - if (!response.data.projectNumber) { - throw new Error( - `Error fetching project number for Vertex AI plugin for project ${projectId}` - ); - } - return response.data['projectNumber']; - } catch (error) { - throw new Error( - `Error fetching project number for Vertex AI plugin for project ${projectId}` - ); - } -} From 1dc53255dd4a63966954f8d6046ad40ca8f3d6c4 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Wed, 30 Oct 2024 19:03:35 +0000 Subject: [PATCH 03/30] Vertex ai plugin gutted. Can now remove existing vertexai evaluators. --- js/plugins/checks/package.json | 6 +- js/plugins/checks/src/evaluation.ts | 60 ++++---- js/plugins/checks/src/evaluator_factory.ts | 4 +- js/plugins/checks/src/index.ts | 22 +-- js/plugins/checks/src/predict.ts | 83 ----------- js/pnpm-lock.yaml | 154 +-------------------- js/testapps/byo-evaluator/package.json | 1 + js/testapps/byo-evaluator/src/index.ts | 8 ++ 8 files changed, 60 insertions(+), 278 deletions(-) delete mode 100644 js/plugins/checks/src/predict.ts diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index 23e2d5062..d9ee9e180 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -35,14 +35,10 @@ "author": "genkit", "license": "Apache-2.0", "dependencies": { - "@anthropic-ai/sdk": "^0.24.3", - "@anthropic-ai/vertex-sdk": "^0.4.0", "@google-cloud/aiplatform": "^3.23.0", - "@google-cloud/vertexai": "^1.1.0", "google-auth-library": "^9.6.3", "googleapis": "^140.0.1", - "node-fetch": "^3.3.2", - "openai": "^4.52.7" + "node-fetch": "^3.3.2" }, "peerDependencies": { "genkit": "workspace:*" diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 965b9fc86..3804c8df1 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -19,11 +19,11 @@ import { GoogleAuth } from 'google-auth-library'; import { EvaluatorFactory } from './evaluator_factory.js'; /** - * Vertex AI Evaluation metrics. See API documentation for more information. - * https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation#parameter-list + * Checks AI Safety policies. See API documentation for more information. + * TODO: add documentation link. */ -export enum VertexAIEvaluationMetricType { - // Update genkit/docs/plugins/vertex-ai.md when modifying the list of enums +export enum ChecksEvaluationMetricType { + // TODO: Change to match checks policies. BLEU = 'BLEU', ROUGE = 'ROUGE', FLUENCY = 'FLEUNCY', @@ -40,19 +40,19 @@ export enum VertexAIEvaluationMetricType { * for details on the possible values of `metricSpec` for each metric. * https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation#parameter-list */ -export type VertexAIEvaluationMetricConfig = { - type: VertexAIEvaluationMetricType; +export type ChecksEvaluationMetricConfig = { + type: ChecksEvaluationMetricType; metricSpec: any; }; -export type VertexAIEvaluationMetric = - | VertexAIEvaluationMetricType - | VertexAIEvaluationMetricConfig; +export type ChecksEvaluationMetric = + | ChecksEvaluationMetricType + | ChecksEvaluationMetricConfig; -export function vertexEvaluators( +export function checksEvaluators( ai: Genkit, auth: GoogleAuth, - metrics: VertexAIEvaluationMetric[], + metrics: ChecksEvaluationMetric[], projectId: string, location: string ): Action[] { @@ -62,28 +62,28 @@ export function vertexEvaluators( const metricSpec = isConfig(metric) ? metric.metricSpec : {}; switch (metricType) { - case VertexAIEvaluationMetricType.BLEU: { + case ChecksEvaluationMetricType.BLEU: { return createBleuEvaluator(ai, factory, metricSpec); } - case VertexAIEvaluationMetricType.ROUGE: { + case ChecksEvaluationMetricType.ROUGE: { return createRougeEvaluator(ai, factory, metricSpec); } - case VertexAIEvaluationMetricType.FLUENCY: { + case ChecksEvaluationMetricType.FLUENCY: { return createFluencyEvaluator(ai, factory, metricSpec); } - case VertexAIEvaluationMetricType.SAFETY: { + case ChecksEvaluationMetricType.SAFETY: { return createSafetyEvaluator(ai, factory, metricSpec); } - case VertexAIEvaluationMetricType.GROUNDEDNESS: { + case ChecksEvaluationMetricType.GROUNDEDNESS: { return createGroundednessEvaluator(ai, factory, metricSpec); } - case VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY: { + case ChecksEvaluationMetricType.SUMMARIZATION_QUALITY: { return createSummarizationQualityEvaluator(ai, factory, metricSpec); } - case VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS: { + case ChecksEvaluationMetricType.SUMMARIZATION_HELPFULNESS: { return createSummarizationHelpfulnessEvaluator(ai, factory, metricSpec); } - case VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY: { + case ChecksEvaluationMetricType.SUMMARIZATION_VERBOSITY: { return createSummarizationVerbosityEvaluator(ai, factory, metricSpec); } } @@ -91,9 +91,9 @@ export function vertexEvaluators( } function isConfig( - config: VertexAIEvaluationMetric -): config is VertexAIEvaluationMetricConfig { - return (config as VertexAIEvaluationMetricConfig).type !== undefined; + config: ChecksEvaluationMetric +): config is ChecksEvaluationMetricConfig { + return (config as ChecksEvaluationMetricConfig).type !== undefined; } const BleuResponseSchema = z.object({ @@ -111,7 +111,7 @@ function createBleuEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.BLEU, + metric: ChecksEvaluationMetricType.BLEU, displayName: 'BLEU', definition: 'Computes the BLEU score by comparing the output against the ground truth', @@ -153,7 +153,7 @@ function createRougeEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.ROUGE, + metric: ChecksEvaluationMetricType.ROUGE, displayName: 'ROUGE', definition: 'Computes the ROUGE score by comparing the output against the ground truth', @@ -194,7 +194,7 @@ function createFluencyEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.FLUENCY, + metric: ChecksEvaluationMetricType.FLUENCY, displayName: 'Fluency', definition: 'Assesses the language mastery of an output', responseSchema: FluencyResponseSchema, @@ -236,7 +236,7 @@ function createSafetyEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.SAFETY, + metric: ChecksEvaluationMetricType.SAFETY, displayName: 'Safety', definition: 'Assesses the level of safety of an output', responseSchema: SafetyResponseSchema, @@ -278,7 +278,7 @@ function createGroundednessEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.GROUNDEDNESS, + metric: ChecksEvaluationMetricType.GROUNDEDNESS, displayName: 'Groundedness', definition: 'Assesses the ability to provide or reference information included only in the context', @@ -322,7 +322,7 @@ function createSummarizationQualityEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.SUMMARIZATION_QUALITY, + metric: ChecksEvaluationMetricType.SUMMARIZATION_QUALITY, displayName: 'Summarization quality', definition: 'Assesses the overall ability to summarize text', responseSchema: SummarizationQualityResponseSchema, @@ -366,7 +366,7 @@ function createSummarizationHelpfulnessEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.SUMMARIZATION_HELPFULNESS, + metric: ChecksEvaluationMetricType.SUMMARIZATION_HELPFULNESS, displayName: 'Summarization helpfulness', definition: 'Assesses the ability to provide a summarization, which contains the details necessary to substitute the original text', @@ -411,7 +411,7 @@ function createSummarizationVerbosityEvaluator( return factory.create( ai, { - metric: VertexAIEvaluationMetricType.SUMMARIZATION_VERBOSITY, + metric: ChecksEvaluationMetricType.SUMMARIZATION_VERBOSITY, displayName: 'Summarization verbosity', definition: 'Aassess the ability to provide a succinct summarization', responseSchema: SummarizationVerbositySchema, diff --git a/js/plugins/checks/src/evaluator_factory.ts b/js/plugins/checks/src/evaluator_factory.ts index 821f4631b..2f74ff4c8 100644 --- a/js/plugins/checks/src/evaluator_factory.ts +++ b/js/plugins/checks/src/evaluator_factory.ts @@ -18,7 +18,7 @@ import { Action, Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; import { BaseEvalDataPoint, Score } from 'genkit/evaluator'; import { runInNewSpan } from 'genkit/tracing'; import { GoogleAuth } from 'google-auth-library'; -import { VertexAIEvaluationMetricType } from './evaluation.js'; +import { ChecksEvaluationMetricType } from './evaluation.js'; export class EvaluatorFactory { constructor( @@ -30,7 +30,7 @@ export class EvaluatorFactory { create( ai: Genkit, config: { - metric: VertexAIEvaluationMetricType; + metric: ChecksEvaluationMetricType; displayName: string; definition: string; responseSchema: ResponseType; diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index 8eb8d9ff1..3a44aa7ff 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -18,12 +18,12 @@ import { Genkit, z } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; import { - VertexAIEvaluationMetric, - VertexAIEvaluationMetricType, - vertexEvaluators, + ChecksEvaluationMetric, + ChecksEvaluationMetricType, + checksEvaluators, } from './evaluation.js'; export { - VertexAIEvaluationMetricType as VertexAIEvaluationMetricType, + ChecksEvaluationMetricType as ChecksEvaluationMetricType, }; export interface PluginOptions { @@ -35,7 +35,7 @@ export interface PluginOptions { googleAuth?: GoogleAuthOptions; /** Configure Vertex AI evaluators */ evaluation?: { - metrics: VertexAIEvaluationMetric[]; + metrics: ChecksEvaluationMetric[]; }; } @@ -43,10 +43,10 @@ const CLOUD_PLATFROM_OAUTH_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'; /** - * Add Google Cloud Vertex AI to Genkit. Includes Gemini and Imagen models and text embedder. + * Add Google Checks evaluators. */ -export function vertexAI(options?: PluginOptions): GenkitPlugin { - return genkitPlugin('vertexai', async (ai: Genkit) => { +export function checks(options?: PluginOptions): GenkitPlugin { + return genkitPlugin('checks', async (ai: Genkit) => { let authClient; let authOptions = options?.googleAuth; @@ -72,7 +72,7 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { const location = options?.location || 'us-central1'; const confError = (parameter: string, envVariableName: string) => { return new Error( - `VertexAI Plugin is missing the '${parameter}' configuration. Please set the '${envVariableName}' environment variable or explicitly pass '${parameter}' into genkit config.` + `Checks Plugin is missing the '${parameter}' configuration. Please set the '${envVariableName}' environment variable or explicitly pass '${parameter}' into genkit config.` ); }; if (!location) { @@ -86,8 +86,8 @@ export function vertexAI(options?: PluginOptions): GenkitPlugin { options?.evaluation && options.evaluation.metrics.length > 0 ? options.evaluation.metrics : []; - vertexEvaluators(ai, authClient, metrics, projectId, location); + checksEvaluators(ai, authClient, metrics, projectId, location); }); } -export default vertexAI; +export default checks; diff --git a/js/plugins/checks/src/predict.ts b/js/plugins/checks/src/predict.ts deleted file mode 100644 index dfc538a5b..000000000 --- a/js/plugins/checks/src/predict.ts +++ /dev/null @@ -1,83 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { GENKIT_CLIENT_HEADER } from 'genkit'; -import { GoogleAuth } from 'google-auth-library'; -import { PluginOptions } from '.'; - -function endpoint(options: { - projectId: string; - location: string; - model: string; -}) { - // eslint-disable-next-line max-len - return `https://${options.location}-aiplatform.googleapis.com/v1/projects/${options.projectId}/locations/${options.location}/publishers/google/models/${options.model}:predict`; -} - -interface PredictionResponse { - predictions: R[]; -} - -export type PredictClient = ( - instances: I[], - parameters?: P -) => Promise>; - -export function predictModel( - auth: GoogleAuth, - { location, projectId }: PluginOptions, - model: string -): PredictClient { - return async ( - instances: I[], - parameters?: P - ): Promise> => { - const fetch = (await import('node-fetch')).default; - - const accessToken = await auth.getAccessToken(); - const req = { - instances, - parameters: parameters || {}, - }; - - const response = await fetch( - endpoint({ - projectId: projectId!, - location, - model, - }), - { - method: 'POST', - body: JSON.stringify(req), - headers: { - Authorization: `Bearer ${accessToken}`, - 'Content-Type': 'application/json', - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - } - ); - - if (!response.ok) { - throw new Error( - `Error from Vertex AI predict: HTTP ${ - response.status - }: ${await response.text()}` - ); - } - - return (await response.json()) as PredictionResponse; - }; -} diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 77c9da06a..1bd683e9c 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -195,18 +195,9 @@ importers: plugins/checks: dependencies: - '@anthropic-ai/sdk': - specifier: ^0.24.3 - version: 0.24.3(encoding@0.1.13) - '@anthropic-ai/vertex-sdk': - specifier: ^0.4.0 - version: 0.4.0(encoding@0.1.13) '@google-cloud/aiplatform': specifier: ^3.23.0 version: 3.25.0(encoding@0.1.13) - '@google-cloud/vertexai': - specifier: ^1.1.0 - version: 1.1.0(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit @@ -219,9 +210,6 @@ importers: node-fetch: specifier: ^3.3.2 version: 3.3.2 - openai: - specifier: ^4.52.7 - version: 4.53.0(encoding@0.1.13) optionalDependencies: '@google-cloud/bigquery': specifier: ^7.8.0 @@ -763,6 +751,9 @@ importers: testapps/byo-evaluator: dependencies: + '@genkit-ai/checks': + specifier: workspace:* + version: link:../../plugins/checks '@genkit-ai/dev-local-vectorstore': specifier: workspace:* version: link:../../plugins/dev-local-vectorstore @@ -1210,7 +1201,7 @@ importers: version: link:../../plugins/ollama genkitx-openai: specifier: ^0.10.1 - version: 0.10.1(@genkit-ai/ai@0.9.0-dev.1)(@genkit-ai/core@0.9.0-dev.1) + version: 0.10.1(@genkit-ai/ai@ai)(@genkit-ai/core@core) devDependencies: rimraf: specifier: ^6.0.1 @@ -1977,12 +1968,6 @@ packages: '@firebase/util@1.9.5': resolution: {integrity: sha512-PP4pAFISDxsf70l3pEy34Mf3GkkUcVQ3MdKp6aSVb7tcpfUQxnsdV7twDd8EkfB6zZylH6wpUAoangQDmCUMqw==} - '@genkit-ai/ai@0.9.0-dev.1': - resolution: {integrity: sha512-ETAlyS/tX5bvv9NrPZ+6cuDStNwy5Yl2CBZjoXQle0jBuBCQr3HLjUH8ntbBX55E8mCQ+5A6Bpi2TXOx1yu1dw==} - - '@genkit-ai/core@0.9.0-dev.1': - resolution: {integrity: sha512-zWlzCaAKpNRwtMrZaA2h0o0yx4uj9OBmPhN5vMUTipWsaKIF1A3STvzRjxz4vFF2U87Uzvl2287JqyUNEXwQbA==} - '@google-cloud/aiplatform@3.25.0': resolution: {integrity: sha512-qKnJgbyCENjed8e1G5zZGFTxxNKhhaKQN414W2KIVHrLxMFmlMuG+3QkXPOWwXBnT5zZ7aMxypt5og0jCirpHg==} engines: {node: '>=14.0.0'} @@ -2520,24 +2505,12 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/context-async-hooks@1.26.0': - resolution: {integrity: sha512-HedpXXYzzbaoutw6DFLWLDket2FwLkLpil4hGCZ1xYEIMTcivdfwEOISgdbLEWyG3HW52gTq2V9mOVJrONgiwg==} - engines: {node: '>=14'} - peerDependencies: - '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/core@1.25.1': resolution: {integrity: sha512-GeT/l6rBYWVQ4XArluLVB6WWQ8flHbdb6r2FCHC3smtdOAbrJBIv35tpV/yp9bmYUJf+xmZpu9DRTIeJVhFbEQ==} engines: {node: '>=14'} peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/core@1.26.0': - resolution: {integrity: sha512-1iKxXXE8415Cdv0yjG3G6hQnB5eVEsJce3QaawX8SjDn0mAS0ZM8fAbZZJD4ajvhC15cePvosSCut404KrIIvQ==} - engines: {node: '>=14'} - peerDependencies: - '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/exporter-trace-otlp-grpc@0.52.1': resolution: {integrity: sha512-pVkSH20crBwMTqB3nIN4jpQKUEoB0Z94drIHpYyEqs7UBr+I0cpYyOR3bqjA/UasQUMROb3GX8ZX4/9cVRqGBQ==} engines: {node: '>=14'} @@ -2884,12 +2857,6 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/resources@1.26.0': - resolution: {integrity: sha512-CPNYchBE7MBecCSVy0HKpUISEeJOniWqcHaAHpmasZ3j9o6V3AyBzhRc90jdmemq0HOxDr6ylhUbDhBqqPpeNw==} - engines: {node: '>=14'} - peerDependencies: - '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/sdk-logs@0.52.1': resolution: {integrity: sha512-MBYh+WcPPsN8YpRHRmK1Hsca9pVlyyKd4BxOC4SsgHACnl/bPp4Cri9hWhVm5+2tiQ9Zf4qSc1Jshw9tOLGWQA==} engines: {node: '>=14'} @@ -2902,12 +2869,6 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.3.0 <1.10.0' - '@opentelemetry/sdk-metrics@1.26.0': - resolution: {integrity: sha512-0SvDXmou/JjzSDOjUmetAAvcKQW6ZrvosU0rkbDGpXvvZN+pQF6JbK/Kd4hNdK4q/22yeruqvukXEJyySTzyTQ==} - engines: {node: '>=14'} - peerDependencies: - '@opentelemetry/api': '>=1.3.0 <1.10.0' - '@opentelemetry/sdk-node@0.52.1': resolution: {integrity: sha512-uEG+gtEr6eKd8CVWeKMhH2olcCHM9dEK68pe0qE0be32BcCRsvYURhHaD1Srngh1SQcnQzZ4TP324euxqtBOJA==} engines: {node: '>=14'} @@ -2920,12 +2881,6 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/sdk-trace-base@1.26.0': - resolution: {integrity: sha512-olWQldtvbK4v22ymrKLbIcBi9L2SpMO84sCPY54IVsJhP9fRsxJT194C/AVaAuJzLE30EdhhM1VmvVYR7az+cw==} - engines: {node: '>=14'} - peerDependencies: - '@opentelemetry/api': '>=1.0.0 <1.10.0' - '@opentelemetry/sdk-trace-node@1.25.1': resolution: {integrity: sha512-nMcjFIKxnFqoez4gUmihdBrbpsEnAX/Xj16sGvZm+guceYE0NE00vLhpDVK6f3q8Q4VFI5xG8JjlXKMB/SkTTQ==} engines: {node: '>=14'} @@ -2944,10 +2899,6 @@ packages: resolution: {integrity: sha512-U9PJlOswJPSgQVPI+XEuNLElyFWkb0hAiMg+DExD9V0St03X2lPHGMdxMY/LrVmoukuIpXJ12oyrOtEZ4uXFkw==} engines: {node: '>=14'} - '@opentelemetry/semantic-conventions@1.27.0': - resolution: {integrity: sha512-sAay1RrB+ONOem0OZanAR1ZI/k7yDpnOQSQmTMuGImUQb2y8EbSaCJ94FQluM74xoU03vlb2d2U90hZluL6nQg==} - engines: {node: '>=14'} - '@opentelemetry/sql-common@0.40.1': resolution: {integrity: sha512-nSDlnHSqzC3pXn/wZEZVLuAuJ1MYMXPBwtv2qAbCa3847SaHItdE7SzUq/Jtb0KZmh1zfAbNi3AAMjztTT4Ugg==} engines: {node: '>=14'} @@ -3354,9 +3305,6 @@ packages: ajv@8.12.0: resolution: {integrity: sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==} - ajv@8.17.1: - resolution: {integrity: sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==} - ansi-escapes@4.3.2: resolution: {integrity: sha512-gKXj5ALrKWQLsYG9jlTRmR/xKluxHV+Z9QEwNIgCfM1/uwPMCuzVVnh5mwTd+OuBZcwSIMbqssNWRm1lE51QaQ==} engines: {node: '>=8'} @@ -4026,9 +3974,6 @@ packages: fast-text-encoding@1.0.6: resolution: {integrity: sha512-VhXlQgj9ioXCqGstD37E/HBeqEGV/qOD/kmbVG8h5xKBYvM1L3lR1Zn4555cQ8GkYbJa8aJSipLPndE1k6zK2w==} - fast-uri@3.0.1: - resolution: {integrity: sha512-MWipKbbYiYI0UC7cl8m/i/IWTqfC8YXsqjzybjddLsFjStroQzsHXkc73JutMvBiXmOvapk+axIl79ig5t55Bw==} - fast-xml-parser@4.3.6: resolution: {integrity: sha512-M2SovcRxD4+vC493Uc2GZVcZaj66CCJhWurC4viynVSTvrpErCShNcDz1lAho6n9REQKvL/ll4A4/fw6Y9z8nw==} hasBin: true @@ -6304,11 +6249,6 @@ packages: peerDependencies: zod: ^3.22.4 - zod-to-json-schema@3.23.3: - resolution: {integrity: sha512-TYWChTxKQbRJp5ST22o/Irt9KC5nj7CdBKYB/AosCRdj/wxEMvv4NNaj9XVUHDOIp53ZxArGhnw5HMZziPFjog==} - peerDependencies: - zod: ^3.23.3 - zod@3.22.4: resolution: {integrity: sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==} @@ -6752,39 +6692,6 @@ snapshots: dependencies: tslib: 2.6.2 - '@genkit-ai/ai@0.9.0-dev.1': - dependencies: - '@genkit-ai/core': 0.9.0-dev.1 - '@opentelemetry/api': 1.9.0 - '@types/node': 20.16.9 - colorette: 2.0.20 - json5: 2.2.3 - node-fetch: 3.3.2 - partial-json: 0.1.7 - transitivePeerDependencies: - - supports-color - - '@genkit-ai/core@0.9.0-dev.1': - dependencies: - '@opentelemetry/api': 1.9.0 - '@opentelemetry/context-async-hooks': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-metrics': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-trace-base': 1.26.0(@opentelemetry/api@1.9.0) - ajv: 8.17.1 - ajv-formats: 3.0.1(ajv@8.17.1) - async-mutex: 0.5.0 - body-parser: 1.20.3 - cors: 2.8.5 - express: 4.21.0 - get-port: 5.1.0 - json-schema: 0.4.0 - zod: 3.23.8 - zod-to-json-schema: 3.23.3(zod@3.23.8) - transitivePeerDependencies: - - supports-color - '@google-cloud/aiplatform@3.25.0(encoding@0.1.13)': dependencies: google-gax: 4.3.7(encoding@0.1.13) @@ -7328,20 +7235,11 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/context-async-hooks@1.26.0(@opentelemetry/api@1.9.0)': - dependencies: - '@opentelemetry/api': 1.9.0 - '@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/semantic-conventions': 1.25.1 - '@opentelemetry/core@1.26.0(@opentelemetry/api@1.9.0)': - dependencies: - '@opentelemetry/api': 1.9.0 - '@opentelemetry/semantic-conventions': 1.27.0 - '@opentelemetry/exporter-trace-otlp-grpc@0.52.1(@opentelemetry/api@1.9.0)': dependencies: '@grpc/grpc-js': 1.10.10 @@ -7815,12 +7713,6 @@ snapshots: '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/semantic-conventions': 1.25.1 - '@opentelemetry/resources@1.26.0(@opentelemetry/api@1.9.0)': - dependencies: - '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 - '@opentelemetry/sdk-logs@0.52.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -7835,12 +7727,6 @@ snapshots: '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) lodash.merge: 4.6.2 - '@opentelemetry/sdk-metrics@1.26.0(@opentelemetry/api@1.9.0)': - dependencies: - '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-node@0.52.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -7867,13 +7753,6 @@ snapshots: '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/semantic-conventions': 1.25.1 - '@opentelemetry/sdk-trace-base@1.26.0(@opentelemetry/api@1.9.0)': - dependencies: - '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 - '@opentelemetry/sdk-trace-node@1.25.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -7890,8 +7769,6 @@ snapshots: '@opentelemetry/semantic-conventions@1.26.0': {} - '@opentelemetry/semantic-conventions@1.27.0': {} - '@opentelemetry/sql-common@0.40.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -8259,10 +8136,6 @@ snapshots: optionalDependencies: ajv: 8.12.0 - ajv-formats@3.0.1(ajv@8.17.1): - optionalDependencies: - ajv: 8.17.1 - ajv@8.12.0: dependencies: fast-deep-equal: 3.1.3 @@ -8270,13 +8143,6 @@ snapshots: require-from-string: 2.0.2 uri-js: 4.4.1 - ajv@8.17.1: - dependencies: - fast-deep-equal: 3.1.3 - fast-uri: 3.0.1 - json-schema-traverse: 1.0.0 - require-from-string: 2.0.2 - ansi-escapes@4.3.2: dependencies: type-fest: 0.21.3 @@ -9110,8 +8976,6 @@ snapshots: fast-text-encoding@1.0.6: optional: true - fast-uri@3.0.1: {} - fast-xml-parser@4.3.6: dependencies: strnum: 1.0.5 @@ -9330,10 +9194,10 @@ snapshots: - encoding - supports-color - genkitx-openai@0.10.1(@genkit-ai/ai@0.9.0-dev.1)(@genkit-ai/core@0.9.0-dev.1): + genkitx-openai@0.10.1(@genkit-ai/ai@ai)(@genkit-ai/core@core): dependencies: - '@genkit-ai/ai': 0.9.0-dev.1 - '@genkit-ai/core': 0.9.0-dev.1 + '@genkit-ai/ai': link:ai + '@genkit-ai/core': link:core openai: 4.53.0(encoding@0.1.13) zod: 3.23.8 transitivePeerDependencies: @@ -11781,10 +11645,6 @@ snapshots: dependencies: zod: 3.23.8 - zod-to-json-schema@3.23.3(zod@3.23.8): - dependencies: - zod: 3.23.8 - zod@3.22.4: {} zod@3.23.8: {} diff --git a/js/testapps/byo-evaluator/package.json b/js/testapps/byo-evaluator/package.json index 53c996e56..c4ba52e44 100644 --- a/js/testapps/byo-evaluator/package.json +++ b/js/testapps/byo-evaluator/package.json @@ -19,6 +19,7 @@ "@genkit-ai/firebase": "workspace:*", "@genkit-ai/googleai": "workspace:*", "@genkit-ai/vertexai": "workspace:*", + "@genkit-ai/checks": "workspace:*", "genkit": "workspace:*", "path": "^0.12.7" }, diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index 9f9e4fdb8..f60292e24 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -36,6 +36,7 @@ import { isRegexMetric, regexMatcher, } from './regex/regex_evaluator.js'; +import {checks, ChecksEvaluationMetricType} from "@genkit-ai/checks" export const ai = genkit({ plugins: [ @@ -54,6 +55,13 @@ export const ai = genkit({ FUNNINESS, ], }), + checks({ + location: "us-central1", + projectId: "checks-prod", + evaluation: { + metrics:[ChecksEvaluationMetricType.SAFETY], + }, + }) ], }); From 61754b5efd3f7012079eda53aaf9d0fb3e4a72d7 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Thu, 31 Oct 2024 19:53:28 +0000 Subject: [PATCH 04/30] hardcoded example for calling the checks api is working. Needs to be cleaned up and turned into working code. --- js/plugins/checks/src/evaluation.ts | 328 --------------------- js/plugins/checks/src/evaluator_factory.ts | 26 +- js/plugins/checks/src/index.ts | 10 +- 3 files changed, 33 insertions(+), 331 deletions(-) diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 3804c8df1..09eb041e2 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -24,14 +24,7 @@ import { EvaluatorFactory } from './evaluator_factory.js'; */ export enum ChecksEvaluationMetricType { // TODO: Change to match checks policies. - BLEU = 'BLEU', - ROUGE = 'ROUGE', - FLUENCY = 'FLEUNCY', SAFETY = 'SAFETY', - GROUNDEDNESS = 'GROUNDEDNESS', - SUMMARIZATION_QUALITY = 'SUMMARIZATION_QUALITY', - SUMMARIZATION_HELPFULNESS = 'SUMMARIZATION_HELPFULNESS', - SUMMARIZATION_VERBOSITY = 'SUMMARIZATION_VERBOSITY', } /** @@ -62,30 +55,9 @@ export function checksEvaluators( const metricSpec = isConfig(metric) ? metric.metricSpec : {}; switch (metricType) { - case ChecksEvaluationMetricType.BLEU: { - return createBleuEvaluator(ai, factory, metricSpec); - } - case ChecksEvaluationMetricType.ROUGE: { - return createRougeEvaluator(ai, factory, metricSpec); - } - case ChecksEvaluationMetricType.FLUENCY: { - return createFluencyEvaluator(ai, factory, metricSpec); - } case ChecksEvaluationMetricType.SAFETY: { return createSafetyEvaluator(ai, factory, metricSpec); } - case ChecksEvaluationMetricType.GROUNDEDNESS: { - return createGroundednessEvaluator(ai, factory, metricSpec); - } - case ChecksEvaluationMetricType.SUMMARIZATION_QUALITY: { - return createSummarizationQualityEvaluator(ai, factory, metricSpec); - } - case ChecksEvaluationMetricType.SUMMARIZATION_HELPFULNESS: { - return createSummarizationHelpfulnessEvaluator(ai, factory, metricSpec); - } - case ChecksEvaluationMetricType.SUMMARIZATION_VERBOSITY: { - return createSummarizationVerbosityEvaluator(ai, factory, metricSpec); - } } }); } @@ -96,129 +68,6 @@ function isConfig( return (config as ChecksEvaluationMetricConfig).type !== undefined; } -const BleuResponseSchema = z.object({ - bleuResults: z.object({ - bleuMetricValues: z.array(z.object({ score: z.number() })), - }), -}); - -// TODO: Add support for batch inputs -function createBleuEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - metricSpec: any -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.BLEU, - displayName: 'BLEU', - definition: - 'Computes the BLEU score by comparing the output against the ground truth', - responseSchema: BleuResponseSchema, - }, - (datapoint) => { - return { - bleuInput: { - metricSpec, - instances: [ - { - prediction: datapoint.output as string, - reference: datapoint.reference, - }, - ], - }, - }; - }, - (response) => { - return { - score: response.bleuResults.bleuMetricValues[0].score, - }; - } - ); -} - -const RougeResponseSchema = z.object({ - rougeResults: z.object({ - rougeMetricValues: z.array(z.object({ score: z.number() })), - }), -}); - -// TODO: Add support for batch inputs -function createRougeEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - metricSpec: any -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.ROUGE, - displayName: 'ROUGE', - definition: - 'Computes the ROUGE score by comparing the output against the ground truth', - responseSchema: RougeResponseSchema, - }, - (datapoint) => { - return { - rougeInput: { - metricSpec, - instances: { - prediction: datapoint.output as string, - reference: datapoint.reference, - }, - }, - }; - }, - (response) => { - return { - score: response.rougeResults.rougeMetricValues[0].score, - }; - } - ); -} - -const FluencyResponseSchema = z.object({ - fluencyResult: z.object({ - score: z.number(), - explanation: z.string(), - confidence: z.number(), - }), -}); - -function createFluencyEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - metricSpec: any -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.FLUENCY, - displayName: 'Fluency', - definition: 'Assesses the language mastery of an output', - responseSchema: FluencyResponseSchema, - }, - (datapoint) => { - return { - fluencyInput: { - metricSpec, - instance: { - prediction: datapoint.output as string, - }, - }, - }; - }, - (response) => { - return { - score: response.fluencyResult.score, - details: { - reasoning: response.fluencyResult.explanation, - }, - }; - } - ); -} const SafetyResponseSchema = z.object({ safetyResult: z.object({ @@ -261,180 +110,3 @@ function createSafetyEvaluator( } ); } - -const GroundednessResponseSchema = z.object({ - groundednessResult: z.object({ - score: z.number(), - explanation: z.string(), - confidence: z.number(), - }), -}); - -function createGroundednessEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - metricSpec: any -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.GROUNDEDNESS, - displayName: 'Groundedness', - definition: - 'Assesses the ability to provide or reference information included only in the context', - responseSchema: GroundednessResponseSchema, - }, - (datapoint) => { - return { - groundednessInput: { - metricSpec, - instance: { - prediction: datapoint.output as string, - context: datapoint.context?.join('. '), - }, - }, - }; - }, - (response) => { - return { - score: response.groundednessResult.score, - details: { - reasoning: response.groundednessResult.explanation, - }, - }; - } - ); -} - -const SummarizationQualityResponseSchema = z.object({ - summarizationQualityResult: z.object({ - score: z.number(), - explanation: z.string(), - confidence: z.number(), - }), -}); - -function createSummarizationQualityEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - metricSpec: any -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.SUMMARIZATION_QUALITY, - displayName: 'Summarization quality', - definition: 'Assesses the overall ability to summarize text', - responseSchema: SummarizationQualityResponseSchema, - }, - (datapoint) => { - return { - summarizationQualityInput: { - metricSpec, - instance: { - prediction: datapoint.output as string, - instruction: datapoint.input as string, - context: datapoint.context?.join('. '), - }, - }, - }; - }, - (response) => { - return { - score: response.summarizationQualityResult.score, - details: { - reasoning: response.summarizationQualityResult.explanation, - }, - }; - } - ); -} - -const SummarizationHelpfulnessResponseSchema = z.object({ - summarizationHelpfulnessResult: z.object({ - score: z.number(), - explanation: z.string(), - confidence: z.number(), - }), -}); - -function createSummarizationHelpfulnessEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - metricSpec: any -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.SUMMARIZATION_HELPFULNESS, - displayName: 'Summarization helpfulness', - definition: - 'Assesses the ability to provide a summarization, which contains the details necessary to substitute the original text', - responseSchema: SummarizationHelpfulnessResponseSchema, - }, - (datapoint) => { - return { - summarizationHelpfulnessInput: { - metricSpec, - instance: { - prediction: datapoint.output as string, - instruction: datapoint.input as string, - context: datapoint.context?.join('. '), - }, - }, - }; - }, - (response) => { - return { - score: response.summarizationHelpfulnessResult.score, - details: { - reasoning: response.summarizationHelpfulnessResult.explanation, - }, - }; - } - ); -} - -const SummarizationVerbositySchema = z.object({ - summarizationVerbosityResult: z.object({ - score: z.number(), - explanation: z.string(), - confidence: z.number(), - }), -}); - -function createSummarizationVerbosityEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - metricSpec: any -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.SUMMARIZATION_VERBOSITY, - displayName: 'Summarization verbosity', - definition: 'Aassess the ability to provide a succinct summarization', - responseSchema: SummarizationVerbositySchema, - }, - (datapoint) => { - return { - summarizationVerbosityInput: { - metricSpec, - instance: { - prediction: datapoint.output as string, - instruction: datapoint.input as string, - context: datapoint.context?.join('. '), - }, - }, - }; - }, - (response) => { - return { - score: response.summarizationVerbosityResult.score, - details: { - reasoning: response.summarizationVerbosityResult.explanation, - }, - }; - } - ); -} diff --git a/js/plugins/checks/src/evaluator_factory.ts b/js/plugins/checks/src/evaluator_factory.ts index 2f74ff4c8..843dd2834 100644 --- a/js/plugins/checks/src/evaluator_factory.ts +++ b/js/plugins/checks/src/evaluator_factory.ts @@ -40,7 +40,7 @@ export class EvaluatorFactory { ): Action { return ai.defineEvaluator( { - name: `vertexai/${config.metric.toLocaleLowerCase()}`, + name: `checks/${config.metric.toLocaleLowerCase()}`, displayName: config.displayName, definition: config.definition, }, @@ -64,6 +64,7 @@ export class EvaluatorFactory { responseSchema: ResponseType ): Promise> { const locationName = `projects/${this.projectId}/locations/${this.location}`; + return await runInNewSpan( { metadata: { @@ -76,6 +77,7 @@ export class EvaluatorFactory { ...partialRequest, }; + metadata.input = request; const client = await this.auth.getClient(); const url = `https://${this.location}-aiplatform.googleapis.com/v1beta1/${locationName}:evaluateInstances`; @@ -89,8 +91,30 @@ export class EvaluatorFactory { }); metadata.output = response.data; + const checksResponse = await client.request({ + url: "https://checks.googleapis.com/v1alpha/aisafety:classifyContent", + method: "POST", + body: `{ + "input": { + "text_input": { + "content": "I hate you and all people on earth" + } + }, + "policies": { "policy_type": "HARASSMENT" } + }`, + headers: { + "X-Goog-User-Project": "checks-api-370419", + "Content-Type": "application/json", + } + }) + try { return responseSchema.parse(response.data); + // return responseSchema.parse({ + // score: 1, + // explanation: "the explanation", + // confidence: 100 + // }); } catch (e) { throw new Error(`Error parsing ${url} API response: ${e}`); } diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index 3a44aa7ff..73488a077 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -42,6 +42,8 @@ export interface PluginOptions { const CLOUD_PLATFROM_OAUTH_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'; +const CHECKS_OAUTH_SCOPE = 'https://www.googleapis.com/auth/checks'; + /** * Add Google Checks evaluators. */ @@ -53,20 +55,24 @@ export function checks(options?: PluginOptions): GenkitPlugin { // Allow customers to pass in cloud credentials from environment variables // following: https://github.com/googleapis/google-auth-library-nodejs?tab=readme-ov-file#loading-credentials-from-environment-variables if (process.env.GCLOUD_SERVICE_ACCOUNT_CREDS) { + console.log("HSH initilizing google auth via path 1") const serviceAccountCreds = JSON.parse( process.env.GCLOUD_SERVICE_ACCOUNT_CREDS ); authOptions = { credentials: serviceAccountCreds, - scopes: [CLOUD_PLATFROM_OAUTH_SCOPE], + scopes: [CLOUD_PLATFROM_OAUTH_SCOPE, CHECKS_OAUTH_SCOPE], }; authClient = new GoogleAuth(authOptions); } else { + console.log("HSH initilizing google auth via path 2") authClient = new GoogleAuth( - authOptions ?? { scopes: [CLOUD_PLATFROM_OAUTH_SCOPE] } + authOptions ?? { scopes: [CLOUD_PLATFROM_OAUTH_SCOPE, CHECKS_OAUTH_SCOPE] } ); } + console.log("HSH Google auth client initialized: ", authClient) + const projectId = options?.projectId || (await authClient.getProjectId()); const location = options?.location || 'us-central1'; From abcf80f8183ab104e6b2426ad05f8ad82681f01b Mon Sep 17 00:00:00 2001 From: hunterheston Date: Fri, 1 Nov 2024 19:15:18 +0000 Subject: [PATCH 05/30] Fully working evaluations. Lots of cleanup. Need to decide on how to handle multiple policies for the same text input. --- js/plugins/checks/src/evaluation.ts | 62 +++++++++++ js/plugins/checks/src/evaluator_factory.ts | 114 +++++++++++++++++++-- js/testapps/byo-evaluator/src/index.ts | 4 +- 3 files changed, 172 insertions(+), 8 deletions(-) diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 09eb041e2..049a65738 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -25,6 +25,7 @@ import { EvaluatorFactory } from './evaluator_factory.js'; export enum ChecksEvaluationMetricType { // TODO: Change to match checks policies. SAFETY = 'SAFETY', + HARASSMENT = 'HARASSMENT', } /** @@ -58,6 +59,9 @@ export function checksEvaluators( case ChecksEvaluationMetricType.SAFETY: { return createSafetyEvaluator(ai, factory, metricSpec); } + case ChecksEvaluationMetricType.HARASSMENT: { + return createHarassmentEvaluator(ai, factory, metricSpec) + } } }); } @@ -69,6 +73,64 @@ function isConfig( } +//TODO: this is the schema: +// { +// policyResults: [ +// { +// policyType: 'HARASSMENT', +// score: 0.31868133, +// violationResult: 'NON_VIOLATIVE' +// } +// ] +// } + +const HarassmentResponseSchema = z.object({ + policyResults: z.array( + z.object({ + policyType: z.string(), + score: z.number(), + violationResult: z.string() + }) + ) +}); + +function createHarassmentEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: ChecksEvaluationMetricType.HARASSMENT, + displayName: 'Harassment', + definition: 'Assesses the text constittues harassment.', + responseSchema: HarassmentResponseSchema, + checksEval: true + }, + (datapoint) => { + return { + input: { + text_input: { + content: datapoint.output as string + }, + }, + policies: { + policy_type: "HARASSMENT", + } + }; + }, + (response) => { + return { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } + }; + } + ); +} + const SafetyResponseSchema = z.object({ safetyResult: z.object({ score: z.number(), diff --git a/js/plugins/checks/src/evaluator_factory.ts b/js/plugins/checks/src/evaluator_factory.ts index 843dd2834..af16bdd28 100644 --- a/js/plugins/checks/src/evaluator_factory.ts +++ b/js/plugins/checks/src/evaluator_factory.ts @@ -25,7 +25,7 @@ export class EvaluatorFactory { private readonly auth: GoogleAuth, private readonly location: string, private readonly projectId: string - ) {} + ) { } create( ai: Genkit, @@ -34,6 +34,7 @@ export class EvaluatorFactory { displayName: string; definition: string; responseSchema: ResponseType; + checksEval?: boolean; }, toRequest: (datapoint: BaseEvalDataPoint) => any, responseHandler: (response: z.infer) => Score @@ -46,10 +47,19 @@ export class EvaluatorFactory { }, async (datapoint: BaseEvalDataPoint) => { const responseSchema = config.responseSchema; - const response = await this.evaluateInstances( - toRequest(datapoint), - responseSchema - ); + let response; + + if (config.checksEval) { + response = await this.checksEvalInstance( + toRequest(datapoint), + responseSchema + ); + } else { + response = await this.evaluateInstances( + toRequest(datapoint), + responseSchema + ); + } return { evaluation: responseHandler(response), @@ -59,12 +69,80 @@ export class EvaluatorFactory { ); } + + async checksEvalInstance( + partialRequest: any, + responseSchema: ResponseType + ): Promise> { + + console.log('HSH::partialRequest: ', partialRequest) + return await runInNewSpan( + { + metadata: { + name: 'EvaluationService#evaluateInstances', + }, + }, + async (metadata, _otSpan) => { + const request = { + ...partialRequest, + }; + + console.log("HSH::request: ", request) + + /** + gcloud auth application-default login --scopes=https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/checks + + curl -X POST https://checks.googleapis.com/v1alpha/aisafety:classifyContent \ + -H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \ + -H "X-Goog-User-Project: checks-api-370419" \ + -H "Content-Type: application/json" \ + -d '{ \ + "input": { \ + "text_input": { \ + "content": "I hate you and all people on earth" \ + } \ + }, \ + "policies": { "policy_type": "HARASSMENT" } \ + }'\ + */ + + metadata.input = request; + const client = await this.auth.getClient(); + const url = "https://checks.googleapis.com/v1alpha/aisafety:classifyContent" + + const response = await client.request({ + url, + method: "POST", + body: JSON.stringify(request), + headers: { + "X-Goog-User-Project": "checks-api-370419", + "Content-Type": "application/json", + } + }) + metadata.output = response.data; + + console.log("HSH::response: ", response) + console.log("HSH::response.data: ", response.data) + + // console.log("HSH::response: ", response) + // console.log("HSH::metadata: ", metadata) + + try { + return responseSchema.parse(response.data); + } catch (e) { + throw new Error(`Error parsing ${url} API response: ${e}`); + } + } + ); + } + async evaluateInstances( partialRequest: any, responseSchema: ResponseType ): Promise> { const locationName = `projects/${this.projectId}/locations/${this.location}`; + console.log('HSH::partialRequest: ', partialRequest) return await runInNewSpan( { metadata: { @@ -77,6 +155,24 @@ export class EvaluatorFactory { ...partialRequest, }; + console.log("HSH::request: ", request) + + /** + gcloud auth application-default login --scopes=https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/checks + + curl -X POST https://checks.googleapis.com/v1alpha/aisafety:classifyContent \ + -H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \ + -H "X-Goog-User-Project: checks-api-370419" \ + -H "Content-Type: application/json" \ + -d '{ \ + "input": { \ + "text_input": { \ + "content": "I hate you and all people on earth" \ + } \ + }, \ + "policies": { "policy_type": "HARASSMENT" } \ + }'\ + */ metadata.input = request; const client = await this.auth.getClient(); @@ -107,7 +203,13 @@ export class EvaluatorFactory { "Content-Type": "application/json", } }) - + + console.log("HSH::checksResponse: ", checksResponse) + console.log("HSH::checksResponse.data: ", checksResponse.data) + + // console.log("HSH::response: ", response) + // console.log("HSH::metadata: ", metadata) + try { return responseSchema.parse(response.data); // return responseSchema.parse({ diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index f60292e24..e0cb71ad2 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -36,7 +36,7 @@ import { isRegexMetric, regexMatcher, } from './regex/regex_evaluator.js'; -import {checks, ChecksEvaluationMetricType} from "@genkit-ai/checks" +import { checks, ChecksEvaluationMetricType } from "@genkit-ai/checks" export const ai = genkit({ plugins: [ @@ -59,7 +59,7 @@ export const ai = genkit({ location: "us-central1", projectId: "checks-prod", evaluation: { - metrics:[ChecksEvaluationMetricType.SAFETY], + metrics: [ChecksEvaluationMetricType.SAFETY, ChecksEvaluationMetricType.HARASSMENT], }, }) ], From 8c8c1f091d402c31dd85640d2eca275e0a09e7ec Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 4 Nov 2024 17:08:55 +0000 Subject: [PATCH 06/30] Add all policies --- js/plugins/checks/src/evaluation.ts | 309 ++++++++++++++++++--- js/plugins/checks/src/evaluator_factory.ts | 105 +------ js/testapps/byo-evaluator/src/index.ts | 11 +- 3 files changed, 287 insertions(+), 138 deletions(-) diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 049a65738..2761fac96 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -23,9 +23,27 @@ import { EvaluatorFactory } from './evaluator_factory.js'; * TODO: add documentation link. */ export enum ChecksEvaluationMetricType { - // TODO: Change to match checks policies. - SAFETY = 'SAFETY', + // The model facilitates, promotes or enables access to harmful goods, + // services, and activities. + DANGEROUS_CONTENT = 'DANGEROUS_CONTENT', + // The model reveals an individual’s personal information and data. + PII_SOLICITING_RECITING = 'PII_SOLICITING_RECITING', + // The model generates content that is malicious, intimidating, bullying, or + // abusive towards another individual. HARASSMENT = 'HARASSMENT', + // The model generates content that is sexually explicit in nature. + SEXUALLY_EXPLICIT = 'SEXUALLY_EXPLICIT', + // The model promotes violence, hatred, discrimination on the basis of race, + // religion, etc. + HATE_SPEECH = 'HATE_SPEECH', + // The model facilitates harm by providing health advice or guidance. + MEDICAL_INFO = 'MEDICAL_INFO', + // The model generates content that contains gratuitous, realistic + // descriptions of violence or gore. + VIOLENCE_AND_GORE = 'VIOLENCE_AND_GORE', + // The model generates content that contains vulgar, profane, or offensive + // language. + OBSCENITY_AND_PROFANITY = 'OBSCENITY_AND_PROFANITY', } /** @@ -56,12 +74,30 @@ export function checksEvaluators( const metricSpec = isConfig(metric) ? metric.metricSpec : {}; switch (metricType) { - case ChecksEvaluationMetricType.SAFETY: { - return createSafetyEvaluator(ai, factory, metricSpec); + case ChecksEvaluationMetricType.DANGEROUS_CONTENT: { + return createDangerousContentEvaluator(ai, factory, metricSpec) + } + case ChecksEvaluationMetricType.PII_SOLICITING_RECITING: { + return createPiiSolicitingEvaluator(ai, factory, metricSpec) } case ChecksEvaluationMetricType.HARASSMENT: { return createHarassmentEvaluator(ai, factory, metricSpec) } + case ChecksEvaluationMetricType.SEXUALLY_EXPLICIT: { + return createSexuallyExplicitEvaluator(ai, factory, metricSpec) + } + case ChecksEvaluationMetricType.HATE_SPEECH: { + return createHateSpeachEvaluator(ai, factory, metricSpec) + } + case ChecksEvaluationMetricType.MEDICAL_INFO: { + return createMedicalInfoEvaluator(ai, factory, metricSpec) + } + case ChecksEvaluationMetricType.VIOLENCE_AND_GORE: { + return createViolenceAndGoreEvaluator(ai, factory, metricSpec) + } + case ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY: { + return createObscenityAndProfanityEvaluator(ai, factory, metricSpec) + } } }); } @@ -72,19 +108,7 @@ function isConfig( return (config as ChecksEvaluationMetricConfig).type !== undefined; } - -//TODO: this is the schema: -// { -// policyResults: [ -// { -// policyType: 'HARASSMENT', -// score: 0.31868133, -// violationResult: 'NON_VIOLATIVE' -// } -// ] -// } - -const HarassmentResponseSchema = z.object({ +const ResponseSchema = z.object({ policyResults: z.array( z.object({ policyType: z.string(), @@ -94,6 +118,78 @@ const HarassmentResponseSchema = z.object({ ) }); +function createDangerousContentEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: ChecksEvaluationMetricType.DANGEROUS_CONTENT, + displayName: 'Dangerous Content', + definition: 'Assesses the text constittues dangerous content.', + responseSchema: ResponseSchema, + }, + (datapoint) => { + return { + input: { + text_input: { + content: datapoint.output as string + }, + }, + policies: { + policy_type: "DANGEROUS_CONTENT", + } + }; + }, + (response) => { + return { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } + }; + } + ); +} + +function createPiiSolicitingEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: ChecksEvaluationMetricType.PII_SOLICITING_RECITING, + displayName: 'PII soliciting reciting', + definition: 'Assesses the text constittues PII solicitation.', + responseSchema: ResponseSchema, + }, + (datapoint) => { + return { + input: { + text_input: { + content: datapoint.output as string + }, + }, + policies: { + policy_type: "PII_SOLICITING_RECITING", + } + }; + }, + (response) => { + return { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } + }; + } + ); +} + function createHarassmentEvaluator( ai: Genkit, factory: EvaluatorFactory, @@ -105,8 +201,7 @@ function createHarassmentEvaluator( metric: ChecksEvaluationMetricType.HARASSMENT, displayName: 'Harassment', definition: 'Assesses the text constittues harassment.', - responseSchema: HarassmentResponseSchema, - checksEval: true + responseSchema: ResponseSchema, }, (datapoint) => { return { @@ -131,15 +226,43 @@ function createHarassmentEvaluator( ); } -const SafetyResponseSchema = z.object({ - safetyResult: z.object({ - score: z.number(), - explanation: z.string(), - confidence: z.number(), - }), -}); +function createSexuallyExplicitEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, + displayName: 'Sexually explicit', + definition: 'Assesses the text is sexually explicit.', + responseSchema: ResponseSchema, + }, + (datapoint) => { + return { + input: { + text_input: { + content: datapoint.output as string + }, + }, + policies: { + policy_type: "SEXUALLY_EXPLICIT", + } + }; + }, + (response) => { + return { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } + }; + } + ); +} -function createSafetyEvaluator( +function createHateSpeachEvaluator( ai: Genkit, factory: EvaluatorFactory, metricSpec: any @@ -147,28 +270,140 @@ function createSafetyEvaluator( return factory.create( ai, { - metric: ChecksEvaluationMetricType.SAFETY, - displayName: 'Safety', - definition: 'Assesses the level of safety of an output', - responseSchema: SafetyResponseSchema, + metric: ChecksEvaluationMetricType.HATE_SPEECH, + displayName: 'Sexually explicit', + definition: 'Assesses the text is sexually explicit.', + responseSchema: ResponseSchema, }, (datapoint) => { return { - safetyInput: { - metricSpec, - instance: { - prediction: datapoint.output as string, + input: { + text_input: { + content: datapoint.output as string + }, + }, + policies: { + policy_type: "HATE_SPEECH", + } + }; + }, + (response) => { + return { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } + }; + } + ); +} + +function createMedicalInfoEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: ChecksEvaluationMetricType.MEDICAL_INFO, + displayName: 'Sexually explicit', + definition: 'Assesses the text is sexually explicit.', + responseSchema: ResponseSchema, + }, + (datapoint) => { + return { + input: { + text_input: { + content: datapoint.output as string + }, + }, + policies: { + policy_type: "MEDICAL_INFO", + } + }; + }, + (response) => { + return { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } + }; + } + ); +} + +function createViolenceAndGoreEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: ChecksEvaluationMetricType.VIOLENCE_AND_GORE, + displayName: 'Sexually explicit', + definition: 'Assesses the text is sexually explicit.', + responseSchema: ResponseSchema, + }, + (datapoint) => { + return { + input: { + text_input: { + content: datapoint.output as string }, }, + policies: { + policy_type: "VIOLENCE_AND_GORE", + } }; }, (response) => { return { - score: response.safetyResult.score, + score: response.policyResults[0].score, details: { - reasoning: response.safetyResult.explanation, + reasoning: response.policyResults[0].violationResult + } + }; + } + ); +} + +function createObscenityAndProfanityEvaluator( + ai: Genkit, + factory: EvaluatorFactory, + metricSpec: any +): Action { + return factory.create( + ai, + { + metric: ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, + displayName: 'Sexually explicit', + definition: 'Assesses the text is sexually explicit.', + responseSchema: ResponseSchema, + }, + (datapoint) => { + return { + input: { + text_input: { + content: datapoint.output as string + }, }, + policies: { + policy_type: "OBSCENITY_AND_PROFANITY", + } + }; + }, + (response) => { + return { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } }; } ); } + + diff --git a/js/plugins/checks/src/evaluator_factory.ts b/js/plugins/checks/src/evaluator_factory.ts index af16bdd28..28af19e7d 100644 --- a/js/plugins/checks/src/evaluator_factory.ts +++ b/js/plugins/checks/src/evaluator_factory.ts @@ -34,7 +34,6 @@ export class EvaluatorFactory { displayName: string; definition: string; responseSchema: ResponseType; - checksEval?: boolean; }, toRequest: (datapoint: BaseEvalDataPoint) => any, responseHandler: (response: z.infer) => Score @@ -49,17 +48,11 @@ export class EvaluatorFactory { const responseSchema = config.responseSchema; let response; - if (config.checksEval) { - response = await this.checksEvalInstance( - toRequest(datapoint), - responseSchema - ); - } else { - response = await this.evaluateInstances( - toRequest(datapoint), - responseSchema - ); - } + response = await this.checksEvalInstance( + toRequest(datapoint), + responseSchema + ); + return { evaluation: responseHandler(response), @@ -135,92 +128,4 @@ export class EvaluatorFactory { } ); } - - async evaluateInstances( - partialRequest: any, - responseSchema: ResponseType - ): Promise> { - const locationName = `projects/${this.projectId}/locations/${this.location}`; - - console.log('HSH::partialRequest: ', partialRequest) - return await runInNewSpan( - { - metadata: { - name: 'EvaluationService#evaluateInstances', - }, - }, - async (metadata, _otSpan) => { - const request = { - location: locationName, - ...partialRequest, - }; - - console.log("HSH::request: ", request) - - /** - gcloud auth application-default login --scopes=https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/checks - - curl -X POST https://checks.googleapis.com/v1alpha/aisafety:classifyContent \ - -H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \ - -H "X-Goog-User-Project: checks-api-370419" \ - -H "Content-Type: application/json" \ - -d '{ \ - "input": { \ - "text_input": { \ - "content": "I hate you and all people on earth" \ - } \ - }, \ - "policies": { "policy_type": "HARASSMENT" } \ - }'\ - */ - - metadata.input = request; - const client = await this.auth.getClient(); - const url = `https://${this.location}-aiplatform.googleapis.com/v1beta1/${locationName}:evaluateInstances`; - const response = await client.request({ - url, - method: 'POST', - body: JSON.stringify(request), - headers: { - 'X-Goog-Api-Client': GENKIT_CLIENT_HEADER, - }, - }); - metadata.output = response.data; - - const checksResponse = await client.request({ - url: "https://checks.googleapis.com/v1alpha/aisafety:classifyContent", - method: "POST", - body: `{ - "input": { - "text_input": { - "content": "I hate you and all people on earth" - } - }, - "policies": { "policy_type": "HARASSMENT" } - }`, - headers: { - "X-Goog-User-Project": "checks-api-370419", - "Content-Type": "application/json", - } - }) - - console.log("HSH::checksResponse: ", checksResponse) - console.log("HSH::checksResponse.data: ", checksResponse.data) - - // console.log("HSH::response: ", response) - // console.log("HSH::metadata: ", metadata) - - try { - return responseSchema.parse(response.data); - // return responseSchema.parse({ - // score: 1, - // explanation: "the explanation", - // confidence: 100 - // }); - } catch (e) { - throw new Error(`Error parsing ${url} API response: ${e}`); - } - } - ); - } } diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index e0cb71ad2..57aa29499 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -59,7 +59,16 @@ export const ai = genkit({ location: "us-central1", projectId: "checks-prod", evaluation: { - metrics: [ChecksEvaluationMetricType.SAFETY, ChecksEvaluationMetricType.HARASSMENT], + metrics: [ + ChecksEvaluationMetricType.DANGEROUS_CONTENT, + ChecksEvaluationMetricType.PII_SOLICITING_RECITING, + ChecksEvaluationMetricType.HARASSMENT, + ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, + ChecksEvaluationMetricType.HATE_SPEECH, + ChecksEvaluationMetricType.MEDICAL_INFO, + ChecksEvaluationMetricType.VIOLENCE_AND_GORE, + ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, + ], }, }) ], From bb4b026acafd7005849db4aaf343ef0b8ad655b1 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Tue, 5 Nov 2024 18:03:13 +0000 Subject: [PATCH 07/30] add threshold to request. remove locaion. --- js/plugins/checks/src/evaluation.ts | 47 +++++++++++++--------- js/plugins/checks/src/evaluator_factory.ts | 1 - js/plugins/checks/src/index.ts | 14 +------ js/testapps/byo-evaluator/src/index.ts | 41 ++++++++++++++----- 4 files changed, 61 insertions(+), 42 deletions(-) diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 2761fac96..732e6ac3f 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -54,7 +54,7 @@ export enum ChecksEvaluationMetricType { */ export type ChecksEvaluationMetricConfig = { type: ChecksEvaluationMetricType; - metricSpec: any; + threshold: number; }; export type ChecksEvaluationMetric = @@ -66,37 +66,36 @@ export function checksEvaluators( auth: GoogleAuth, metrics: ChecksEvaluationMetric[], projectId: string, - location: string ): Action[] { - const factory = new EvaluatorFactory(auth, location, projectId); + const factory = new EvaluatorFactory(auth, projectId); return metrics.map((metric) => { const metricType = isConfig(metric) ? metric.type : metric; - const metricSpec = isConfig(metric) ? metric.metricSpec : {}; + const threshold = isConfig(metric) ? metric.threshold : undefined; switch (metricType) { case ChecksEvaluationMetricType.DANGEROUS_CONTENT: { - return createDangerousContentEvaluator(ai, factory, metricSpec) + return createDangerousContentEvaluator(ai, factory, threshold) } case ChecksEvaluationMetricType.PII_SOLICITING_RECITING: { - return createPiiSolicitingEvaluator(ai, factory, metricSpec) + return createPiiSolicitingEvaluator(ai, factory, threshold) } case ChecksEvaluationMetricType.HARASSMENT: { - return createHarassmentEvaluator(ai, factory, metricSpec) + return createHarassmentEvaluator(ai, factory, threshold) } case ChecksEvaluationMetricType.SEXUALLY_EXPLICIT: { - return createSexuallyExplicitEvaluator(ai, factory, metricSpec) + return createSexuallyExplicitEvaluator(ai, factory, threshold) } case ChecksEvaluationMetricType.HATE_SPEECH: { - return createHateSpeachEvaluator(ai, factory, metricSpec) + return createHateSpeachEvaluator(ai, factory, threshold) } case ChecksEvaluationMetricType.MEDICAL_INFO: { - return createMedicalInfoEvaluator(ai, factory, metricSpec) + return createMedicalInfoEvaluator(ai, factory, threshold) } case ChecksEvaluationMetricType.VIOLENCE_AND_GORE: { - return createViolenceAndGoreEvaluator(ai, factory, metricSpec) + return createViolenceAndGoreEvaluator(ai, factory, threshold) } case ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY: { - return createObscenityAndProfanityEvaluator(ai, factory, metricSpec) + return createObscenityAndProfanityEvaluator(ai, factory, threshold) } } }); @@ -121,7 +120,7 @@ const ResponseSchema = z.object({ function createDangerousContentEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -140,6 +139,7 @@ function createDangerousContentEvaluator( }, policies: { policy_type: "DANGEROUS_CONTENT", + threshold, } }; }, @@ -157,7 +157,7 @@ function createDangerousContentEvaluator( function createPiiSolicitingEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -176,6 +176,7 @@ function createPiiSolicitingEvaluator( }, policies: { policy_type: "PII_SOLICITING_RECITING", + threshold, } }; }, @@ -193,7 +194,7 @@ function createPiiSolicitingEvaluator( function createHarassmentEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -212,6 +213,7 @@ function createHarassmentEvaluator( }, policies: { policy_type: "HARASSMENT", + threshold, } }; }, @@ -229,7 +231,7 @@ function createHarassmentEvaluator( function createSexuallyExplicitEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -248,6 +250,7 @@ function createSexuallyExplicitEvaluator( }, policies: { policy_type: "SEXUALLY_EXPLICIT", + threshold, } }; }, @@ -265,7 +268,7 @@ function createSexuallyExplicitEvaluator( function createHateSpeachEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -284,6 +287,7 @@ function createHateSpeachEvaluator( }, policies: { policy_type: "HATE_SPEECH", + threshold, } }; }, @@ -301,7 +305,7 @@ function createHateSpeachEvaluator( function createMedicalInfoEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -320,6 +324,7 @@ function createMedicalInfoEvaluator( }, policies: { policy_type: "MEDICAL_INFO", + threshold, } }; }, @@ -337,7 +342,7 @@ function createMedicalInfoEvaluator( function createViolenceAndGoreEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -356,6 +361,7 @@ function createViolenceAndGoreEvaluator( }, policies: { policy_type: "VIOLENCE_AND_GORE", + threshold, } }; }, @@ -373,7 +379,7 @@ function createViolenceAndGoreEvaluator( function createObscenityAndProfanityEvaluator( ai: Genkit, factory: EvaluatorFactory, - metricSpec: any + threshold?: number ): Action { return factory.create( ai, @@ -392,6 +398,7 @@ function createObscenityAndProfanityEvaluator( }, policies: { policy_type: "OBSCENITY_AND_PROFANITY", + threshold, } }; }, diff --git a/js/plugins/checks/src/evaluator_factory.ts b/js/plugins/checks/src/evaluator_factory.ts index 28af19e7d..16a82908d 100644 --- a/js/plugins/checks/src/evaluator_factory.ts +++ b/js/plugins/checks/src/evaluator_factory.ts @@ -23,7 +23,6 @@ import { ChecksEvaluationMetricType } from './evaluation.js'; export class EvaluatorFactory { constructor( private readonly auth: GoogleAuth, - private readonly location: string, private readonly projectId: string ) { } diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index 73488a077..37c09f72f 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -27,10 +27,8 @@ export { }; export interface PluginOptions { - /** The Google Cloud project id to call. */ + /** The Google Cloud project id to call. This is the project with quota for the Checks API*/ projectId?: string; - /** The Google Cloud region to call. */ - location: string; /** Provide custom authentication configuration for connecting to Vertex AI. */ googleAuth?: GoogleAuthOptions; /** Configure Vertex AI evaluators */ @@ -55,7 +53,6 @@ export function checks(options?: PluginOptions): GenkitPlugin { // Allow customers to pass in cloud credentials from environment variables // following: https://github.com/googleapis/google-auth-library-nodejs?tab=readme-ov-file#loading-credentials-from-environment-variables if (process.env.GCLOUD_SERVICE_ACCOUNT_CREDS) { - console.log("HSH initilizing google auth via path 1") const serviceAccountCreds = JSON.parse( process.env.GCLOUD_SERVICE_ACCOUNT_CREDS ); @@ -65,25 +62,18 @@ export function checks(options?: PluginOptions): GenkitPlugin { }; authClient = new GoogleAuth(authOptions); } else { - console.log("HSH initilizing google auth via path 2") authClient = new GoogleAuth( authOptions ?? { scopes: [CLOUD_PLATFROM_OAUTH_SCOPE, CHECKS_OAUTH_SCOPE] } ); } - console.log("HSH Google auth client initialized: ", authClient) - const projectId = options?.projectId || (await authClient.getProjectId()); - const location = options?.location || 'us-central1'; const confError = (parameter: string, envVariableName: string) => { return new Error( `Checks Plugin is missing the '${parameter}' configuration. Please set the '${envVariableName}' environment variable or explicitly pass '${parameter}' into genkit config.` ); }; - if (!location) { - throw confError('location', 'GCLOUD_LOCATION'); - } if (!projectId) { throw confError('project', 'GCLOUD_PROJECT'); } @@ -92,7 +82,7 @@ export function checks(options?: PluginOptions): GenkitPlugin { options?.evaluation && options.evaluation.metrics.length > 0 ? options.evaluation.metrics : []; - checksEvaluators(ai, authClient, metrics, projectId, location); + checksEvaluators(ai, authClient, metrics, projectId); }); } diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index 57aa29499..24b137aaa 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -56,18 +56,41 @@ export const ai = genkit({ ], }), checks({ - location: "us-central1", projectId: "checks-prod", evaluation: { metrics: [ - ChecksEvaluationMetricType.DANGEROUS_CONTENT, - ChecksEvaluationMetricType.PII_SOLICITING_RECITING, - ChecksEvaluationMetricType.HARASSMENT, - ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, - ChecksEvaluationMetricType.HATE_SPEECH, - ChecksEvaluationMetricType.MEDICAL_INFO, - ChecksEvaluationMetricType.VIOLENCE_AND_GORE, - ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, + { + type: ChecksEvaluationMetricType.DANGEROUS_CONTENT, + threshold: .01, + }, + { + type: ChecksEvaluationMetricType.PII_SOLICITING_RECITING, + threshold: .01, + }, + { + type: ChecksEvaluationMetricType.HARASSMENT, + threshold: .01, + }, + { + type: ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, + threshold: .01, + }, + { + type: ChecksEvaluationMetricType.HATE_SPEECH, + threshold: .01, + }, + { + type: ChecksEvaluationMetricType.MEDICAL_INFO, + threshold: .01, + }, + { + type: ChecksEvaluationMetricType.VIOLENCE_AND_GORE, + threshold: .01, + }, + { + type: ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, + threshold: .01, + } ], }, }) From ba874e5c4e76a33c3d1a50ab6c287f282f445e05 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Wed, 6 Nov 2024 23:24:45 +0000 Subject: [PATCH 08/30] Remove factory since these are all basically the same request. --- js/plugins/checks/src/evaluation.ts | 369 ++++----------------- js/plugins/checks/src/evaluator_factory.ts | 130 -------- js/testapps/byo-evaluator/src/index.ts | 18 +- 3 files changed, 82 insertions(+), 435 deletions(-) delete mode 100644 js/plugins/checks/src/evaluator_factory.ts diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 732e6ac3f..3c5c71001 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -16,7 +16,8 @@ import { Action, Genkit, z } from 'genkit'; import { GoogleAuth } from 'google-auth-library'; -import { EvaluatorFactory } from './evaluator_factory.js'; +import { BaseEvalDataPoint } from 'genkit/evaluator'; +import { runInNewSpan } from 'genkit/tracing'; /** * Checks AI Safety policies. See API documentation for more information. @@ -54,7 +55,7 @@ export enum ChecksEvaluationMetricType { */ export type ChecksEvaluationMetricConfig = { type: ChecksEvaluationMetricType; - threshold: number; + threshold?: number; }; export type ChecksEvaluationMetric = @@ -67,38 +68,22 @@ export function checksEvaluators( metrics: ChecksEvaluationMetric[], projectId: string, ): Action[] { - const factory = new EvaluatorFactory(auth, projectId); - return metrics.map((metric) => { + + const policy_configs: ChecksEvaluationMetricConfig[] = metrics.map((metric) => { const metricType = isConfig(metric) ? metric.type : metric; const threshold = isConfig(metric) ? metric.threshold : undefined; - switch (metricType) { - case ChecksEvaluationMetricType.DANGEROUS_CONTENT: { - return createDangerousContentEvaluator(ai, factory, threshold) - } - case ChecksEvaluationMetricType.PII_SOLICITING_RECITING: { - return createPiiSolicitingEvaluator(ai, factory, threshold) - } - case ChecksEvaluationMetricType.HARASSMENT: { - return createHarassmentEvaluator(ai, factory, threshold) - } - case ChecksEvaluationMetricType.SEXUALLY_EXPLICIT: { - return createSexuallyExplicitEvaluator(ai, factory, threshold) - } - case ChecksEvaluationMetricType.HATE_SPEECH: { - return createHateSpeachEvaluator(ai, factory, threshold) - } - case ChecksEvaluationMetricType.MEDICAL_INFO: { - return createMedicalInfoEvaluator(ai, factory, threshold) - } - case ChecksEvaluationMetricType.VIOLENCE_AND_GORE: { - return createViolenceAndGoreEvaluator(ai, factory, threshold) - } - case ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY: { - return createObscenityAndProfanityEvaluator(ai, factory, threshold) - } + return { + type: metricType, + threshold, } }); + + const evaluators = policy_configs.map((policy_config) => { + return createPolicyEvaluator(projectId, auth, ai, policy_config) + }) + + return evaluators } function isConfig( @@ -117,300 +102,92 @@ const ResponseSchema = z.object({ ) }); -function createDangerousContentEvaluator( +function createPolicyEvaluator( + projectId: string, + auth: GoogleAuth, ai: Genkit, - factory: EvaluatorFactory, - threshold?: number + policy_config: ChecksEvaluationMetricConfig ): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.DANGEROUS_CONTENT, - displayName: 'Dangerous Content', - definition: 'Assesses the text constittues dangerous content.', - responseSchema: ResponseSchema, - }, - (datapoint) => { - return { - input: { - text_input: { - content: datapoint.output as string - }, - }, - policies: { - policy_type: "DANGEROUS_CONTENT", - threshold, - } - }; - }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult - } - }; - } - ); -} + const policyType = policy_config.type as string; -function createPiiSolicitingEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - threshold?: number -): Action { - return factory.create( - ai, + return ai.defineEvaluator( { - metric: ChecksEvaluationMetricType.PII_SOLICITING_RECITING, - displayName: 'PII soliciting reciting', - definition: 'Assesses the text constittues PII solicitation.', - responseSchema: ResponseSchema, + name: `checks/${policyType.toLowerCase()}`, + displayName: policyType, + definition: `Evaluates text against the Checks ${policyType} policy.` }, - (datapoint) => { - return { + async (datapoint: BaseEvalDataPoint) => { + const partialRequest = { input: { text_input: { content: datapoint.output as string }, }, policies: { - policy_type: "PII_SOLICITING_RECITING", - threshold, - } - }; - }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult + policy_type: policy_config.type, + threshold: policy_config.threshold, } }; - } - ); -} -function createHarassmentEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - threshold?: number -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.HARASSMENT, - displayName: 'Harassment', - definition: 'Assesses the text constittues harassment.', - responseSchema: ResponseSchema, - }, - (datapoint) => { - return { - input: { - text_input: { - content: datapoint.output as string - }, - }, - policies: { - policy_type: "HARASSMENT", - threshold, - } - }; - }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult - } - }; - } - ); -} + const response = await checksEvalInstance( + projectId, + auth, + partialRequest, + ResponseSchema + ); -function createSexuallyExplicitEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - threshold?: number -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, - displayName: 'Sexually explicit', - definition: 'Assesses the text is sexually explicit.', - responseSchema: ResponseSchema, - }, - (datapoint) => { return { - input: { - text_input: { - content: datapoint.output as string - }, + evaluation: { + score: response.policyResults[0].score, + details: { + reasoning: response.policyResults[0].violationResult + } }, - policies: { - policy_type: "SEXUALLY_EXPLICIT", - threshold, - } - }; - }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult - } + testCaseId: datapoint.testCaseId, }; } - ); + ) } -function createHateSpeachEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - threshold?: number -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.HATE_SPEECH, - displayName: 'Sexually explicit', - definition: 'Assesses the text is sexually explicit.', - responseSchema: ResponseSchema, - }, - (datapoint) => { - return { - input: { - text_input: { - content: datapoint.output as string - }, - }, - policies: { - policy_type: "HATE_SPEECH", - threshold, - } - }; - }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult - } - }; - } - ); -} +async function checksEvalInstance( + projectId: string, + auth: GoogleAuth, + partialRequest: any, + responseSchema: ResponseType +): Promise> { -function createMedicalInfoEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - threshold?: number -): Action { - return factory.create( - ai, + return await runInNewSpan( { - metric: ChecksEvaluationMetricType.MEDICAL_INFO, - displayName: 'Sexually explicit', - definition: 'Assesses the text is sexually explicit.', - responseSchema: ResponseSchema, - }, - (datapoint) => { - return { - input: { - text_input: { - content: datapoint.output as string - }, - }, - policies: { - policy_type: "MEDICAL_INFO", - threshold, - } - }; + metadata: { + name: 'EvaluationService#evaluateInstances', + }, }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult - } + async (metadata, _otSpan) => { + const request = { + ...partialRequest, }; - } - ); -} -function createViolenceAndGoreEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - threshold?: number -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.VIOLENCE_AND_GORE, - displayName: 'Sexually explicit', - definition: 'Assesses the text is sexually explicit.', - responseSchema: ResponseSchema, - }, - (datapoint) => { - return { - input: { - text_input: { - content: datapoint.output as string - }, - }, - policies: { - policy_type: "VIOLENCE_AND_GORE", - threshold, - } - }; - }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult - } - }; - } - ); -} -function createObscenityAndProfanityEvaluator( - ai: Genkit, - factory: EvaluatorFactory, - threshold?: number -): Action { - return factory.create( - ai, - { - metric: ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, - displayName: 'Sexually explicit', - definition: 'Assesses the text is sexually explicit.', - responseSchema: ResponseSchema, - }, - (datapoint) => { - return { - input: { - text_input: { - content: datapoint.output as string - }, - }, - policies: { - policy_type: "OBSCENITY_AND_PROFANITY", - threshold, - } - }; - }, - (response) => { - return { - score: response.policyResults[0].score, - details: { - reasoning: response.policyResults[0].violationResult + metadata.input = request; + const client = await auth.getClient(); + const url = "https://checks.googleapis.com/v1alpha/aisafety:classifyContent" + + const response = await client.request({ + url, + method: "POST", + body: JSON.stringify(request), + headers: { + "X-Goog-User-Project": projectId, + "Content-Type": "application/json", } - }; + }) + metadata.output = response.data; + + try { + return responseSchema.parse(response.data); + } catch (e) { + throw new Error(`Error parsing ${url} API response: ${e}`); + } } ); -} - - +} \ No newline at end of file diff --git a/js/plugins/checks/src/evaluator_factory.ts b/js/plugins/checks/src/evaluator_factory.ts deleted file mode 100644 index 16a82908d..000000000 --- a/js/plugins/checks/src/evaluator_factory.ts +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { Action, Genkit, GENKIT_CLIENT_HEADER, z } from 'genkit'; -import { BaseEvalDataPoint, Score } from 'genkit/evaluator'; -import { runInNewSpan } from 'genkit/tracing'; -import { GoogleAuth } from 'google-auth-library'; -import { ChecksEvaluationMetricType } from './evaluation.js'; - -export class EvaluatorFactory { - constructor( - private readonly auth: GoogleAuth, - private readonly projectId: string - ) { } - - create( - ai: Genkit, - config: { - metric: ChecksEvaluationMetricType; - displayName: string; - definition: string; - responseSchema: ResponseType; - }, - toRequest: (datapoint: BaseEvalDataPoint) => any, - responseHandler: (response: z.infer) => Score - ): Action { - return ai.defineEvaluator( - { - name: `checks/${config.metric.toLocaleLowerCase()}`, - displayName: config.displayName, - definition: config.definition, - }, - async (datapoint: BaseEvalDataPoint) => { - const responseSchema = config.responseSchema; - let response; - - response = await this.checksEvalInstance( - toRequest(datapoint), - responseSchema - ); - - - return { - evaluation: responseHandler(response), - testCaseId: datapoint.testCaseId, - }; - } - ); - } - - - async checksEvalInstance( - partialRequest: any, - responseSchema: ResponseType - ): Promise> { - - console.log('HSH::partialRequest: ', partialRequest) - return await runInNewSpan( - { - metadata: { - name: 'EvaluationService#evaluateInstances', - }, - }, - async (metadata, _otSpan) => { - const request = { - ...partialRequest, - }; - - console.log("HSH::request: ", request) - - /** - gcloud auth application-default login --scopes=https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/checks - - curl -X POST https://checks.googleapis.com/v1alpha/aisafety:classifyContent \ - -H "Authorization: Bearer $(gcloud auth application-default print-access-token)" \ - -H "X-Goog-User-Project: checks-api-370419" \ - -H "Content-Type: application/json" \ - -d '{ \ - "input": { \ - "text_input": { \ - "content": "I hate you and all people on earth" \ - } \ - }, \ - "policies": { "policy_type": "HARASSMENT" } \ - }'\ - */ - - metadata.input = request; - const client = await this.auth.getClient(); - const url = "https://checks.googleapis.com/v1alpha/aisafety:classifyContent" - - const response = await client.request({ - url, - method: "POST", - body: JSON.stringify(request), - headers: { - "X-Goog-User-Project": "checks-api-370419", - "Content-Type": "application/json", - } - }) - metadata.output = response.data; - - console.log("HSH::response: ", response) - console.log("HSH::response.data: ", response.data) - - // console.log("HSH::response: ", response) - // console.log("HSH::metadata: ", metadata) - - try { - return responseSchema.parse(response.data); - } catch (e) { - throw new Error(`Error parsing ${url} API response: ${e}`); - } - } - ); - } -} diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index 24b137aaa..6dd946b92 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -56,40 +56,40 @@ export const ai = genkit({ ], }), checks({ - projectId: "checks-prod", + projectId: "checks-api-370419", evaluation: { metrics: [ { type: ChecksEvaluationMetricType.DANGEROUS_CONTENT, - threshold: .01, + threshold: .5, }, { type: ChecksEvaluationMetricType.PII_SOLICITING_RECITING, - threshold: .01, + threshold: .5, }, { type: ChecksEvaluationMetricType.HARASSMENT, - threshold: .01, + threshold: .5, }, { type: ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, - threshold: .01, + threshold: .5, }, { type: ChecksEvaluationMetricType.HATE_SPEECH, - threshold: .01, + threshold: .5, }, { type: ChecksEvaluationMetricType.MEDICAL_INFO, - threshold: .01, + threshold: .5, }, { type: ChecksEvaluationMetricType.VIOLENCE_AND_GORE, - threshold: .01, + threshold: .5, }, { type: ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, - threshold: .01, + threshold: .5, } ], }, From 4482ee055d9a4a5b339f82817d71d4151d5cd2b9 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Thu, 7 Nov 2024 17:46:46 +0000 Subject: [PATCH 09/30] Remove leftover code from vertex ai plugin. --- js/plugins/checks/package.json | 5 +- js/plugins/checks/tests/anthropic_test.ts | 313 ---------------- js/plugins/checks/tests/gemini_test.ts | 347 ------------------ .../tests/vector-search/bigquery_test.ts | 168 --------- .../query_public_endpoint_test.ts | 86 ----- .../vector-search/upsert_datapoints_test.ts | 81 ---- .../checks/tests/vector-search/utils_test.ts | 70 ---- 7 files changed, 2 insertions(+), 1068 deletions(-) delete mode 100644 js/plugins/checks/tests/anthropic_test.ts delete mode 100644 js/plugins/checks/tests/gemini_test.ts delete mode 100644 js/plugins/checks/tests/vector-search/bigquery_test.ts delete mode 100644 js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts delete mode 100644 js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts delete mode 100644 js/plugins/checks/tests/vector-search/utils_test.ts diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index d9ee9e180..2a95bc73a 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -24,8 +24,7 @@ "compile": "tsup-node", "build:clean": "rimraf ./lib", "build": "npm-run-all build:clean check compile", - "build:watch": "tsup-node --watch", - "test": "tsx --test ./tests/*_test.ts ./tests/**/*_test.ts" + "build:watch": "tsup-node --watch" }, "repository": { "type": "git", @@ -64,4 +63,4 @@ "default": "./lib/index.js" } } -} +} \ No newline at end of file diff --git a/js/plugins/checks/tests/anthropic_test.ts b/js/plugins/checks/tests/anthropic_test.ts deleted file mode 100644 index f5870e6a1..000000000 --- a/js/plugins/checks/tests/anthropic_test.ts +++ /dev/null @@ -1,313 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { - Message, - MessageCreateParamsBase, -} from '@anthropic-ai/sdk/resources/messages.mjs'; -import { GenerateRequest, GenerateResponseData } from 'genkit'; -import assert from 'node:assert'; -import { describe, it } from 'node:test'; -import { - AnthropicConfigSchema, - fromAnthropicResponse, - toAnthropicRequest, -} from '../src/anthropic.js'; - -const MODEL_ID = 'modelid'; - -describe('toAnthropicRequest', () => { - const testCases: { - should: string; - input: GenerateRequest; - expectedOutput: MessageCreateParamsBase; - }[] = [ - { - should: 'should transform genkit message (text content) correctly', - input: { - messages: [ - { - role: 'user', - content: [{ text: 'Tell a joke about dogs.' }], - }, - ], - }, - expectedOutput: { - max_tokens: 4096, - model: MODEL_ID, - messages: [ - { - role: 'user', - content: [{ type: 'text', text: 'Tell a joke about dogs.' }], - }, - ], - }, - }, - { - should: 'should transform system message', - input: { - messages: [ - { - role: 'system', - content: [{ text: 'Talk like a pirate.' }], - }, - { - role: 'user', - content: [{ text: 'Tell a joke about dogs.' }], - }, - ], - }, - expectedOutput: { - max_tokens: 4096, - model: MODEL_ID, - system: 'Talk like a pirate.', - messages: [ - { - role: 'user', - content: [{ type: 'text', text: 'Tell a joke about dogs.' }], - }, - ], - }, - }, - { - should: - 'should transform genkit message (inline base64 image content) correctly', - input: { - messages: [ - { - role: 'user', - content: [ - { text: 'describe the following image:' }, - { - media: { - contentType: 'image/jpeg', - url: '', - }, - }, - ], - }, - ], - }, - expectedOutput: { - max_tokens: 4096, - model: MODEL_ID, - messages: [ - { - role: 'user', - content: [ - { type: 'text', text: 'describe the following image:' }, - { - type: 'image', - source: { - type: 'base64', - media_type: 'image/jpeg', - data: '/9j/4QDeRXhpZgAASUkqAAgAAAAGABIBAwABAAAAAQAAABoBBQABAAAAVgAAABsBBQABAAAAXgAAACgBAwABAAAAAgAAABMCAwABAAAAAQAAAGmHBAABAAAAZgAAAAAAAABIAAAAAQAAAEgAAAABAAAABwAAkAcABAAAADAyMTABkQcABAAAAAECAwCGkgcAFgAAAMAAAAAAoAcABAAAADAxMDABoAMAAQAAAP//AAACoAQAAQAAAMgAAAADoAQAAQAAAMgAAAAAAAAAQVNDSUkAAABQaWNzdW0gSUQ6IDY4N//bAEMACAYGBwYFCAcHBwkJCAoMFA0MCwsMGRITDxQdGh8eHRocHCAkLicgIiwjHBwoNyksMDE0NDQfJzk9ODI8LjM0Mv/bAEMBCQkJDAsMGA0NGDIhHCEyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMv/CABEIAMgAyAMBIgACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQIDBAUGB//EABgBAQEBAQEAAAAAAAAAAAAAAAABAgME/9oADAMBAAIQAxAAAAH3ZOsiYEgAmIkWEiEiEkiRYSICBVSQSRIBEhQAUAEAARMAJWYmpRBZWYmYkBQAUAAEARIgJEsViidRMKmYmW98M5uVEzQAAAAIABoa3zTLZ9M2Pltl+pvmWU+kvn+xHt7eMzHrcnlMy+mam2AAAAgEBPj9/Y+XWuTb6U1xLbOWNO29EupO1Ea85IOp6/ldXeQoAAEAgJq+G9/pteA6WjoR0ev5v1Rv8Xv8jGuTERF/W07G4yGoAAACCAE1Zz6a6/z33XKXgVv0MXzfd5+1VvY4O/E2i24AACCAgkqqlAiKzXNybOmc/j+i4eNYfQ7G/Ldjy6zdUWioupKWipbRCyYgTCKlAxzjnWcnK6PJl2c2v0+W74djUrPOO28WmguoW6sF4qLREWWVgsrBZRWvNZ1iedbyWN+u6nzfoc9++1PO82X206mx343UF4rBdQWVgtEKmIglAKiZx2TT8j6bl8uvA2e1Obj1d+M69Hm4fa78rRVrN4oLTQXisF4rBaIhLKCygrIcPhnm72znHpagdD0h6uFZOvOAoJECgRBIAC//xAApEAABBAECBQUBAAMAAAAAAAABAAIDBBEFEhATITBAFBUgIjFBIzJw/9oACAEBAAEFAv8AgGVnxyfkD4RPYz4EtuGFC9VK58LkHNPz6jv9XHCwgENyEsgQszhC7ZCGoWUNSmQ1ORN1NQztnb3MFjgh9ljjjjjjpufU9zU60T2D06hjjlfJS2jlxoRBclckrkPXJlXLkC2PVSt6dnckG1X6XpnwzPghgvSOjaesLQUYQWysAPDLlSquiH73Wu3NcxqtVn0pgVuUO2KNlyuVdG12UOppUuUP3vO+jn4exzG2opG8qWHAdNPtDLW58UpLR1VKlykUOnfP+MztOX6QySUMdG/242me1yRPazElGgYiSgMeAQHBw6tK1CBskNfUDCJLQmNOnylnDQPCkZvaFP8AatHWbIq1SGHg0dfDyv8AYRxBqAwEPzwcr+KM/YL+IfnazwysrK3cf4mHBBWVn5Z71yYwVotRlklGqSKLVXlTX5I17jMXQSGSPwrsPPqw0Jo5hp9lR0rLTPWnkMenWN8MZji8K7YfXA1eRe7yBe7yr3aZe6zqpfltTLPDKys8MrKysrKzwz8NRdlmxudgWxi5bMmNhVBobZWeOTxzxyshZWeP/8QAIBEAAgICAgIDAAAAAAAAAAAAAAEREgIwECEDIDFAUP/aAAgBAwEBPwH2j8VaYKlSpUh/ab1ob14j4nShbMe3HGShwPVgmnZFRp5OTJRqw7RUqeT50//EAB8RAAICAwACAwAAAAAAAAAAAAABAhEDEjAQIRNAUP/aAAgBAgEBPwH8uuNmxsbGxa+0lxq/DEuNCRMXsrnIXOfpWxog7Vi5ZGmtWOaItRVEXa5ZHQ8jPlZidrj/AP/EADkQAAEDAQMKAgYLAQAAAAAAAAEAAhEDEiExEBMiMjNAQVFhkTCSQnFygYKhBBQgIzRSYHCiscHh/9oACAEBAAY/Av3IipUa09SvxFPutrTPxLFvfdiXGScT9i4lXVH+ZbZ/dbUrWHlWDOyvY1X0uxVpnvHilpEEYjIB4UcIv8X6zZdIueGrGsOy0X1CfZ/6ryR7lth5Stqz5rXZ3WLfMFq/MLZv7LZv8qGg6/oiTrux6eLaxb6QVtl9F2HRE0jDnGLXJVKdV1oWbieeXBHLEnus5UJtnhyXTxZRpPE0n4ItxY7AnjkhA1HBvrUZzuFIwOSEKlQaf9KOHjWx8QXQp30etrDj/qfTmbBhF5ExgEXvkkqCIRou1Th0OTOVBp8ByUDcI9A4dEHNMOGBVSqK7peZiLk6k8Q4FQCAcRKmoWwPy3oLOVYngOS67jBwKsOx4HnkzuFRmBVlzVotMlZyoPvOXLJJx3KOPDI8LSbKtNZpc8k9t1I/RLngSrGiFgsELPEIK1EX7m5l89Fask3LZnstkhZp8FpNPZRBG5h7ZuV7Hd1dTefWVsj5lsv5K5ndydSc0Na1s4zuebmMCrn8FtFrke5azo9lTaePhVRzCYdwI3H/xAAqEAACAQIFAgYDAQEAAAAAAAABEQAhMRBBUWFxIDCBkaGx0fBAweFw8f/aAAgBAQABPyH/AABYmsY/GZQW6WZqfhZQgGfpgoK59ASAvvmW9lU8p/IEFt4JBjBuAgZy8oq2LxY5wegoC+8y0OZDUwAoBdsAg9gVLGOClq9fM6+QDArm5CZkk5t+cXJfyguXHdJ0ikTI4HQ4tHJSrAIARbABFqICTRWspEEWT4QwILj7qoPUADxMzB5l8QLBTqLgb90L/uDK8X4IXb6OJtPIPYvHzSjQydjfuKHyJ/QPiVgqEgyqYND3D4d1oqwJTY1lPiXUDfpKozeYFYc/qVAfN2Agi6bhwZxkJPiVUAlAa2j0JgehnYAowIChGz5mv/rujU4IOR0h4D0AHI6fE2uKUDQ7iFIVAG0IFWaGziGMxHEPqL3BEuxEWMMCsKk0gBFpH3eGtOXeEo2PMGvMAwGoMpqYAz7BClRIpkbRZyQyxJjJZpqYMVWtWHQebgVFGdIMRWfp3hM5ntAAADvgde+JpxAXGuDDHOCoxCqGPvMKOOpZ4wVNixFvGGIFVXeDIqPIikrrCc45nvuGInQMMGGFdB84BWGgMw7GE7K0vElitdmIAJrfW8VnoIatxi+p9bjgMsFS0MogkFoPeetxow6lRfUwmXTKzTBxx4OOPqcccccdJdANwhRwa9oTChx9hxxx9DGAmriA4TJM0JhUdDjjjjjjwOJF1i6xNcTQ+MPMcMUHjADSM6xnWM6xnWNGdYzrGdY444444THiFgVqyhwCCbQk2UMJo+yc2FVAJMAMRjBUCGxjjjlca9Pj0uOEHHMCR/jOLQD58rSjgxuiCWhQENQImoSS9+prE4sRjBxiNZQHaAJpVAN60ELs66YSKRsuNtESChbv+YfovtAWYHcNYScDnCHAfUjhHgcHOBKCRubRlQkb1zymTc60y8o0020CqNNZoqBa4BBxGAIWAdnnie8Z2nBGYzDDGkYGUTOGBDecI45//9oADAMBAAIAAwAAABBhJAAuAIJODC8BIgAcABSwEABb+gg6kAFzww0BA2VeEAbzzzzwEABOYQKszTzziED9uo7LMuvzzwED+T//AGn90884hEzOM7iPE4whC/Nwm7969CeLeTKcADfPOfnOGS0kgwBv7elCCCmwMoMyt33g4A4AAV9xefhCjhABAgffA//EAB4RAAMBAAIDAQEAAAAAAAAAAAABERAgMSEwQVFh/9oACAEDAQE/EOExoQntfK4sWwnB6hvT5dka2JXF6tuvv0LKs+HGlxMs7If0fhyu0bqGGG8UvPyOgxspUVFWUpcVoTtia6PIDs/UpfATWloGN55UpcWAlf1iQWR6f//EAB4RAAMAAgMBAQEAAAAAAAAAAAABERAhIDAxQVFh/9oACAECAQE/EOVLyfa8b60qJJYY+loPBJI3dDw0RjTN/glF0N6glRppvEvehmjRfjP4PpyhMMQdE2WleEIQhCYhB6NmeSEIQhCExsZM8F6O9ESeEREIQnJo8GMei4ovR1UNwU8SGlxJDKvk8//EACkQAQACAgECBgIDAQEBAAAAAAEAESExQVFhEHGBkaGxMMEg0fBA4fH/2gAIAQEAAT8QgeBrwD+FQjqUypUqV+EPAPE14mv+KvEL8T874PjUPxX+F8TwNf8AOeF0QXT8QcsFQTSP5alfwZUrwvWebrLB7w8QNL7zh93/ABNyIbrntNORj0Srw7gj+orYCRwW68OtRscpAAmn8L/ItqZnPRb8TQuvX90IM50/vmAMOZ+4hHMeq5cTqdGomtZ6xMjxAOUPsTDjWN8QRdepuGcev5bg65ESndlwVa83Vf3EIQ+kpbAHsQhRp1f2mYH7GvmYMvTMfMrP9g6QzzL+tBtv90fuUK8mR+2Wsdcv6kmSZKDT9z8rD3EUhhGK0Ab4DmKsGMKov99oAitaTVvrNhm4xaRqrIZ4KJb5S14UTa3cAA16wxfS5QzYd7ltEaw733uLi4fXSqfdh+RSXYhQMZRyYLOEdETRCNlFfrajFcxmHzIo06cq3ssyxVvoH9pXWPQ/c0PkKfYS2VjRk+4iAyHZCeRtHhh2DHxcvaWuZspIBis0ZNw8gA1oBwee8P5P8sTbK6GsOUNnJ5EMt1BWZE9Hh9IJCnQJ7FmrLmKZs83UUDza1XaWwcmGGCbdLiGwZtqf+TEwDHVfLp3jqL7Gn3iwU6aDkjSopZHPFdZgYAqne/yoZ4YuDq6+X5TxI9wLa7kCorqGy9t7XCV0gYcxMRuj+jqRDR+Bq39xWGVkekOSkD7JBb66lQZYwL6wXQjaL7g6hf8AyJm5p3XVRwKqu3R++7iZuphnPaABQUH5XTEqo0cez5PIlYV9a42I+ySn1CgUj5RpP0zWa79VXp5RKxMS0sWcgCpLUE2ct/nylDG6RYvvHcAWzWWJ0GqTv2iICRKBdvSFUrtMhf8ALeNR7LqWtOvnNJB/K/G5fitQloYzvD8u7jo45i8qflo8icjyQU3EiNeeaH1qLWWv2sUeSrX0zBDTo1vWgZpHZLBNstA1aCiEvZ0DKTRXdQiXlbvs694tRHafb2gFLUym1/EsGpcuLAQnZKHki9hbHz9nPvKOeY0NUecAeoZvt5XD30UGAfPp7w4vSK2RxQVcNuDeBO5Tf10cqntvg2L0IbBd/Y4DtLlyhLPG4udzzRaikvwXPhlzBERJsUdlpiqC5mYZavZL8GggkHkn1AoOpUTTS69JRQOtTiVYCYHL5suWlplLlpad0slxZcW4giJWGEsTZ5QPul55AHzKgLTtKgFDZbZkmLKeC+8vvL7y+8vvLlhKeBjKXFzLIdaA6kG29L6mU3mZV9v3HVAs01H4IFwpXLlstO6Uj4NesqckRKx6h7ynhG/SI8IGNsB6sbk8vqDZKG+j6ZYbl5931KXuPqyo2neTvJ3kU5Z3k7yPWS/Vi+rHHcXUvL9YjzFHMt6y5YAARJvmbuwC9V6xSp4Cg4sgxadq/qLhYuUVX6QgoE2Gr8pYcZ5ikLipaWixb1guoo3LUje7mes9Y7h1uevhzmUdQtBBRWcVOPmYMyO58oOrkcWEwC6HJ9R9ptZ7+IMdQ4pedbilJAdquJXNxSoAtxQlzJFx5ziZYmo1d0yqWSyIvmWTNi0XqlYhZRWQeo/UeBDBMwXmUXrU/wAMQylncn6hxS8BA8q+0K6DjZ8S9gCXeoA0V7QTxHKn4le8v1MStQFR733iThlb1G+os7xFZqNCiFvFRVQ7os8mCnVlSrQTZksBafGFMbFabqDpQDANvh5t+8b4RwWOb5sls01QtHzqcUUqD69YMXdONZAm6W5YvLfvADSssaHrG3cC8S3al3BKOBlhf3lCEYUYX0gnSgVwhj1ktsCf/9k=', - }, - }, - ], - }, - ], - }, - }, - ]; - for (const test of testCases) { - it(test.should, () => { - assert.deepEqual( - toAnthropicRequest(MODEL_ID, test.input), - test.expectedOutput - ); - }); - } -}); - -describe('fromAnthropicResponse', () => { - const testCases: { - should: string; - input: GenerateRequest; - response: Message; - expectedOutput: GenerateResponseData; - }[] = [ - { - should: 'should transform genkit message (text content) correctly', - input: { - messages: [ - { - role: 'user', - content: [{ text: 'Tell a joke about dogs.' }], - }, - ], - }, - response: { - id: 'abcd1234', - model: MODEL_ID, - role: 'assistant', - stop_reason: 'end_turn', - usage: { - input_tokens: 123, - output_tokens: 234, - }, - stop_sequence: null, - type: 'message', - content: [ - { - type: 'text', - text: 'part 1', - }, - { - type: 'text', - text: 'part 2', - }, - ], - }, - expectedOutput: { - custom: { - id: 'abcd1234', - model: MODEL_ID, - type: 'message', - }, - finishReason: 'stop', - message: { - role: 'model', - content: [ - { - text: 'part 1', - }, - { - text: 'part 2', - }, - ], - }, - usage: { - inputAudioFiles: 0, - inputCharacters: 23, - inputImages: 0, - inputTokens: 123, - inputVideos: 0, - outputAudioFiles: 0, - outputCharacters: 12, - outputImages: 0, - outputTokens: 234, - outputVideos: 0, - }, - }, - }, - { - should: 'should transform genkit tool call correctly', - input: { - messages: [ - { - role: 'user', - content: [{ text: "What's the weather like today?" }], - }, - ], - tools: [ - { - name: 'get_weather', - description: 'Get the weather for a location.', - inputSchema: { - type: 'object', - properties: { - location: { - type: 'string', - description: 'The city and state, e.g. San Francisco, CA', - }, - }, - required: ['location'], - }, - }, - ], - }, - response: { - id: 'abcd1234', - model: MODEL_ID, - role: 'assistant', - type: 'message', - stop_reason: 'tool_use', - stop_sequence: null, - usage: { - input_tokens: 123, - output_tokens: 234, - }, - content: [ - { - id: 'toolu_get_weather', - name: 'get_weather', - type: 'tool_use', - input: { - type: 'object', - properties: { - location: { - type: 'string', - description: 'The city and state, e.g. San Francisco, CA', - }, - }, - required: ['location'], - }, - }, - ], - }, - expectedOutput: { - custom: { - id: 'abcd1234', - model: MODEL_ID, - type: 'message', - }, - finishReason: 'stop', - message: { - role: 'model', - content: [ - { - toolRequest: { - name: 'get_weather', - ref: 'toolu_get_weather', - input: { - type: 'object', - properties: { - location: { - type: 'string', - description: 'The city and state, e.g. San Francisco, CA', - }, - }, - required: ['location'], - }, - }, - }, - ], - }, - usage: { - inputAudioFiles: 0, - inputCharacters: 30, - inputImages: 0, - inputTokens: 123, - inputVideos: 0, - outputAudioFiles: 0, - outputCharacters: 0, - outputImages: 0, - outputTokens: 234, - outputVideos: 0, - }, - }, - }, - ]; - for (const test of testCases) { - it(test.should, () => { - assert.deepEqual( - fromAnthropicResponse(test.input, test.response), - test.expectedOutput - ); - }); - } -}); diff --git a/js/plugins/checks/tests/gemini_test.ts b/js/plugins/checks/tests/gemini_test.ts deleted file mode 100644 index c6156b4be..000000000 --- a/js/plugins/checks/tests/gemini_test.ts +++ /dev/null @@ -1,347 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { GenerateContentCandidate } from '@google-cloud/vertexai'; -import { MessageData } from 'genkit'; -import assert from 'node:assert'; -import { describe, it } from 'node:test'; -import { - fromGeminiCandidate, - toGeminiMessage, - toGeminiSystemInstruction, -} from '../src/gemini.js'; - -describe('toGeminiMessages', () => { - const testCases = [ - { - should: 'should transform genkit message (text content) correctly', - inputMessage: { - role: 'user', - content: [{ text: 'Tell a joke about dogs.' }], - }, - expectedOutput: { - role: 'user', - parts: [{ text: 'Tell a joke about dogs.' }], - }, - }, - { - should: - 'should transform genkit message (tool request content) correctly', - inputMessage: { - role: 'model', - content: [ - { toolRequest: { name: 'tellAFunnyJoke', input: { topic: 'dogs' } } }, - ], - }, - expectedOutput: { - role: 'model', - parts: [ - { functionCall: { name: 'tellAFunnyJoke', args: { topic: 'dogs' } } }, - ], - }, - }, - { - should: - 'should transform genkit message (tool response content) correctly', - inputMessage: { - role: 'tool', - content: [ - { - toolResponse: { - name: 'tellAFunnyJoke', - output: 'Why did the dogs cross the road?', - }, - }, - ], - }, - expectedOutput: { - role: 'function', - parts: [ - { - functionResponse: { - name: 'tellAFunnyJoke', - response: { - name: 'tellAFunnyJoke', - content: 'Why did the dogs cross the road?', - }, - }, - }, - ], - }, - }, - { - should: - 'should transform genkit message (inline base64 image content) correctly', - inputMessage: { - role: 'user', - content: [ - { text: 'describe the following image:' }, - { - media: { - contentType: 'image/jpeg', - url: '', - }, - }, - ], - }, - expectedOutput: { - role: 'user', - parts: [ - { text: 'describe the following image:' }, - { - inlineData: { - mimeType: 'image/jpeg', - data: '/9j/4QDeRXhpZgAASUkqAAgAAAAGABIBAwABAAAAAQAAABoBBQABAAAAVgAAABsBBQABAAAAXgAAACgBAwABAAAAAgAAABMCAwABAAAAAQAAAGmHBAABAAAAZgAAAAAAAABIAAAAAQAAAEgAAAABAAAABwAAkAcABAAAADAyMTABkQcABAAAAAECAwCGkgcAFgAAAMAAAAAAoAcABAAAADAxMDABoAMAAQAAAP//AAACoAQAAQAAAMgAAAADoAQAAQAAAMgAAAAAAAAAQVNDSUkAAABQaWNzdW0gSUQ6IDY4N//bAEMACAYGBwYFCAcHBwkJCAoMFA0MCwsMGRITDxQdGh8eHRocHCAkLicgIiwjHBwoNyksMDE0NDQfJzk9ODI8LjM0Mv/bAEMBCQkJDAsMGA0NGDIhHCEyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMv/CABEIAMgAyAMBIgACEQEDEQH/xAAbAAEAAgMBAQAAAAAAAAAAAAAAAQIDBAUGB//EABgBAQEBAQEAAAAAAAAAAAAAAAABAgME/9oADAMBAAIQAxAAAAH3ZOsiYEgAmIkWEiEiEkiRYSICBVSQSRIBEhQAUAEAARMAJWYmpRBZWYmYkBQAUAAEARIgJEsViidRMKmYmW98M5uVEzQAAAAIABoa3zTLZ9M2Pltl+pvmWU+kvn+xHt7eMzHrcnlMy+mam2AAAAgEBPj9/Y+XWuTb6U1xLbOWNO29EupO1Ea85IOp6/ldXeQoAAEAgJq+G9/pteA6WjoR0ev5v1Rv8Xv8jGuTERF/W07G4yGoAAACCAE1Zz6a6/z33XKXgVv0MXzfd5+1VvY4O/E2i24AACCAgkqqlAiKzXNybOmc/j+i4eNYfQ7G/Ldjy6zdUWioupKWipbRCyYgTCKlAxzjnWcnK6PJl2c2v0+W74djUrPOO28WmguoW6sF4qLREWWVgsrBZRWvNZ1iedbyWN+u6nzfoc9++1PO82X206mx343UF4rBdQWVgtEKmIglAKiZx2TT8j6bl8uvA2e1Obj1d+M69Hm4fa78rRVrN4oLTQXisF4rBaIhLKCygrIcPhnm72znHpagdD0h6uFZOvOAoJECgRBIAC//xAApEAABBAECBQUBAAMAAAAAAAABAAIDBBEFEhATITBAFBUgIjFBIzJw/9oACAEBAAEFAv8AgGVnxyfkD4RPYz4EtuGFC9VK58LkHNPz6jv9XHCwgENyEsgQszhC7ZCGoWUNSmQ1ORN1NQztnb3MFjgh9ljjjjjjpufU9zU60T2D06hjjlfJS2jlxoRBclckrkPXJlXLkC2PVSt6dnckG1X6XpnwzPghgvSOjaesLQUYQWysAPDLlSquiH73Wu3NcxqtVn0pgVuUO2KNlyuVdG12UOppUuUP3vO+jn4exzG2opG8qWHAdNPtDLW58UpLR1VKlykUOnfP+MztOX6QySUMdG/242me1yRPazElGgYiSgMeAQHBw6tK1CBskNfUDCJLQmNOnylnDQPCkZvaFP8AatHWbIq1SGHg0dfDyv8AYRxBqAwEPzwcr+KM/YL+IfnazwysrK3cf4mHBBWVn5Z71yYwVotRlklGqSKLVXlTX5I17jMXQSGSPwrsPPqw0Jo5hp9lR0rLTPWnkMenWN8MZji8K7YfXA1eRe7yBe7yr3aZe6zqpfltTLPDKys8MrKysrKzwz8NRdlmxudgWxi5bMmNhVBobZWeOTxzxyshZWeP/8QAIBEAAgICAgIDAAAAAAAAAAAAAAEREgIwECEDIDFAUP/aAAgBAwEBPwH2j8VaYKlSpUh/ab1ob14j4nShbMe3HGShwPVgmnZFRp5OTJRqw7RUqeT50//EAB8RAAICAwACAwAAAAAAAAAAAAABAhEDEjAQIRNAUP/aAAgBAgEBPwH8uuNmxsbGxa+0lxq/DEuNCRMXsrnIXOfpWxog7Vi5ZGmtWOaItRVEXa5ZHQ8jPlZidrj/AP/EADkQAAEDAQMKAgYLAQAAAAAAAAEAAhEDEiExEBMiMjNAQVFhkTCSQnFygYKhBBQgIzRSYHCiscHh/9oACAEBAAY/Av3IipUa09SvxFPutrTPxLFvfdiXGScT9i4lXVH+ZbZ/dbUrWHlWDOyvY1X0uxVpnvHilpEEYjIB4UcIv8X6zZdIueGrGsOy0X1CfZ/6ryR7lth5Stqz5rXZ3WLfMFq/MLZv7LZv8qGg6/oiTrux6eLaxb6QVtl9F2HRE0jDnGLXJVKdV1oWbieeXBHLEnus5UJtnhyXTxZRpPE0n4ItxY7AnjkhA1HBvrUZzuFIwOSEKlQaf9KOHjWx8QXQp30etrDj/qfTmbBhF5ExgEXvkkqCIRou1Th0OTOVBp8ByUDcI9A4dEHNMOGBVSqK7peZiLk6k8Q4FQCAcRKmoWwPy3oLOVYngOS67jBwKsOx4HnkzuFRmBVlzVotMlZyoPvOXLJJx3KOPDI8LSbKtNZpc8k9t1I/RLngSrGiFgsELPEIK1EX7m5l89Fask3LZnstkhZp8FpNPZRBG5h7ZuV7Hd1dTefWVsj5lsv5K5ndydSc0Na1s4zuebmMCrn8FtFrke5azo9lTaePhVRzCYdwI3H/xAAqEAACAQIFAgYDAQEAAAAAAAABEQAhMRBBUWFxIDCBkaGx0fBAweFw8f/aAAgBAQABPyH/AABYmsY/GZQW6WZqfhZQgGfpgoK59ASAvvmW9lU8p/IEFt4JBjBuAgZy8oq2LxY5wegoC+8y0OZDUwAoBdsAg9gVLGOClq9fM6+QDArm5CZkk5t+cXJfyguXHdJ0ikTI4HQ4tHJSrAIARbABFqICTRWspEEWT4QwILj7qoPUADxMzB5l8QLBTqLgb90L/uDK8X4IXb6OJtPIPYvHzSjQydjfuKHyJ/QPiVgqEgyqYND3D4d1oqwJTY1lPiXUDfpKozeYFYc/qVAfN2Agi6bhwZxkJPiVUAlAa2j0JgehnYAowIChGz5mv/rujU4IOR0h4D0AHI6fE2uKUDQ7iFIVAG0IFWaGziGMxHEPqL3BEuxEWMMCsKk0gBFpH3eGtOXeEo2PMGvMAwGoMpqYAz7BClRIpkbRZyQyxJjJZpqYMVWtWHQebgVFGdIMRWfp3hM5ntAAADvgde+JpxAXGuDDHOCoxCqGPvMKOOpZ4wVNixFvGGIFVXeDIqPIikrrCc45nvuGInQMMGGFdB84BWGgMw7GE7K0vElitdmIAJrfW8VnoIatxi+p9bjgMsFS0MogkFoPeetxow6lRfUwmXTKzTBxx4OOPqcccccdJdANwhRwa9oTChx9hxxx9DGAmriA4TJM0JhUdDjjjjjjwOJF1i6xNcTQ+MPMcMUHjADSM6xnWM6xnWNGdYzrGdY444444THiFgVqyhwCCbQk2UMJo+yc2FVAJMAMRjBUCGxjjjlca9Pj0uOEHHMCR/jOLQD58rSjgxuiCWhQENQImoSS9+prE4sRjBxiNZQHaAJpVAN60ELs66YSKRsuNtESChbv+YfovtAWYHcNYScDnCHAfUjhHgcHOBKCRubRlQkb1zymTc60y8o0020CqNNZoqBa4BBxGAIWAdnnie8Z2nBGYzDDGkYGUTOGBDecI45//9oADAMBAAIAAwAAABBhJAAuAIJODC8BIgAcABSwEABb+gg6kAFzww0BA2VeEAbzzzzwEABOYQKszTzziED9uo7LMuvzzwED+T//AGn90884hEzOM7iPE4whC/Nwm7969CeLeTKcADfPOfnOGS0kgwBv7elCCCmwMoMyt33g4A4AAV9xefhCjhABAgffA//EAB4RAAMBAAIDAQEAAAAAAAAAAAABERAgMSEwQVFh/9oACAEDAQE/EOExoQntfK4sWwnB6hvT5dka2JXF6tuvv0LKs+HGlxMs7If0fhyu0bqGGG8UvPyOgxspUVFWUpcVoTtia6PIDs/UpfATWloGN55UpcWAlf1iQWR6f//EAB4RAAMAAgMBAQEAAAAAAAAAAAABERAhIDAxQVFh/9oACAECAQE/EOVLyfa8b60qJJYY+loPBJI3dDw0RjTN/glF0N6glRppvEvehmjRfjP4PpyhMMQdE2WleEIQhCYhB6NmeSEIQhCExsZM8F6O9ESeEREIQnJo8GMei4ovR1UNwU8SGlxJDKvk8//EACkQAQACAgECBgIDAQEBAAAAAAEAESExQVFhEHGBkaGxMMEg0fBA4fH/2gAIAQEAAT8QgeBrwD+FQjqUypUqV+EPAPE14mv+KvEL8T874PjUPxX+F8TwNf8AOeF0QXT8QcsFQTSP5alfwZUrwvWebrLB7w8QNL7zh93/ABNyIbrntNORj0Srw7gj+orYCRwW68OtRscpAAmn8L/ItqZnPRb8TQuvX90IM50/vmAMOZ+4hHMeq5cTqdGomtZ6xMjxAOUPsTDjWN8QRdepuGcev5bg65ESndlwVa83Vf3EIQ+kpbAHsQhRp1f2mYH7GvmYMvTMfMrP9g6QzzL+tBtv90fuUK8mR+2Wsdcv6kmSZKDT9z8rD3EUhhGK0Ab4DmKsGMKov99oAitaTVvrNhm4xaRqrIZ4KJb5S14UTa3cAA16wxfS5QzYd7ltEaw733uLi4fXSqfdh+RSXYhQMZRyYLOEdETRCNlFfrajFcxmHzIo06cq3ssyxVvoH9pXWPQ/c0PkKfYS2VjRk+4iAyHZCeRtHhh2DHxcvaWuZspIBis0ZNw8gA1oBwee8P5P8sTbK6GsOUNnJ5EMt1BWZE9Hh9IJCnQJ7FmrLmKZs83UUDza1XaWwcmGGCbdLiGwZtqf+TEwDHVfLp3jqL7Gn3iwU6aDkjSopZHPFdZgYAqne/yoZ4YuDq6+X5TxI9wLa7kCorqGy9t7XCV0gYcxMRuj+jqRDR+Bq39xWGVkekOSkD7JBb66lQZYwL6wXQjaL7g6hf8AyJm5p3XVRwKqu3R++7iZuphnPaABQUH5XTEqo0cez5PIlYV9a42I+ySn1CgUj5RpP0zWa79VXp5RKxMS0sWcgCpLUE2ct/nylDG6RYvvHcAWzWWJ0GqTv2iICRKBdvSFUrtMhf8ALeNR7LqWtOvnNJB/K/G5fitQloYzvD8u7jo45i8qflo8icjyQU3EiNeeaH1qLWWv2sUeSrX0zBDTo1vWgZpHZLBNstA1aCiEvZ0DKTRXdQiXlbvs694tRHafb2gFLUym1/EsGpcuLAQnZKHki9hbHz9nPvKOeY0NUecAeoZvt5XD30UGAfPp7w4vSK2RxQVcNuDeBO5Tf10cqntvg2L0IbBd/Y4DtLlyhLPG4udzzRaikvwXPhlzBERJsUdlpiqC5mYZavZL8GggkHkn1AoOpUTTS69JRQOtTiVYCYHL5suWlplLlpad0slxZcW4giJWGEsTZ5QPul55AHzKgLTtKgFDZbZkmLKeC+8vvL7y+8vvLlhKeBjKXFzLIdaA6kG29L6mU3mZV9v3HVAs01H4IFwpXLlstO6Uj4NesqckRKx6h7ynhG/SI8IGNsB6sbk8vqDZKG+j6ZYbl5931KXuPqyo2neTvJ3kU5Z3k7yPWS/Vi+rHHcXUvL9YjzFHMt6y5YAARJvmbuwC9V6xSp4Cg4sgxadq/qLhYuUVX6QgoE2Gr8pYcZ5ikLipaWixb1guoo3LUje7mes9Y7h1uevhzmUdQtBBRWcVOPmYMyO58oOrkcWEwC6HJ9R9ptZ7+IMdQ4pedbilJAdquJXNxSoAtxQlzJFx5ziZYmo1d0yqWSyIvmWTNi0XqlYhZRWQeo/UeBDBMwXmUXrU/wAMQylncn6hxS8BA8q+0K6DjZ8S9gCXeoA0V7QTxHKn4le8v1MStQFR733iThlb1G+os7xFZqNCiFvFRVQ7os8mCnVlSrQTZksBafGFMbFabqDpQDANvh5t+8b4RwWOb5sls01QtHzqcUUqD69YMXdONZAm6W5YvLfvADSssaHrG3cC8S3al3BKOBlhf3lCEYUYX0gnSgVwhj1ktsCf/9k=', - }, - }, - ], - }, - }, - ]; - for (const test of testCases) { - it(test.should, () => { - assert.deepEqual( - toGeminiMessage(test.inputMessage as MessageData), - test.expectedOutput - ); - }); - } -}); - -describe('toGeminiSystemInstruction', () => { - const testCases = [ - { - should: 'should transform from system to user', - inputMessage: { - role: 'system', - content: [{ text: 'You are an expert in all things cats.' }], - }, - expectedOutput: { - role: 'user', - parts: [{ text: 'You are an expert in all things cats.' }], - }, - }, - { - should: 'should transform from system to user with multiple parts', - inputMessage: { - role: 'system', - content: [ - { text: 'You are an expert in all things animals.' }, - { text: 'You love cats.' }, - ], - }, - expectedOutput: { - role: 'user', - parts: [ - { text: 'You are an expert in all things animals.' }, - { text: 'You love cats.' }, - ], - }, - }, - ]; - for (const test of testCases) { - it(test.should, () => { - assert.deepEqual( - toGeminiSystemInstruction(test.inputMessage as MessageData), - test.expectedOutput - ); - }); - } -}); - -describe('fromGeminiCandidate', () => { - const testCases = [ - { - should: - 'should transform gemini candidate to genkit candidate (text parts) correctly', - // had to delete the probabilityScore, severity, severityScore for the HARM_CATEGORY_SEXUALLY_EXPLICIT safety rating category - geminiCandidate: { - content: { - role: 'model', - parts: [ - { - text: 'Why did the dog go to the bank?\n\nTo get his bones cashed!', - }, - ], - }, - finishReason: 'STOP', - safetyRatings: [ - { - category: 'HARM_CATEGORY_HATE_SPEECH', - probability: 'NEGLIGIBLE', - probabilityScore: 0.12074952, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.18388656, - }, - { - category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.37874627, - severity: 'HARM_SEVERITY_LOW', - severityScore: 0.37227696, - }, - { - category: 'HARM_CATEGORY_HARASSMENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.3983479, - severity: 'HARM_SEVERITY_LOW', - severityScore: 0.22270013, - }, - { - category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - probability: 'NEGLIGIBLE', - }, - ], - }, - expectedOutput: { - index: 0, - message: { - role: 'model', - content: [ - { - text: 'Why did the dog go to the bank?\n\nTo get his bones cashed!', - }, - ], - }, - finishReason: 'stop', - finishMessage: undefined, - custom: { - citationMetadata: undefined, - safetyRatings: [ - { - category: 'HARM_CATEGORY_HATE_SPEECH', - probability: 'NEGLIGIBLE', - probabilityScore: 0.12074952, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.18388656, - }, - { - category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.37874627, - severity: 'HARM_SEVERITY_LOW', - severityScore: 0.37227696, - }, - { - category: 'HARM_CATEGORY_HARASSMENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.3983479, - severity: 'HARM_SEVERITY_LOW', - severityScore: 0.22270013, - }, - { - category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - probability: 'NEGLIGIBLE', - }, - ], - }, - }, - }, - { - should: - 'should transform gemini candidate to genkit candidate (function call parts) correctly', - geminiCandidate: { - content: { - role: 'model', - parts: [ - { - functionCall: { name: 'tellAFunnyJoke', args: { topic: 'dog' } }, - }, - ], - }, - finishReason: 'STOP', - safetyRatings: [ - { - category: 'HARM_CATEGORY_HATE_SPEECH', - probability: 'NEGLIGIBLE', - probabilityScore: 0.11858909, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.11456649, - }, - { - category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.13857833, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.11417085, - }, - { - category: 'HARM_CATEGORY_HARASSMENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.28012377, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.112405084, - }, - { - category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - probability: 'NEGLIGIBLE', - }, - ], - }, - expectedOutput: { - index: 0, - message: { - role: 'model', - content: [ - { - toolRequest: { name: 'tellAFunnyJoke', input: { topic: 'dog' } }, - }, - ], - }, - finishReason: 'stop', - finishMessage: undefined, - custom: { - citationMetadata: undefined, - safetyRatings: [ - { - category: 'HARM_CATEGORY_HATE_SPEECH', - probability: 'NEGLIGIBLE', - probabilityScore: 0.11858909, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.11456649, - }, - { - category: 'HARM_CATEGORY_DANGEROUS_CONTENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.13857833, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.11417085, - }, - { - category: 'HARM_CATEGORY_HARASSMENT', - probability: 'NEGLIGIBLE', - probabilityScore: 0.28012377, - severity: 'HARM_SEVERITY_NEGLIGIBLE', - severityScore: 0.112405084, - }, - { - category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT', - probability: 'NEGLIGIBLE', - }, - ], - }, - }, - }, - ]; - for (const test of testCases) { - it(test.should, () => { - assert.deepEqual( - fromGeminiCandidate(test.geminiCandidate as GenerateContentCandidate), - test.expectedOutput - ); - }); - } -}); diff --git a/js/plugins/checks/tests/vector-search/bigquery_test.ts b/js/plugins/checks/tests/vector-search/bigquery_test.ts deleted file mode 100644 index 1cbc54314..000000000 --- a/js/plugins/checks/tests/vector-search/bigquery_test.ts +++ /dev/null @@ -1,168 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import { BigQuery } from '@google-cloud/bigquery'; -import { Document } from 'genkit/retriever'; -import assert from 'node:assert'; -import { describe, it } from 'node:test'; -import { getBigQueryDocumentRetriever } from '../../src'; - -class MockBigQuery { - query: Function; - - constructor({ - mockRows, - shouldThrowError = false, - }: { - mockRows: any[]; - shouldThrowError?: boolean; - }) { - this.query = async (_options: { - query: string; - params: { ids: string[] }; - }) => { - if (shouldThrowError) { - throw new Error('Query failed'); - } - return [mockRows]; - }; - } -} - -describe('getBigQueryDocumentRetriever', () => { - it('returns a function that retrieves documents from BigQuery', async () => { - const doc1 = Document.fromText('content1'); - const doc2 = Document.fromText('content2'); - - const mockRows = [ - { - id: '1', - content: JSON.stringify(doc1.content), - metadata: null, - }, - { - id: '2', - content: JSON.stringify(doc2.content), - metadata: null, - }, - ]; - - const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; - const documentRetriever = getBigQueryDocumentRetriever( - mockBigQuery, - 'test-table', - 'test-dataset' - ); - - const documents = await documentRetriever([ - { datapoint: { datapointId: '1' } }, - { datapoint: { datapointId: '2' } }, - ]); - - assert.deepStrictEqual(documents, [doc1, doc2]); - }); - - it('returns an empty array when no documents match', async () => { - const mockRows: any[] = []; - - const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; - const documentRetriever = getBigQueryDocumentRetriever( - mockBigQuery, - 'test-table', - 'test-dataset' - ); - - const documents = await documentRetriever([ - { datapoint: { datapointId: '3' } }, - ]); - - assert.deepStrictEqual(documents, []); - }); - - it('handles BigQuery query errors', async () => { - const mockBigQuery = new MockBigQuery({ - mockRows: [], - shouldThrowError: true, - }) as unknown as BigQuery; - const documentRetriever = getBigQueryDocumentRetriever( - mockBigQuery, - 'test-table', - 'test-dataset' - ); - // no need to assert the error, just make sure it doesn't throw - await documentRetriever([{ datapoint: { datapointId: '1' } }]); - }); - - it('filters out invalid documents', async () => { - const validDoc = Document.fromText('valid content'); - const mockRows = [ - { - id: '1', - content: JSON.stringify(validDoc.content), - metadata: null, - }, - { - id: '2', - content: 'invalid JSON', - metadata: null, - }, - ]; - - const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; - const documentRetriever = getBigQueryDocumentRetriever( - mockBigQuery, - 'test-table', - 'test-dataset' - ); - - const documents = await documentRetriever([ - { datapoint: { datapointId: '1' } }, - { datapoint: { datapointId: '2' } }, - ]); - - assert.deepStrictEqual(documents, [validDoc]); - }); - - it('handles missing content in documents', async () => { - const validDoc = Document.fromText('valid content'); - const mockRows = [ - { - id: '1', - content: JSON.stringify(validDoc.content), - metadata: null, - }, - { - id: '2', - content: null, - metadata: null, - }, - ]; - - const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery; - const documentRetriever = getBigQueryDocumentRetriever( - mockBigQuery, - 'test-table', - 'test-dataset' - ); - - const documents = await documentRetriever([ - { datapoint: { datapointId: '1' } }, - { datapoint: { datapointId: '2' } }, - ]); - - assert.deepStrictEqual(documents, [validDoc]); - }); -}); diff --git a/js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts b/js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts deleted file mode 100644 index 9419f2916..000000000 --- a/js/plugins/checks/tests/vector-search/query_public_endpoint_test.ts +++ /dev/null @@ -1,86 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import assert from 'assert'; -import { describe, it, Mock } from 'node:test'; -import { queryPublicEndpoint } from '../../src/vector-search/query_public_endpoint'; - -describe('queryPublicEndpoint', () => { - // FIXME -- t.mock.method is not supported node above 20 - it.skip('queryPublicEndpoint sends the correct request and retrieves neighbors', async (t) => { - t.mock.method(global, 'fetch', async (url, options) => { - return { - ok: true, - json: async () => ({ neighbors: ['neighbor1', 'neighbor2'] }), - } as any; - }); - - const params = { - featureVector: [0.1, 0.2, 0.3], - neighborCount: 5, - accessToken: 'test-access-token', - projectId: 'test-project-id', - location: 'us-central1', - indexEndpointId: 'idx123', - publicDomainName: 'example.com', - projectNumber: '123456789', - deployedIndexId: 'deployed-idx123', - }; - - const expectedResponse = { neighbors: ['neighbor1', 'neighbor2'] }; - - const response = await queryPublicEndpoint(params); - - const calls = ( - global.fetch as Mock< - (url: string, options: Record) => Promise - > - ).mock.calls; - - assert.strictEqual(calls.length, 1); - - const [url, options] = calls[0].arguments; - - const expectedUrl = `https://example.com/v1/projects/123456789/locations/us-central1/indexEndpoints/idx123:findNeighbors`; - - assert.strictEqual(url.toString(), expectedUrl); - - assert.strictEqual(options.method, 'POST'); - - assert.strictEqual(options.headers['Content-Type'], 'application/json'); - assert.strictEqual( - options.headers['Authorization'], - 'Bearer test-access-token' - ); - - const body = JSON.parse(options.body); - assert.deepStrictEqual(body, { - deployed_index_id: 'deployed-idx123', - queries: [ - { - datapoint: { - datapoint_id: '0', - feature_vector: [0.1, 0.2, 0.3], - }, - neighbor_count: 5, - }, - ], - }); - - // Verifying the response - assert.deepStrictEqual(response, expectedResponse); - }); -}); diff --git a/js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts b/js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts deleted file mode 100644 index 5b36a47d0..000000000 --- a/js/plugins/checks/tests/vector-search/upsert_datapoints_test.ts +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import assert from 'assert'; -import { GoogleAuth } from 'google-auth-library'; -import { describe, it, Mock } from 'node:test'; -import { IIndexDatapoint } from '../../src/vector-search/types'; -import { upsertDatapoints } from '../../src/vector-search/upsert_datapoints'; - -describe('upsertDatapoints', () => { - // FIXME -- t.mock.method is not supported node above 20 - it.skip('upsertDatapoints sends the correct request and handles response', async (t) => { - // Mocking the fetch method within the test scope - t.mock.method(global, 'fetch', async (url, options) => { - return { - ok: true, - json: async () => ({}), - } as any; - }); - - // Mocking the GoogleAuth client - const mockAuthClient = { - getAccessToken: async () => 'test-access-token', - } as GoogleAuth; - - const params = { - datapoints: [ - { datapointId: 'dp1', featureVector: [0.1, 0.2, 0.3] }, - { datapointId: 'dp2', featureVector: [0.4, 0.5, 0.6] }, - ] as IIndexDatapoint[], - authClient: mockAuthClient, - projectId: 'test-project-id', - location: 'us-central1', - indexId: 'idx123', - }; - - await upsertDatapoints(params); - - // Verifying the fetch call - const calls = ( - global.fetch as Mock< - (url: string, options: Record) => Promise - > - ).mock.calls; - - assert.strictEqual(calls.length, 1); - const [url, options] = calls[0].arguments; - - assert.strictEqual( - url.toString(), - 'https://us-central1-aiplatform.googleapis.com/v1/projects/test-project-id/locations/us-central1/indexes/idx123:upsertDatapoints' - ); - assert.strictEqual(options.method, 'POST'); - assert.strictEqual(options.headers['Content-Type'], 'application/json'); - assert.strictEqual( - options.headers['Authorization'], - 'Bearer test-access-token' - ); - - const body = JSON.parse(options.body); - assert.deepStrictEqual(body, { - datapoints: [ - { datapoint_id: 'dp1', feature_vector: [0.1, 0.2, 0.3] }, - { datapoint_id: 'dp2', feature_vector: [0.4, 0.5, 0.6] }, - ], - }); - }); -}); diff --git a/js/plugins/checks/tests/vector-search/utils_test.ts b/js/plugins/checks/tests/vector-search/utils_test.ts deleted file mode 100644 index 38b130b3a..000000000 --- a/js/plugins/checks/tests/vector-search/utils_test.ts +++ /dev/null @@ -1,70 +0,0 @@ -/** - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import assert from 'assert'; -import { google } from 'googleapis'; -import { describe, it } from 'node:test'; -import { - getAccessToken, - getProjectNumber, -} from '../../src/vector-search/utils'; - -// Mocking the google.auth.getClient method -google.auth.getClient = async () => { - return { - getRequestHeaders: async () => ({ Authorization: 'Bearer test-token' }), - } as any; // Using `any` to bypass type checks for the mock -}; - -// Mocking the google.cloudresourcemanager method -google.cloudresourcemanager = () => { - return { - projects: { - get: async ({ projectId }) => { - return { - data: { - projectNumber: '123456789', - }, - }; - }, - }, - } as any; // Using `any` to bypass type checks for the mock -}; - -describe('utils', () => { - it('getProjectNumber retrieves the project number', async () => { - const projectId = 'test-project-id'; - const expectedProjectNumber = '123456789'; - - const projectNumber = await getProjectNumber(projectId); - assert.strictEqual(projectNumber, expectedProjectNumber); - }); - - // Mocking the GoogleAuth client - const mockAuthClient = { - getAccessToken: async () => ({ token: 'test-access-token' }), - }; - - it('getAccessToken retrieves the access token', async () => { - // Mocking the GoogleAuth.getClient method to return the mockAuthClient - const auth = { - getClient: async () => mockAuthClient, - } as any; // Using `any` to bypass type checks for the mock - - const accessToken = await getAccessToken(auth); - assert.strictEqual(accessToken, 'test-access-token'); - }); -}); From 7082b6017bbd61b997b2b90a7e54854d36458041 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Thu, 7 Nov 2024 19:08:02 +0000 Subject: [PATCH 10/30] lock file updated. --- js/pnpm-lock.yaml | 144 +++++++++++++++++++++++----------------------- 1 file changed, 72 insertions(+), 72 deletions(-) diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index d17bcfdca..7b364f331 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -1096,7 +1096,7 @@ importers: version: link:../../plugins/vertexai '@opentelemetry/sdk-trace-base': specifier: ^1.25.0 - version: 1.26.0(@opentelemetry/api@1.9.0) + version: 1.25.1(@opentelemetry/api@1.9.0) genkit: specifier: workspace:* version: link:../../genkit @@ -1232,7 +1232,7 @@ importers: version: link:../../plugins/ollama genkitx-openai: specifier: ^0.10.1 - version: 0.10.1(@genkit-ai/ai@0.9.0-dev.3)(@genkit-ai/core@0.9.0-dev.3) + version: 0.10.1(@genkit-ai/ai@0.9.0-dev.4)(@genkit-ai/core@0.9.0-dev.4) devDependencies: rimraf: specifier: ^6.0.1 @@ -1999,11 +1999,11 @@ packages: '@firebase/util@1.9.5': resolution: {integrity: sha512-PP4pAFISDxsf70l3pEy34Mf3GkkUcVQ3MdKp6aSVb7tcpfUQxnsdV7twDd8EkfB6zZylH6wpUAoangQDmCUMqw==} - '@genkit-ai/ai@0.9.0-dev.3': - resolution: {integrity: sha512-fXi7onEpViZX86dPq0xWsqxivvXQMf9wH3boaNJFDJg22YvJMpb5MDV+jgzmXbKwGFCVaFAsePMBzv7Ikt703A==} + '@genkit-ai/ai@0.9.0-dev.4': + resolution: {integrity: sha512-j7mCfJnPupK9tqkESV+SVtwGAfGFB6CnIr/NXeTZleU6cupocP0uFkZKi72HbdMYk2VI38spplp5aIt4jW/wNA==} - '@genkit-ai/core@0.9.0-dev.3': - resolution: {integrity: sha512-fA8XUVYY9K77zWG0AVVHjcYe3XCYfnUuTOf11VMbvxVG3t/cHfkrKzLG/lxyO41O1vfL7xwIcaHLlpzXVQmZPQ==} + '@genkit-ai/core@0.9.0-dev.4': + resolution: {integrity: sha512-v6QpSedACJU/jKJGukJKHM5sPJdyYKPoyzAMyztWvVD12t2bkvXYL7+QyCeB/cUE7cijyO4w/2lRNyZciyAgMw==} '@google-cloud/aiplatform@3.25.0': resolution: {integrity: sha512-qKnJgbyCENjed8e1G5zZGFTxxNKhhaKQN414W2KIVHrLxMFmlMuG+3QkXPOWwXBnT5zZ7aMxypt5og0jCirpHg==} @@ -6766,9 +6766,9 @@ snapshots: dependencies: tslib: 2.6.2 - '@genkit-ai/ai@0.9.0-dev.3': + '@genkit-ai/ai@0.9.0-dev.4': dependencies: - '@genkit-ai/core': 0.9.0-dev.3 + '@genkit-ai/core': 0.9.0-dev.4 '@opentelemetry/api': 1.9.0 '@types/node': 20.16.9 colorette: 2.0.20 @@ -6778,16 +6778,16 @@ snapshots: transitivePeerDependencies: - supports-color - '@genkit-ai/core@0.9.0-dev.3': + '@genkit-ai/core@0.9.0-dev.4': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/context-async-hooks': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-metrics': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-trace-base': 1.26.0(@opentelemetry/api@1.9.0) - ajv: 8.17.1 - ajv-formats: 3.0.1(ajv@8.17.1) + '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) + ajv: 8.12.0 + ajv-formats: 3.0.1(ajv@8.12.0) async-mutex: 0.5.0 body-parser: 1.20.3 cors: 2.8.5 @@ -6795,7 +6795,7 @@ snapshots: get-port: 5.1.0 json-schema: 0.4.0 zod: 3.23.8 - zod-to-json-schema: 3.23.3(zod@3.23.8) + zod-to-json-schema: 3.22.5(zod@3.23.8) transitivePeerDependencies: - supports-color @@ -7386,9 +7386,9 @@ snapshots: '@opentelemetry/instrumentation-amqplib@0.41.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7397,8 +7397,8 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/propagator-aws-xray': 1.3.1(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.26.0 '@types/aws-lambda': 8.10.122 transitivePeerDependencies: - supports-color @@ -7406,10 +7406,10 @@ snapshots: '@opentelemetry/instrumentation-aws-sdk@0.43.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/propagation-utils': 0.30.10(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7426,16 +7426,16 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-connect@0.38.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 '@types/connect': 3.4.36 transitivePeerDependencies: - supports-color @@ -7444,7 +7444,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7466,25 +7466,25 @@ snapshots: '@opentelemetry/instrumentation-express@0.41.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-fastify@0.38.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-fs@0.14.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) transitivePeerDependencies: - supports-color @@ -7514,9 +7514,9 @@ snapshots: '@opentelemetry/instrumentation-hapi@0.40.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7535,7 +7535,7 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/redis-common': 0.36.2 - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7543,7 +7543,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7551,16 +7551,16 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-koa@0.42.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7575,7 +7575,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 '@types/memcached': 2.2.10 transitivePeerDependencies: - supports-color @@ -7585,16 +7585,16 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-mongoose@0.40.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7602,7 +7602,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 '@opentelemetry/sql-common': 0.40.1(@opentelemetry/api@1.9.0) transitivePeerDependencies: - supports-color @@ -7611,7 +7611,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 '@types/mysql': 2.15.22 transitivePeerDependencies: - supports-color @@ -7620,7 +7620,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7628,7 +7628,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7636,7 +7636,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 '@opentelemetry/sql-common': 0.40.1(@opentelemetry/api@1.9.0) '@types/pg': 8.6.1 '@types/pg-pool': 2.0.4 @@ -7657,7 +7657,7 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/redis-common': 0.36.2 - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7666,16 +7666,16 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/redis-common': 0.36.2 - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-restify@0.40.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7683,7 +7683,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7691,7 +7691,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 transitivePeerDependencies: - supports-color @@ -7699,7 +7699,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/semantic-conventions': 1.26.0 '@types/tedious': 4.0.14 transitivePeerDependencies: - supports-color @@ -7707,7 +7707,7 @@ snapshots: '@opentelemetry/instrumentation-undici@0.4.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) transitivePeerDependencies: - supports-color @@ -7764,7 +7764,7 @@ snapshots: '@opentelemetry/propagator-aws-xray@1.3.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/propagator-b3@1.25.1(@opentelemetry/api@1.9.0)': dependencies: @@ -7781,34 +7781,34 @@ snapshots: '@opentelemetry/resource-detector-alibaba-cloud@0.29.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.26.0 '@opentelemetry/resource-detector-aws@1.5.2(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.26.0 '@opentelemetry/resource-detector-azure@0.2.9(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.26.0 '@opentelemetry/resource-detector-container@0.3.11(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.26.0 '@opentelemetry/resource-detector-gcp@0.29.10(@opentelemetry/api@1.9.0)(encoding@0.1.13)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.26.0 gcp-metadata: 6.1.0(encoding@0.1.13) transitivePeerDependencies: - encoding @@ -7877,7 +7877,7 @@ snapshots: '@opentelemetry/sql-common@0.40.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@pinecone-database/pinecone@2.2.0': dependencies: @@ -9317,10 +9317,10 @@ snapshots: - encoding - supports-color - genkitx-openai@0.10.1(@genkit-ai/ai@0.9.0-dev.3)(@genkit-ai/core@0.9.0-dev.3): + genkitx-openai@0.10.1(@genkit-ai/ai@0.9.0-dev.4)(@genkit-ai/core@0.9.0-dev.4): dependencies: - '@genkit-ai/ai': 0.9.0-dev.3 - '@genkit-ai/core': 0.9.0-dev.3 + '@genkit-ai/ai': 0.9.0-dev.4 + '@genkit-ai/core': 0.9.0-dev.4 openai: 4.53.0(encoding@0.1.13) zod: 3.23.8 transitivePeerDependencies: From da3c312419bf83cad3b4f6357d5906b2300a62d8 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Fri, 8 Nov 2024 00:52:18 +0000 Subject: [PATCH 11/30] Removed byo updates --- js/pnpm-lock.yaml | 217 +++++++++++++++++-------- js/testapps/byo-evaluator/package.json | 3 +- js/testapps/byo-evaluator/src/index.ts | 40 ----- 3 files changed, 154 insertions(+), 106 deletions(-) diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 7b364f331..7f67de2fb 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -757,9 +757,6 @@ importers: testapps/byo-evaluator: dependencies: - '@genkit-ai/checks': - specifier: workspace:* - version: link:../../plugins/checks '@genkit-ai/dev-local-vectorstore': specifier: workspace:* version: link:../../plugins/dev-local-vectorstore @@ -1096,7 +1093,7 @@ importers: version: link:../../plugins/vertexai '@opentelemetry/sdk-trace-base': specifier: ^1.25.0 - version: 1.25.1(@opentelemetry/api@1.9.0) + version: 1.26.0(@opentelemetry/api@1.9.0) genkit: specifier: workspace:* version: link:../../genkit @@ -2542,12 +2539,24 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/context-async-hooks@1.26.0': + resolution: {integrity: sha512-HedpXXYzzbaoutw6DFLWLDket2FwLkLpil4hGCZ1xYEIMTcivdfwEOISgdbLEWyG3HW52gTq2V9mOVJrONgiwg==} + engines: {node: '>=14'} + peerDependencies: + '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/core@1.25.1': resolution: {integrity: sha512-GeT/l6rBYWVQ4XArluLVB6WWQ8flHbdb6r2FCHC3smtdOAbrJBIv35tpV/yp9bmYUJf+xmZpu9DRTIeJVhFbEQ==} engines: {node: '>=14'} peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/core@1.26.0': + resolution: {integrity: sha512-1iKxXXE8415Cdv0yjG3G6hQnB5eVEsJce3QaawX8SjDn0mAS0ZM8fAbZZJD4ajvhC15cePvosSCut404KrIIvQ==} + engines: {node: '>=14'} + peerDependencies: + '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/exporter-trace-otlp-grpc@0.52.1': resolution: {integrity: sha512-pVkSH20crBwMTqB3nIN4jpQKUEoB0Z94drIHpYyEqs7UBr+I0cpYyOR3bqjA/UasQUMROb3GX8ZX4/9cVRqGBQ==} engines: {node: '>=14'} @@ -2894,6 +2903,12 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/resources@1.26.0': + resolution: {integrity: sha512-CPNYchBE7MBecCSVy0HKpUISEeJOniWqcHaAHpmasZ3j9o6V3AyBzhRc90jdmemq0HOxDr6ylhUbDhBqqPpeNw==} + engines: {node: '>=14'} + peerDependencies: + '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/sdk-logs@0.52.1': resolution: {integrity: sha512-MBYh+WcPPsN8YpRHRmK1Hsca9pVlyyKd4BxOC4SsgHACnl/bPp4Cri9hWhVm5+2tiQ9Zf4qSc1Jshw9tOLGWQA==} engines: {node: '>=14'} @@ -2906,6 +2921,12 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.3.0 <1.10.0' + '@opentelemetry/sdk-metrics@1.26.0': + resolution: {integrity: sha512-0SvDXmou/JjzSDOjUmetAAvcKQW6ZrvosU0rkbDGpXvvZN+pQF6JbK/Kd4hNdK4q/22yeruqvukXEJyySTzyTQ==} + engines: {node: '>=14'} + peerDependencies: + '@opentelemetry/api': '>=1.3.0 <1.10.0' + '@opentelemetry/sdk-node@0.52.1': resolution: {integrity: sha512-uEG+gtEr6eKd8CVWeKMhH2olcCHM9dEK68pe0qE0be32BcCRsvYURhHaD1Srngh1SQcnQzZ4TP324euxqtBOJA==} engines: {node: '>=14'} @@ -2918,6 +2939,12 @@ packages: peerDependencies: '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/sdk-trace-base@1.26.0': + resolution: {integrity: sha512-olWQldtvbK4v22ymrKLbIcBi9L2SpMO84sCPY54IVsJhP9fRsxJT194C/AVaAuJzLE30EdhhM1VmvVYR7az+cw==} + engines: {node: '>=14'} + peerDependencies: + '@opentelemetry/api': '>=1.0.0 <1.10.0' + '@opentelemetry/sdk-trace-node@1.25.1': resolution: {integrity: sha512-nMcjFIKxnFqoez4gUmihdBrbpsEnAX/Xj16sGvZm+guceYE0NE00vLhpDVK6f3q8Q4VFI5xG8JjlXKMB/SkTTQ==} engines: {node: '>=14'} @@ -2932,6 +2959,10 @@ packages: resolution: {integrity: sha512-U9PJlOswJPSgQVPI+XEuNLElyFWkb0hAiMg+DExD9V0St03X2lPHGMdxMY/LrVmoukuIpXJ12oyrOtEZ4uXFkw==} engines: {node: '>=14'} + '@opentelemetry/semantic-conventions@1.27.0': + resolution: {integrity: sha512-sAay1RrB+ONOem0OZanAR1ZI/k7yDpnOQSQmTMuGImUQb2y8EbSaCJ94FQluM74xoU03vlb2d2U90hZluL6nQg==} + engines: {node: '>=14'} + '@opentelemetry/sql-common@0.40.1': resolution: {integrity: sha512-nSDlnHSqzC3pXn/wZEZVLuAuJ1MYMXPBwtv2qAbCa3847SaHItdE7SzUq/Jtb0KZmh1zfAbNi3AAMjztTT4Ugg==} engines: {node: '>=14'} @@ -3350,6 +3381,9 @@ packages: ajv@8.12.0: resolution: {integrity: sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==} + ajv@8.17.1: + resolution: {integrity: sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==} + ansi-escapes@4.3.2: resolution: {integrity: sha512-gKXj5ALrKWQLsYG9jlTRmR/xKluxHV+Z9QEwNIgCfM1/uwPMCuzVVnh5mwTd+OuBZcwSIMbqssNWRm1lE51QaQ==} engines: {node: '>=8'} @@ -4023,6 +4057,9 @@ packages: fast-text-encoding@1.0.6: resolution: {integrity: sha512-VhXlQgj9ioXCqGstD37E/HBeqEGV/qOD/kmbVG8h5xKBYvM1L3lR1Zn4555cQ8GkYbJa8aJSipLPndE1k6zK2w==} + fast-uri@3.0.1: + resolution: {integrity: sha512-MWipKbbYiYI0UC7cl8m/i/IWTqfC8YXsqjzybjddLsFjStroQzsHXkc73JutMvBiXmOvapk+axIl79ig5t55Bw==} + fast-xml-parser@4.3.6: resolution: {integrity: sha512-M2SovcRxD4+vC493Uc2GZVcZaj66CCJhWurC4viynVSTvrpErCShNcDz1lAho6n9REQKvL/ll4A4/fw6Y9z8nw==} hasBin: true @@ -6323,6 +6360,11 @@ packages: peerDependencies: zod: ^3.22.4 + zod-to-json-schema@3.23.3: + resolution: {integrity: sha512-TYWChTxKQbRJp5ST22o/Irt9KC5nj7CdBKYB/AosCRdj/wxEMvv4NNaj9XVUHDOIp53ZxArGhnw5HMZziPFjog==} + peerDependencies: + zod: ^3.23.3 + zod@3.22.4: resolution: {integrity: sha512-iC+8Io04lddc+mVqQ9AZ7OQ2MrUKGN+oIQyq1vemgt46jwCwLfhq7/pwnBnNXXXZb8VTVLKwp9EDkx+ryxIWmg==} @@ -6781,13 +6823,13 @@ snapshots: '@genkit-ai/core@0.9.0-dev.4': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/context-async-hooks': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-metrics': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-node': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) - ajv: 8.12.0 - ajv-formats: 3.0.1(ajv@8.12.0) + '@opentelemetry/sdk-trace-base': 1.26.0(@opentelemetry/api@1.9.0) + ajv: 8.17.1 + ajv-formats: 3.0.1(ajv@8.17.1) async-mutex: 0.5.0 body-parser: 1.20.3 cors: 2.8.5 @@ -6795,7 +6837,7 @@ snapshots: get-port: 5.1.0 json-schema: 0.4.0 zod: 3.23.8 - zod-to-json-schema: 3.22.5(zod@3.23.8) + zod-to-json-schema: 3.23.3(zod@3.23.8) transitivePeerDependencies: - supports-color @@ -7342,11 +7384,20 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 + '@opentelemetry/context-async-hooks@1.26.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/core@1.25.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/semantic-conventions': 1.25.1 + '@opentelemetry/core@1.26.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/exporter-trace-otlp-grpc@0.52.1(@opentelemetry/api@1.9.0)': dependencies: '@grpc/grpc-js': 1.10.10 @@ -7386,9 +7437,9 @@ snapshots: '@opentelemetry/instrumentation-amqplib@0.41.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7397,8 +7448,8 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/propagator-aws-xray': 1.3.1(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 '@types/aws-lambda': 8.10.122 transitivePeerDependencies: - supports-color @@ -7406,10 +7457,10 @@ snapshots: '@opentelemetry/instrumentation-aws-sdk@0.43.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/propagation-utils': 0.30.10(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7426,16 +7477,16 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-connect@0.38.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 '@types/connect': 3.4.36 transitivePeerDependencies: - supports-color @@ -7444,7 +7495,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7466,25 +7517,25 @@ snapshots: '@opentelemetry/instrumentation-express@0.41.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-fastify@0.38.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-fs@0.14.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) transitivePeerDependencies: - supports-color @@ -7514,9 +7565,9 @@ snapshots: '@opentelemetry/instrumentation-hapi@0.40.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7535,7 +7586,7 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/redis-common': 0.36.2 - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7543,7 +7594,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7551,16 +7602,16 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-koa@0.42.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7575,7 +7626,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 '@types/memcached': 2.2.10 transitivePeerDependencies: - supports-color @@ -7585,16 +7636,16 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-mongoose@0.40.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7602,7 +7653,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 '@opentelemetry/sql-common': 0.40.1(@opentelemetry/api@1.9.0) transitivePeerDependencies: - supports-color @@ -7611,7 +7662,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 '@types/mysql': 2.15.22 transitivePeerDependencies: - supports-color @@ -7620,7 +7671,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7628,7 +7679,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7636,7 +7687,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 '@opentelemetry/sql-common': 0.40.1(@opentelemetry/api@1.9.0) '@types/pg': 8.6.1 '@types/pg-pool': 2.0.4 @@ -7657,7 +7708,7 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/redis-common': 0.36.2 - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7666,16 +7717,16 @@ snapshots: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) '@opentelemetry/redis-common': 0.36.2 - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color '@opentelemetry/instrumentation-restify@0.40.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7683,7 +7734,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7691,7 +7742,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 transitivePeerDependencies: - supports-color @@ -7699,7 +7750,7 @@ snapshots: dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/semantic-conventions': 1.27.0 '@types/tedious': 4.0.14 transitivePeerDependencies: - supports-color @@ -7707,7 +7758,7 @@ snapshots: '@opentelemetry/instrumentation-undici@0.4.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/instrumentation': 0.52.1(@opentelemetry/api@1.9.0) transitivePeerDependencies: - supports-color @@ -7764,7 +7815,7 @@ snapshots: '@opentelemetry/propagator-aws-xray@1.3.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@opentelemetry/propagator-b3@1.25.1(@opentelemetry/api@1.9.0)': dependencies: @@ -7781,34 +7832,34 @@ snapshots: '@opentelemetry/resource-detector-alibaba-cloud@0.29.0(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 '@opentelemetry/resource-detector-aws@1.5.2(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 '@opentelemetry/resource-detector-azure@0.2.9(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 '@opentelemetry/resource-detector-container@0.3.11(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 '@opentelemetry/resource-detector-gcp@0.29.10(@opentelemetry/api@1.9.0)(encoding@0.1.13)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) - '@opentelemetry/semantic-conventions': 1.26.0 + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 gcp-metadata: 6.1.0(encoding@0.1.13) transitivePeerDependencies: - encoding @@ -7820,6 +7871,12 @@ snapshots: '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/semantic-conventions': 1.25.1 + '@opentelemetry/resources@1.26.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/sdk-logs@0.52.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -7834,6 +7891,12 @@ snapshots: '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) lodash.merge: 4.6.2 + '@opentelemetry/sdk-metrics@1.26.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/sdk-node@0.52.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -7860,6 +7923,13 @@ snapshots: '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/semantic-conventions': 1.25.1 + '@opentelemetry/sdk-trace-base@1.26.0(@opentelemetry/api@1.9.0)': + dependencies: + '@opentelemetry/api': 1.9.0 + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/resources': 1.26.0(@opentelemetry/api@1.9.0) + '@opentelemetry/semantic-conventions': 1.27.0 + '@opentelemetry/sdk-trace-node@1.25.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 @@ -7874,10 +7944,12 @@ snapshots: '@opentelemetry/semantic-conventions@1.26.0': {} + '@opentelemetry/semantic-conventions@1.27.0': {} + '@opentelemetry/sql-common@0.40.1(@opentelemetry/api@1.9.0)': dependencies: '@opentelemetry/api': 1.9.0 - '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) + '@opentelemetry/core': 1.26.0(@opentelemetry/api@1.9.0) '@pinecone-database/pinecone@2.2.0': dependencies: @@ -8254,6 +8326,10 @@ snapshots: optionalDependencies: ajv: 8.12.0 + ajv-formats@3.0.1(ajv@8.17.1): + optionalDependencies: + ajv: 8.17.1 + ajv@8.12.0: dependencies: fast-deep-equal: 3.1.3 @@ -8261,6 +8337,13 @@ snapshots: require-from-string: 2.0.2 uri-js: 4.4.1 + ajv@8.17.1: + dependencies: + fast-deep-equal: 3.1.3 + fast-uri: 3.0.1 + json-schema-traverse: 1.0.0 + require-from-string: 2.0.2 + ansi-escapes@4.3.2: dependencies: type-fest: 0.21.3 @@ -9099,6 +9182,8 @@ snapshots: fast-text-encoding@1.0.6: optional: true + fast-uri@3.0.1: {} + fast-xml-parser@4.3.6: dependencies: strnum: 1.0.5 @@ -11800,6 +11885,10 @@ snapshots: dependencies: zod: 3.23.8 + zod-to-json-schema@3.23.3(zod@3.23.8): + dependencies: + zod: 3.23.8 + zod@3.22.4: {} zod@3.23.8: {} diff --git a/js/testapps/byo-evaluator/package.json b/js/testapps/byo-evaluator/package.json index c4ba52e44..391f2919a 100644 --- a/js/testapps/byo-evaluator/package.json +++ b/js/testapps/byo-evaluator/package.json @@ -19,7 +19,6 @@ "@genkit-ai/firebase": "workspace:*", "@genkit-ai/googleai": "workspace:*", "@genkit-ai/vertexai": "workspace:*", - "@genkit-ai/checks": "workspace:*", "genkit": "workspace:*", "path": "^0.12.7" }, @@ -27,4 +26,4 @@ "rimraf": "^6.0.1", "typescript": "^5.3.3" } -} +} \ No newline at end of file diff --git a/js/testapps/byo-evaluator/src/index.ts b/js/testapps/byo-evaluator/src/index.ts index 6dd946b92..9f9e4fdb8 100644 --- a/js/testapps/byo-evaluator/src/index.ts +++ b/js/testapps/byo-evaluator/src/index.ts @@ -36,7 +36,6 @@ import { isRegexMetric, regexMatcher, } from './regex/regex_evaluator.js'; -import { checks, ChecksEvaluationMetricType } from "@genkit-ai/checks" export const ai = genkit({ plugins: [ @@ -55,45 +54,6 @@ export const ai = genkit({ FUNNINESS, ], }), - checks({ - projectId: "checks-api-370419", - evaluation: { - metrics: [ - { - type: ChecksEvaluationMetricType.DANGEROUS_CONTENT, - threshold: .5, - }, - { - type: ChecksEvaluationMetricType.PII_SOLICITING_RECITING, - threshold: .5, - }, - { - type: ChecksEvaluationMetricType.HARASSMENT, - threshold: .5, - }, - { - type: ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, - threshold: .5, - }, - { - type: ChecksEvaluationMetricType.HATE_SPEECH, - threshold: .5, - }, - { - type: ChecksEvaluationMetricType.MEDICAL_INFO, - threshold: .5, - }, - { - type: ChecksEvaluationMetricType.VIOLENCE_AND_GORE, - threshold: .5, - }, - { - type: ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, - threshold: .5, - } - ], - }, - }) ], }); From a49150934e81a3562d36e8bd9386f5397f3aba98 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Fri, 8 Nov 2024 17:27:29 +0000 Subject: [PATCH 12/30] update comments. --- js/plugins/checks/src/evaluation.ts | 7 ++----- js/plugins/checks/src/index.ts | 6 +++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 3c5c71001..2569bfb5a 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -20,8 +20,7 @@ import { BaseEvalDataPoint } from 'genkit/evaluator'; import { runInNewSpan } from 'genkit/tracing'; /** - * Checks AI Safety policies. See API documentation for more information. - * TODO: add documentation link. + * Currently supported Checks AI Safety policies. */ export enum ChecksEvaluationMetricType { // The model facilitates, promotes or enables access to harmful goods, @@ -48,10 +47,8 @@ export enum ChecksEvaluationMetricType { } /** - * Evaluation metric config. Use `metricSpec` to define the behavior of the metric. + * Checks evaluation metric config. Use `threshold` to override the default violation threshold. * The value of `metricSpec` will be included in the request to the API. See the API documentation - * for details on the possible values of `metricSpec` for each metric. - * https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/evaluation#parameter-list */ export type ChecksEvaluationMetricConfig = { type: ChecksEvaluationMetricType; diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index 37c09f72f..ea9116f89 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -27,11 +27,11 @@ export { }; export interface PluginOptions { - /** The Google Cloud project id to call. This is the project with quota for the Checks API*/ + /** The Google Cloud project id to call. Must have quota for the Checks API. */ projectId?: string; - /** Provide custom authentication configuration for connecting to Vertex AI. */ + /** Provide custom authentication configuration for connecting to Checks API. */ googleAuth?: GoogleAuthOptions; - /** Configure Vertex AI evaluators */ + /** Configure Checks evaluators. */ evaluation?: { metrics: ChecksEvaluationMetric[]; }; From 871c3fdd74e50dfdec84dd1756ab189c14f6b422 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Fri, 8 Nov 2024 19:54:00 +0000 Subject: [PATCH 13/30] Log a warning with there is a default quota project. It seems to override the user configured value. --- js/plugins/checks/src/evaluation.ts | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 2569bfb5a..0d036392b 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -169,12 +169,16 @@ async function checksEvalInstance( const client = await auth.getClient(); const url = "https://checks.googleapis.com/v1alpha/aisafety:classifyContent" + if (client.quotaProjectId) { + console.warn(`Checks Evaluator: Your Google cloud authentication has a default quota project(${client.quotaProjectId}) associated with it which will overrid the projectId in your Checks plugin config(${projectId}).`) + } + const response = await client.request({ url, method: "POST", body: JSON.stringify(request), headers: { - "X-Goog-User-Project": projectId, + "x-goog-user-project": projectId, "Content-Type": "application/json", } }) From 74e559b12421210d8534ff3e52ff0b881e960e75 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Fri, 8 Nov 2024 19:56:07 +0000 Subject: [PATCH 14/30] pnpm format --- js/plugins/checks/package.json | 2 +- js/plugins/checks/src/evaluation.ts | 64 +++++++++++++------------- js/plugins/checks/src/index.ts | 12 ++--- js/testapps/byo-evaluator/package.json | 2 +- 4 files changed, 41 insertions(+), 39 deletions(-) diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index 2a95bc73a..a23933b7d 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -63,4 +63,4 @@ "default": "./lib/index.js" } } -} \ No newline at end of file +} diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index 0d036392b..b1700932b 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -15,9 +15,9 @@ */ import { Action, Genkit, z } from 'genkit'; -import { GoogleAuth } from 'google-auth-library'; import { BaseEvalDataPoint } from 'genkit/evaluator'; import { runInNewSpan } from 'genkit/tracing'; +import { GoogleAuth } from 'google-auth-library'; /** * Currently supported Checks AI Safety policies. @@ -63,24 +63,25 @@ export function checksEvaluators( ai: Genkit, auth: GoogleAuth, metrics: ChecksEvaluationMetric[], - projectId: string, + projectId: string ): Action[] { + const policy_configs: ChecksEvaluationMetricConfig[] = metrics.map( + (metric) => { + const metricType = isConfig(metric) ? metric.type : metric; + const threshold = isConfig(metric) ? metric.threshold : undefined; - const policy_configs: ChecksEvaluationMetricConfig[] = metrics.map((metric) => { - const metricType = isConfig(metric) ? metric.type : metric; - const threshold = isConfig(metric) ? metric.threshold : undefined; - - return { - type: metricType, - threshold, + return { + type: metricType, + threshold, + }; } - }); + ); const evaluators = policy_configs.map((policy_config) => { - return createPolicyEvaluator(projectId, auth, ai, policy_config) - }) + return createPolicyEvaluator(projectId, auth, ai, policy_config); + }); - return evaluators + return evaluators; } function isConfig( @@ -94,9 +95,9 @@ const ResponseSchema = z.object({ z.object({ policyType: z.string(), score: z.number(), - violationResult: z.string() + violationResult: z.string(), }) - ) + ), }); function createPolicyEvaluator( @@ -111,19 +112,19 @@ function createPolicyEvaluator( { name: `checks/${policyType.toLowerCase()}`, displayName: policyType, - definition: `Evaluates text against the Checks ${policyType} policy.` + definition: `Evaluates text against the Checks ${policyType} policy.`, }, async (datapoint: BaseEvalDataPoint) => { const partialRequest = { input: { text_input: { - content: datapoint.output as string + content: datapoint.output as string, }, }, policies: { policy_type: policy_config.type, threshold: policy_config.threshold, - } + }, }; const response = await checksEvalInstance( @@ -137,13 +138,13 @@ function createPolicyEvaluator( evaluation: { score: response.policyResults[0].score, details: { - reasoning: response.policyResults[0].violationResult - } + reasoning: response.policyResults[0].violationResult, + }, }, testCaseId: datapoint.testCaseId, }; } - ) + ); } async function checksEvalInstance( @@ -152,7 +153,6 @@ async function checksEvalInstance( partialRequest: any, responseSchema: ResponseType ): Promise> { - return await runInNewSpan( { metadata: { @@ -164,24 +164,26 @@ async function checksEvalInstance( ...partialRequest, }; - metadata.input = request; const client = await auth.getClient(); - const url = "https://checks.googleapis.com/v1alpha/aisafety:classifyContent" + const url = + 'https://checks.googleapis.com/v1alpha/aisafety:classifyContent'; if (client.quotaProjectId) { - console.warn(`Checks Evaluator: Your Google cloud authentication has a default quota project(${client.quotaProjectId}) associated with it which will overrid the projectId in your Checks plugin config(${projectId}).`) + console.warn( + `Checks Evaluator: Your Google cloud authentication has a default quota project(${client.quotaProjectId}) associated with it which will overrid the projectId in your Checks plugin config(${projectId}).` + ); } const response = await client.request({ url, - method: "POST", + method: 'POST', body: JSON.stringify(request), headers: { - "x-goog-user-project": projectId, - "Content-Type": "application/json", - } - }) + 'x-goog-user-project': projectId, + 'Content-Type': 'application/json', + }, + }); metadata.output = response.data; try { @@ -191,4 +193,4 @@ async function checksEvalInstance( } } ); -} \ No newline at end of file +} diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index ea9116f89..b43b2fed3 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Genkit, z } from 'genkit'; +import { Genkit } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; import { @@ -22,9 +22,7 @@ import { ChecksEvaluationMetricType, checksEvaluators, } from './evaluation.js'; -export { - ChecksEvaluationMetricType as ChecksEvaluationMetricType, -}; +export { ChecksEvaluationMetricType as ChecksEvaluationMetricType }; export interface PluginOptions { /** The Google Cloud project id to call. Must have quota for the Checks API. */ @@ -43,7 +41,7 @@ const CLOUD_PLATFROM_OAUTH_SCOPE = const CHECKS_OAUTH_SCOPE = 'https://www.googleapis.com/auth/checks'; /** - * Add Google Checks evaluators. + * Add Google Checks evaluators. */ export function checks(options?: PluginOptions): GenkitPlugin { return genkitPlugin('checks', async (ai: Genkit) => { @@ -63,7 +61,9 @@ export function checks(options?: PluginOptions): GenkitPlugin { authClient = new GoogleAuth(authOptions); } else { authClient = new GoogleAuth( - authOptions ?? { scopes: [CLOUD_PLATFROM_OAUTH_SCOPE, CHECKS_OAUTH_SCOPE] } + authOptions ?? { + scopes: [CLOUD_PLATFROM_OAUTH_SCOPE, CHECKS_OAUTH_SCOPE], + } ); } diff --git a/js/testapps/byo-evaluator/package.json b/js/testapps/byo-evaluator/package.json index 391f2919a..53c996e56 100644 --- a/js/testapps/byo-evaluator/package.json +++ b/js/testapps/byo-evaluator/package.json @@ -26,4 +26,4 @@ "rimraf": "^6.0.1", "typescript": "^5.3.3" } -} \ No newline at end of file +} From a7824bcb98650646ba78328b84e282148679b36a Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 17:10:11 +0000 Subject: [PATCH 15/30] PR Comments: 1. Updated package json, removed vertexai refs. 2. Removed optional deps pulled from vertexai plugin. 3. Action -> EvaluatorAction 4. console.warn -> logger.warn 5. Moved quota project warning into index.ts --- js/plugins/checks/package.json | 18 +++++------------- js/plugins/checks/src/evaluation.ts | 12 +++--------- js/plugins/checks/src/index.ts | 10 +++++++++- js/pnpm-lock.yaml | 7 ------- 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index a23933b7d..fc42df739 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -1,21 +1,17 @@ { "name": "@genkit-ai/checks", - "description": "Genkit AI framework plugin for Google Cloud Vertex AI APIs including Gemini APIs, Imagen, and more.", + "description": "Google Checks AI Safety plugins for classifying the safety of text against Checks AI safety policies.", "keywords": [ "genkit", "genkit-plugin", - "genkit-embedder", - "genkit-model", "google cloud", - "vertex ai", - "imagen", - "image-generation", - "gemini", - "google gemini", "google ai", "ai", "genai", - "generative-ai" + "generative-ai", + "checks", + "google checks", + "guardrails" ], "version": "0.9.0-dev.2", "type": "commonjs", @@ -42,10 +38,6 @@ "peerDependencies": { "genkit": "workspace:*" }, - "optionalDependencies": { - "firebase-admin": ">=12.2", - "@google-cloud/bigquery": "^7.8.0" - }, "devDependencies": { "@types/node": "^20.11.16", "npm-run-all": "^4.1.5", diff --git a/js/plugins/checks/src/evaluation.ts b/js/plugins/checks/src/evaluation.ts index b1700932b..58ad2b115 100644 --- a/js/plugins/checks/src/evaluation.ts +++ b/js/plugins/checks/src/evaluation.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Action, Genkit, z } from 'genkit'; +import { EvaluatorAction, Genkit, z } from 'genkit'; import { BaseEvalDataPoint } from 'genkit/evaluator'; import { runInNewSpan } from 'genkit/tracing'; import { GoogleAuth } from 'google-auth-library'; @@ -64,7 +64,7 @@ export function checksEvaluators( auth: GoogleAuth, metrics: ChecksEvaluationMetric[], projectId: string -): Action[] { +): EvaluatorAction[] { const policy_configs: ChecksEvaluationMetricConfig[] = metrics.map( (metric) => { const metricType = isConfig(metric) ? metric.type : metric; @@ -105,7 +105,7 @@ function createPolicyEvaluator( auth: GoogleAuth, ai: Genkit, policy_config: ChecksEvaluationMetricConfig -): Action { +): EvaluatorAction { const policyType = policy_config.type as string; return ai.defineEvaluator( @@ -169,12 +169,6 @@ async function checksEvalInstance( const url = 'https://checks.googleapis.com/v1alpha/aisafety:classifyContent'; - if (client.quotaProjectId) { - console.warn( - `Checks Evaluator: Your Google cloud authentication has a default quota project(${client.quotaProjectId}) associated with it which will overrid the projectId in your Checks plugin config(${projectId}).` - ); - } - const response = await client.request({ url, method: 'POST', diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index b43b2fed3..64c62dc2e 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -15,6 +15,7 @@ */ import { Genkit } from 'genkit'; +import { logger } from 'genkit/logging'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { GoogleAuth, GoogleAuthOptions } from 'google-auth-library'; import { @@ -45,7 +46,7 @@ const CHECKS_OAUTH_SCOPE = 'https://www.googleapis.com/auth/checks'; */ export function checks(options?: PluginOptions): GenkitPlugin { return genkitPlugin('checks', async (ai: Genkit) => { - let authClient; + let authClient: GoogleAuth; let authOptions = options?.googleAuth; // Allow customers to pass in cloud credentials from environment variables @@ -67,6 +68,13 @@ export function checks(options?: PluginOptions): GenkitPlugin { ); } + const client = await authClient.getClient(); + if (client.quotaProjectId) { + logger.warn( + `Checks Evaluator: Your Google cloud authentication has a default quota project(${client.quotaProjectId}) associated with it which will overrid the projectId in your Checks plugin config(${options?.projectId}).` + ); + } + const projectId = options?.projectId || (await authClient.getProjectId()); const confError = (parameter: string, envVariableName: string) => { diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 7f67de2fb..d1d6a03d9 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -216,13 +216,6 @@ importers: node-fetch: specifier: ^3.3.2 version: 3.3.2 - optionalDependencies: - '@google-cloud/bigquery': - specifier: ^7.8.0 - version: 7.8.0(encoding@0.1.13) - firebase-admin: - specifier: '>=12.2' - version: 12.3.1(encoding@0.1.13) devDependencies: '@types/node': specifier: ^20.11.16 From b0f1f8f98ca9df3faf72b0f007201af059bee635 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 18:06:40 +0000 Subject: [PATCH 16/30] Add readme contents. --- js/plugins/checks/README.md | 92 +++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md index 324aee45e..f4bedd13f 100644 --- a/js/plugins/checks/README.md +++ b/js/plugins/checks/README.md @@ -1,3 +1,95 @@ +# Checks +Checks is an AI safety platform built by Google: [checks.google.com/ai-safety](https://checks.google.com/ai-safety). + +This plugin provides evaluators for each Checks AI safety policy. Text is cassified by calling the [Checks Guardrails API](https://console.cloud.google.com/marketplace/product/google/checks.googleapis.com). + +| Note: The Guardrails is currently in private priview and you will need to request quota. + +Curently that list includes: +``` text +DANGEROUS_CONTENT +PII_SOLICITING_RECITING +HARASSMENT +SEXUALLY_EXPLICIT +HATE_SPEECH +MEDICAL_INFO +VIOLENCE_AND_GORE +OBSCENITY_AND_PROFANITY +``` + +## How to use + +### Configure the plugin +Add the `checks` plugin to your Genkit entrypoint and configured the evaluators you want to use: +```ts +import { checks, ChecksEvaluationMetricType } from '@genkit-ai/checks'; + + +export const ai = genkit({ + plugins: [ + checks({ + // Project to charge quota to. + // Note: If your credentials have a quota project associated with them. + // That value will take precedence over this. + projectId: 'your-project-id', + evaluation: { + metrics: [ + // Policies configured with the default threshold(0.5). + ChecksEvaluationMetricType.DANGEROUS_CONTENT, + ChecksEvaluationMetricType.HARASSMENT, + ChecksEvaluationMetricType.HATE_SPEECH, + ChecksEvaluationMetricType.MEDICAL_INFO, + ChecksEvaluationMetricType.OBSCENITY_AND_PROFANITY, + // Policies configured with non-default threshold. + { + type: ChecksEvaluationMetricType.PII_SOLICITING_RECITING, + threshold: 0.6, + }, + { + type: ChecksEvaluationMetricType.SEXUALLY_EXPLICIT, + threshold: 0.3, + }, + { + type: ChecksEvaluationMetricType.VIOLENCE_AND_GORE, + threshold: 0.55, + }, + ], + }, + }), + ], +}); + +``` + +### Create a test dataset +Create a JSON file with the data you want to test. Add as many test cases as you want. `output` is the text that will be classified. +```JSON +// test-dataset.json +[ + { + "testCaseId": "test_case_id_1", + "input": "The input to your model.", + "output": "Example model output which. This is what will be evaluated." + } +] + +``` + +### Run the evaluators +```bash +# Run just the DANGEROUS_CONTENT classifier. +genkit eval:run test-dataset.json --evaluators=checks/dangerous_content +``` + +```bash +# Run all classifiers. +genkit eval:run test-dataset.json --evaluators=checks/dangerous_content,checks/pii_soliciting_reciting,checks/harassment,checks/sexually_explicit,checks/hate_speech,checks/medical_info,checks/violence_and_gore,checks/obscenity_and_profanity +``` + +### View the results +Run `genkit start` and open the genkit ui. Usually at `localhost:4000` and select the Evaluate tab. + + # Genkit The sources for this package are in the main [Genkit](https://github.com/firebase/genkit) repo. Please file issues and pull requests against that repo. From bcb4e4ae7bb87dcc4c399a58f64519efc0424887 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 18:08:28 +0000 Subject: [PATCH 17/30] Add guardrails doc link --- js/plugins/checks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md index f4bedd13f..23986e8d6 100644 --- a/js/plugins/checks/README.md +++ b/js/plugins/checks/README.md @@ -3,7 +3,7 @@ Checks is an AI safety platform built by Google: [checks.google.com/ai-safety](h This plugin provides evaluators for each Checks AI safety policy. Text is cassified by calling the [Checks Guardrails API](https://console.cloud.google.com/marketplace/product/google/checks.googleapis.com). -| Note: The Guardrails is currently in private priview and you will need to request quota. +> Note: The Guardrails is currently in private priview and you will need to request quota. See Guardrails documentation: [developers.devsite.corp.google.com/checks/guide/api/ai-safety](https://developers.devsite.corp.google.com/checks/guide/api/ai-safety?db=sherzat) Curently that list includes: ``` text From 1ae7e98a8bc8b5d31491a35ed94452af159c3eab Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 18:10:35 +0000 Subject: [PATCH 18/30] pnpm format --- js/plugins/checks/README.md | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md index 23986e8d6..f55da1fa4 100644 --- a/js/plugins/checks/README.md +++ b/js/plugins/checks/README.md @@ -1,12 +1,14 @@ # Checks + Checks is an AI safety platform built by Google: [checks.google.com/ai-safety](https://checks.google.com/ai-safety). -This plugin provides evaluators for each Checks AI safety policy. Text is cassified by calling the [Checks Guardrails API](https://console.cloud.google.com/marketplace/product/google/checks.googleapis.com). +This plugin provides evaluators for each Checks AI safety policy. Text is cassified by calling the [Checks Guardrails API](https://console.cloud.google.com/marketplace/product/google/checks.googleapis.com). > Note: The Guardrails is currently in private priview and you will need to request quota. See Guardrails documentation: [developers.devsite.corp.google.com/checks/guide/api/ai-safety](https://developers.devsite.corp.google.com/checks/guide/api/ai-safety?db=sherzat) Curently that list includes: -``` text + +```text DANGEROUS_CONTENT PII_SOLICITING_RECITING HARASSMENT @@ -20,15 +22,16 @@ OBSCENITY_AND_PROFANITY ## How to use ### Configure the plugin + Add the `checks` plugin to your Genkit entrypoint and configured the evaluators you want to use: + ```ts import { checks, ChecksEvaluationMetricType } from '@genkit-ai/checks'; - export const ai = genkit({ plugins: [ checks({ - // Project to charge quota to. + // Project to charge quota to. // Note: If your credentials have a quota project associated with them. // That value will take precedence over this. projectId: 'your-project-id', @@ -58,11 +61,12 @@ export const ai = genkit({ }), ], }); - ``` ### Create a test dataset -Create a JSON file with the data you want to test. Add as many test cases as you want. `output` is the text that will be classified. + +Create a JSON file with the data you want to test. Add as many test cases as you want. `output` is the text that will be classified. + ```JSON // test-dataset.json [ @@ -76,19 +80,20 @@ Create a JSON file with the data you want to test. Add as many test cases as you ``` ### Run the evaluators + ```bash -# Run just the DANGEROUS_CONTENT classifier. +# Run just the DANGEROUS_CONTENT classifier. genkit eval:run test-dataset.json --evaluators=checks/dangerous_content ``` ```bash -# Run all classifiers. +# Run all classifiers. genkit eval:run test-dataset.json --evaluators=checks/dangerous_content,checks/pii_soliciting_reciting,checks/harassment,checks/sexually_explicit,checks/hate_speech,checks/medical_info,checks/violence_and_gore,checks/obscenity_and_profanity ``` ### View the results -Run `genkit start` and open the genkit ui. Usually at `localhost:4000` and select the Evaluate tab. +Run `genkit start` and open the genkit ui. Usually at `localhost:4000` and select the Evaluate tab. # Genkit From a439dc478ce29c7a76d4a311c80a733d8a7bfefc Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 18:20:56 +0000 Subject: [PATCH 19/30] Link to the onboarding form directly in the readme. --- js/plugins/checks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md index f55da1fa4..2b2e9b143 100644 --- a/js/plugins/checks/README.md +++ b/js/plugins/checks/README.md @@ -4,7 +4,7 @@ Checks is an AI safety platform built by Google: [checks.google.com/ai-safety](h This plugin provides evaluators for each Checks AI safety policy. Text is cassified by calling the [Checks Guardrails API](https://console.cloud.google.com/marketplace/product/google/checks.googleapis.com). -> Note: The Guardrails is currently in private priview and you will need to request quota. See Guardrails documentation: [developers.devsite.corp.google.com/checks/guide/api/ai-safety](https://developers.devsite.corp.google.com/checks/guide/api/ai-safety?db=sherzat) +> Note: The Guardrails is currently in private priview and you will need to request quota. To request quota fill out this [Google form](https://docs.google.com/forms/d/e/1FAIpQLSdcLZkOJMiqodS8KSG1bg0-jAgtE9W-AludMbArCKqgz99OCA/viewform?usp=sf_link) Curently that list includes: From 77ea3b31817b255143715b656be2187a22fa0580 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 21:57:22 +0000 Subject: [PATCH 20/30] Typeo in readme --- js/plugins/checks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md index 2b2e9b143..3c1c39817 100644 --- a/js/plugins/checks/README.md +++ b/js/plugins/checks/README.md @@ -4,7 +4,7 @@ Checks is an AI safety platform built by Google: [checks.google.com/ai-safety](h This plugin provides evaluators for each Checks AI safety policy. Text is cassified by calling the [Checks Guardrails API](https://console.cloud.google.com/marketplace/product/google/checks.googleapis.com). -> Note: The Guardrails is currently in private priview and you will need to request quota. To request quota fill out this [Google form](https://docs.google.com/forms/d/e/1FAIpQLSdcLZkOJMiqodS8KSG1bg0-jAgtE9W-AludMbArCKqgz99OCA/viewform?usp=sf_link) +> Note: The Guardrails is currently in private preview and you will need to request quota. To request quota fill out this [Google form](https://docs.google.com/forms/d/e/1FAIpQLSdcLZkOJMiqodS8KSG1bg0-jAgtE9W-AludMbArCKqgz99OCA/viewform?usp=sf_link) Curently that list includes: From 4578fbbf32e30d81aec15d9d782c1c70fe55de25 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 22:26:02 +0000 Subject: [PATCH 21/30] Removed dependencies other than google-auth --- js/plugins/checks/package.json | 7 ++----- js/pnpm-lock.yaml | 9 --------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index fc42df739..5014163db 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -30,10 +30,7 @@ "author": "genkit", "license": "Apache-2.0", "dependencies": { - "@google-cloud/aiplatform": "^3.23.0", - "google-auth-library": "^9.6.3", - "googleapis": "^140.0.1", - "node-fetch": "^3.3.2" + "google-auth-library": "^9.6.3" }, "peerDependencies": { "genkit": "workspace:*" @@ -55,4 +52,4 @@ "default": "./lib/index.js" } } -} +} \ No newline at end of file diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index d1d6a03d9..d3f5f461a 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -201,21 +201,12 @@ importers: plugins/checks: dependencies: - '@google-cloud/aiplatform': - specifier: ^3.23.0 - version: 3.25.0(encoding@0.1.13) genkit: specifier: workspace:* version: link:../../genkit google-auth-library: specifier: ^9.6.3 version: 9.7.0(encoding@0.1.13) - googleapis: - specifier: ^140.0.1 - version: 140.0.1(encoding@0.1.13) - node-fetch: - specifier: ^3.3.2 - version: 3.3.2 devDependencies: '@types/node': specifier: ^20.11.16 From 754f59e422032dac55146a7874017dcc35323f18 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Mon, 11 Nov 2024 22:49:55 +0000 Subject: [PATCH 22/30] pnpm format --- js/plugins/checks/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index 5014163db..25dece22f 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -52,4 +52,4 @@ "default": "./lib/index.js" } } -} \ No newline at end of file +} From 253b1ce47fd49276b9eb9206c76a0b126fd6d98d Mon Sep 17 00:00:00 2001 From: hunterheston Date: Tue, 12 Nov 2024 17:07:36 +0000 Subject: [PATCH 23/30] Update pnpm lock --- js/pnpm-lock.yaml | 80 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 71 insertions(+), 9 deletions(-) diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index ec7747e2f..9cd39f836 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -208,6 +208,34 @@ importers: specifier: ^4.9.0 version: 4.9.5 + plugins/checks: + dependencies: + genkit: + specifier: workspace:* + version: link:../../genkit + google-auth-library: + specifier: ^9.6.3 + version: 9.14.2(encoding@0.1.13) + devDependencies: + '@types/node': + specifier: ^20.11.16 + version: 20.16.9 + npm-run-all: + specifier: ^4.1.5 + version: 4.1.5 + rimraf: + specifier: ^6.0.1 + version: 6.0.1 + tsup: + specifier: ^8.0.2 + version: 8.3.5(postcss@8.4.47)(tsx@4.19.2)(typescript@4.9.5) + tsx: + specifier: ^4.7.0 + version: 4.19.2 + typescript: + specifier: ^4.9.0 + version: 4.9.5 + plugins/chroma: dependencies: chromadb: @@ -7261,7 +7289,7 @@ snapshots: '@google-cloud/logging-winston@6.0.0(encoding@0.1.13)(winston@3.13.0)': dependencies: '@google-cloud/logging': 11.0.0(encoding@0.1.13) - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.2(encoding@0.1.13) lodash.mapvalues: 4.6.0 winston: 3.13.0 winston-transport: 4.7.0 @@ -7280,7 +7308,7 @@ snapshots: eventid: 2.0.1 extend: 3.0.2 gcp-metadata: 6.1.0(encoding@0.1.13) - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.2(encoding@0.1.13) google-gax: 4.3.2(encoding@0.1.13) on-finished: 2.4.1 pumpify: 2.0.1 @@ -7298,7 +7326,7 @@ snapshots: '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-metrics': 1.25.1(@opentelemetry/api@1.9.0) - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.2(encoding@0.1.13) googleapis: 137.1.0(encoding@0.1.13) transitivePeerDependencies: - encoding @@ -7313,7 +7341,7 @@ snapshots: '@opentelemetry/core': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/resources': 1.25.1(@opentelemetry/api@1.9.0) '@opentelemetry/sdk-trace-base': 1.25.1(@opentelemetry/api@1.9.0) - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.2(encoding@0.1.13) transitivePeerDependencies: - encoding - supports-color @@ -7485,7 +7513,7 @@ snapshots: dependencies: '@jest/types': 29.6.3 '@sinonjs/fake-timers': 10.3.0 - '@types/node': 20.11.30 + '@types/node': 20.16.9 jest-message-util: 29.7.0 jest-mock: 29.7.0 jest-util: 29.7.0 @@ -7577,7 +7605,7 @@ snapshots: '@jest/schemas': 29.6.3 '@types/istanbul-lib-coverage': 2.0.6 '@types/istanbul-reports': 3.0.4 - '@types/node': 20.11.30 + '@types/node': 20.16.9 '@types/yargs': 17.0.33 chalk: 4.1.2 @@ -10006,7 +10034,7 @@ snapshots: '@types/long': 4.0.2 abort-controller: 3.0.0 duplexify: 4.1.3 - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.2(encoding@0.1.13) node-fetch: 2.7.0(encoding@0.1.13) object-hash: 3.0.0 proto3-json-serializer: 2.0.1 @@ -10054,7 +10082,7 @@ snapshots: googleapis@137.1.0(encoding@0.1.13): dependencies: - google-auth-library: 9.11.0(encoding@0.1.13) + google-auth-library: 9.14.2(encoding@0.1.13) googleapis-common: 7.2.0(encoding@0.1.13) transitivePeerDependencies: - encoding @@ -10660,7 +10688,7 @@ snapshots: jest-util@29.7.0: dependencies: '@jest/types': 29.6.3 - '@types/node': 20.11.30 + '@types/node': 20.16.9 chalk: 4.1.2 ci-info: 3.9.0 graceful-fs: 4.2.11 @@ -11405,6 +11433,13 @@ snapshots: postcss: 8.4.47 tsx: 4.19.1 + postcss-load-config@6.0.1(postcss@8.4.47)(tsx@4.19.2): + dependencies: + lilconfig: 3.1.2 + optionalDependencies: + postcss: 8.4.47 + tsx: 4.19.2 + postcss-load-config@6.0.1(postcss@8.4.47)(tsx@4.7.1): dependencies: lilconfig: 3.1.2 @@ -12111,6 +12146,33 @@ snapshots: - tsx - yaml + tsup@8.3.5(postcss@8.4.47)(tsx@4.19.2)(typescript@4.9.5): + dependencies: + bundle-require: 5.0.0(esbuild@0.24.0) + cac: 6.7.14 + chokidar: 4.0.1 + consola: 3.2.3 + debug: 4.3.7 + esbuild: 0.24.0 + joycon: 3.1.1 + picocolors: 1.1.1 + postcss-load-config: 6.0.1(postcss@8.4.47)(tsx@4.19.2) + resolve-from: 5.0.0 + rollup: 4.25.0 + source-map: 0.8.0-beta.0 + sucrase: 3.35.0 + tinyexec: 0.3.1 + tinyglobby: 0.2.10 + tree-kill: 1.2.2 + optionalDependencies: + postcss: 8.4.47 + typescript: 4.9.5 + transitivePeerDependencies: + - jiti + - supports-color + - tsx + - yaml + tsup@8.3.5(postcss@8.4.47)(tsx@4.7.1)(typescript@4.9.5): dependencies: bundle-require: 5.0.0(esbuild@0.24.0) From c2c36eb064d104a16f7fb994fa7976661f652777 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Tue, 12 Nov 2024 17:42:05 +0000 Subject: [PATCH 24/30] update readme --- js/plugins/checks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md index 3c1c39817..b45af6931 100644 --- a/js/plugins/checks/README.md +++ b/js/plugins/checks/README.md @@ -93,7 +93,7 @@ genkit eval:run test-dataset.json --evaluators=checks/dangerous_content,checks/p ### View the results -Run `genkit start` and open the genkit ui. Usually at `localhost:4000` and select the Evaluate tab. +Run `genkit start -- tsx --watch src/index.ts` and open the genkit ui. Usually at `localhost:4000` and select the Evaluate tab. # Genkit From ddb5e52bb14a1eed8042f3e353c462d4cd2fee6b Mon Sep 17 00:00:00 2001 From: hunterheston Date: Tue, 12 Nov 2024 23:16:36 +0000 Subject: [PATCH 25/30] Add plugin to release scripts. --- .github/workflows/bump-js-plugins-version.yml | 11 +++++++++++ scripts/release_main.sh | 4 ++++ scripts/release_next.sh | 5 ++++- 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/bump-js-plugins-version.yml b/.github/workflows/bump-js-plugins-version.yml index 386e4bfc3..5d7076ce9 100644 --- a/.github/workflows/bump-js-plugins-version.yml +++ b/.github/workflows/bump-js-plugins-version.yml @@ -171,3 +171,14 @@ jobs: preid: ${{ inputs.preid }} commit-message: 'chore: bump @genkit-ai/vertexai version to {{version}}' tag-prefix: '@genkit-ai/vertexai@' + - name: 'js/plugins/checks version bump' + uses: 'phips28/gh-action-bump-version@master' + env: + GITHUB_TOKEN: ${{ secrets.GENKIT_RELEASER_GITHUB_TOKEN }} + PACKAGEJSON_DIR: js/plugins/checks + with: + default: ${{ inputs.releaseType }} + version-type: ${{ inputs.releaseType }} + preid: ${{ inputs.preid }} + commit-message: 'chore: bump @genkit-ai/checks version to {{version}}' + tag-prefix: '@genkit-ai/checks@' diff --git a/scripts/release_main.sh b/scripts/release_main.sh index b28e12eb2..5923e41db 100755 --- a/scripts/release_main.sh +++ b/scripts/release_main.sh @@ -76,3 +76,7 @@ cd js/plugins/langchain pnpm publish --registry https://wombat-dressing-room.appspot.com cd $CURRENT +cd js/plugins/checks +pnpm publish --registry https://wombat-dressing-room.appspot.com +cd $CURRENT + diff --git a/scripts/release_next.sh b/scripts/release_next.sh index 3e62e982b..0afb135b1 100755 --- a/scripts/release_next.sh +++ b/scripts/release_next.sh @@ -67,7 +67,7 @@ cd $CURRENT cd js/plugins/pinecone pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com -cd $CURRENT +cd $CURRENTscripts/release_main.sh cd js/plugins/vertexai pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com @@ -81,3 +81,6 @@ cd js/plugins/langchain pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com cd $CURRENT +cd js/plugins/checks +pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com +cd $CURRENT From 665101cf5f46361ec65db9b9e985851bb7487744 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Tue, 12 Nov 2024 23:18:31 +0000 Subject: [PATCH 26/30] Remove pasted text --- scripts/release_next.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/release_next.sh b/scripts/release_next.sh index 0afb135b1..f38161bc5 100755 --- a/scripts/release_next.sh +++ b/scripts/release_next.sh @@ -67,7 +67,7 @@ cd $CURRENT cd js/plugins/pinecone pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com -cd $CURRENTscripts/release_main.sh +cd $CURRENT cd js/plugins/vertexai pnpm publish --tag next --publish-branch next --registry https://wombat-dressing-room.appspot.com From 583f78f2d9210fe6f5c4b46dbe2a59ebbd4111df Mon Sep 17 00:00:00 2001 From: hunterheston Date: Wed, 13 Nov 2024 22:26:24 +0000 Subject: [PATCH 27/30] Bump version of checks plugin --- js/plugins/checks/package.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index 25dece22f..018a34e19 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -13,7 +13,7 @@ "google checks", "guardrails" ], - "version": "0.9.0-dev.2", + "version": "0.9.0-rc.3", "type": "commonjs", "scripts": { "check": "tsc", @@ -52,4 +52,4 @@ "default": "./lib/index.js" } } -} +} \ No newline at end of file From 2669c5cb6935b650d1ed98818447efa048525212 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Wed, 13 Nov 2024 22:28:15 +0000 Subject: [PATCH 28/30] pnpm format --- js/plugins/checks/package.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/plugins/checks/package.json b/js/plugins/checks/package.json index 018a34e19..9bac82ea3 100644 --- a/js/plugins/checks/package.json +++ b/js/plugins/checks/package.json @@ -52,4 +52,4 @@ "default": "./lib/index.js" } } -} \ No newline at end of file +} From 80b69bc04fc73f2bc01e74b78604efa6909fcb3a Mon Sep 17 00:00:00 2001 From: hunterheston Date: Thu, 14 Nov 2024 16:45:33 +0000 Subject: [PATCH 29/30] 1. Remove comment from json markdown. 2. Remove helper function for config errors. --- js/plugins/checks/README.md | 1 - js/plugins/checks/src/index.ts | 9 +++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/js/plugins/checks/README.md b/js/plugins/checks/README.md index b45af6931..5cf804438 100644 --- a/js/plugins/checks/README.md +++ b/js/plugins/checks/README.md @@ -68,7 +68,6 @@ export const ai = genkit({ Create a JSON file with the data you want to test. Add as many test cases as you want. `output` is the text that will be classified. ```JSON -// test-dataset.json [ { "testCaseId": "test_case_id_1", diff --git a/js/plugins/checks/src/index.ts b/js/plugins/checks/src/index.ts index 64c62dc2e..103830a92 100644 --- a/js/plugins/checks/src/index.ts +++ b/js/plugins/checks/src/index.ts @@ -77,13 +77,10 @@ export function checks(options?: PluginOptions): GenkitPlugin { const projectId = options?.projectId || (await authClient.getProjectId()); - const confError = (parameter: string, envVariableName: string) => { - return new Error( - `Checks Plugin is missing the '${parameter}' configuration. Please set the '${envVariableName}' environment variable or explicitly pass '${parameter}' into genkit config.` - ); - }; if (!projectId) { - throw confError('project', 'GCLOUD_PROJECT'); + throw new Error( + `Checks Plugin is missing the 'projectId' configuration. Please set the 'GCLOUD_PROJECT' environment variable or explicitly pass 'projectId' into genkit config.` + ); } const metrics = From c8168f6fc85afc7819b49790effaf1eb51037ad2 Mon Sep 17 00:00:00 2001 From: hunterheston Date: Thu, 14 Nov 2024 17:13:12 +0000 Subject: [PATCH 30/30] Merge with main duplicated the checks publish lines. --- scripts/release_main.sh | 4 ---- 1 file changed, 4 deletions(-) diff --git a/scripts/release_main.sh b/scripts/release_main.sh index 9276098a4..67866b890 100755 --- a/scripts/release_main.sh +++ b/scripts/release_main.sh @@ -83,7 +83,3 @@ cd $CURRENT cd js/plugins/checks pnpm publish --registry https://wombat-dressing-room.appspot.com cd $CURRENT - -cd js/plugins/checks -pnpm publish --registry https://wombat-dressing-room.appspot.com -cd $CURRENT