Skip to content

Commit

Permalink
feat(js/plugins/vertexai): support resposeMimeType in next (#1149)
Browse files Browse the repository at this point in the history
  • Loading branch information
cabljac authored Oct 31, 2024
1 parent 325c5b7 commit 971d40e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 40 deletions.
8 changes: 4 additions & 4 deletions js/plugins/vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@
"@anthropic-ai/sdk": "^0.24.3",
"@anthropic-ai/vertex-sdk": "^0.4.0",
"@google-cloud/aiplatform": "^3.23.0",
"@google-cloud/vertexai": "^1.1.0",
"google-auth-library": "^9.6.3",
"@google-cloud/vertexai": "^1.9.0",
"google-auth-library": "^9.14.2",
"googleapis": "^140.0.1",
"node-fetch": "^3.3.2",
"openai": "^4.52.7"
Expand All @@ -48,8 +48,8 @@
"genkit": "workspace:*"
},
"optionalDependencies": {
"firebase-admin": ">=12.2",
"@google-cloud/bigquery": "^7.8.0"
"@google-cloud/bigquery": "^7.8.0",
"firebase-admin": ">=12.2"
},
"devDependencies": {
"@types/node": "^20.11.16",
Expand Down
34 changes: 25 additions & 9 deletions js/plugins/vertexai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,10 @@ function fromGeminiFunctionResponsePart(part: GeminiPart): Part {
}

// Converts vertex part to genkit part
function fromGeminiPart(part: GeminiPart): Part {
function fromGeminiPart(part: GeminiPart, jsonMode: boolean): Part {
if (jsonMode && part.text !== undefined) {
return { data: JSON.parse(part.text) };
}
if (part.text !== undefined) return { text: part.text };
if (part.functionCall) return fromGeminiFunctionCallPart(part);
if (part.functionResponse) return fromGeminiFunctionResponsePart(part);
Expand All @@ -355,14 +358,15 @@ function fromGeminiPart(part: GeminiPart): Part {
}

export function fromGeminiCandidate(
candidate: GenerateContentCandidate
candidate: GenerateContentCandidate,
jsonMode: boolean
): CandidateData {
const parts = candidate.content.parts || [];
const genkitCandidate: CandidateData = {
index: candidate.index || 0, // reasonable default?
message: {
role: 'model',
content: parts.map(fromGeminiPart),
content: parts.map((p) => fromGeminiPart(p, jsonMode)),
},
finishReason: fromGeminiFinishReason(candidate.finishReason),
finishMessage: candidate.finishMessage,
Expand Down Expand Up @@ -463,11 +467,18 @@ export function defineGeminiModel(
}
}

const tools = request.tools?.length
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [];

// Cannot use tools and function calling at the same time
const jsonMode =
(request.output?.format === 'json' || !!request.output?.schema) &&
tools.length === 0;

const chatRequest: StartChatParams = {
systemInstruction,
tools: request.tools?.length
? [{ functionDeclarations: request.tools?.map(toGeminiTool) }]
: [],
tools,
history: messages
.slice(0, -1)
.map((message) => toGeminiMessage(message, model)),
Expand All @@ -477,6 +488,7 @@ export function defineGeminiModel(
maxOutputTokens: request.config?.maxOutputTokens,
topK: request.config?.topK,
topP: request.config?.topP,
responseMimeType: jsonMode ? 'application/json' : undefined,
stopSequences: request.config?.stopSequences,
},
safetySettings: request.config?.safetySettings,
Expand Down Expand Up @@ -511,7 +523,7 @@ export function defineGeminiModel(
.sendMessageStream(msg.parts);
for await (const item of result.stream) {
(item as GenerateContentResponse).candidates?.forEach((candidate) => {
const c = fromGeminiCandidate(candidate);
const c = fromGeminiCandidate(candidate, jsonMode);
streamingCallback({
index: c.index,
content: c.message.content,
Expand All @@ -523,7 +535,9 @@ export function defineGeminiModel(
throw new Error('No valid candidates returned.');
}
return {
candidates: response.candidates?.map(fromGeminiCandidate) || [],
candidates:
response.candidates?.map((c) => fromGeminiCandidate(c, jsonMode)) ||
[],
custom: response,
};
} else {
Expand All @@ -537,7 +551,9 @@ export function defineGeminiModel(
throw new Error('No valid candidates returned.');
}
const responseCandidates =
result.response.candidates?.map(fromGeminiCandidate) || [];
result.response.candidates?.map((c) =>
fromGeminiCandidate(c, jsonMode)
) || [];
return {
candidates: responseCandidates,
custom: result.response,
Expand Down
70 changes: 43 additions & 27 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 971d40e

Please sign in to comment.