Skip to content

Commit

Permalink
feat: add support for multimodal in Vertex (huggingface#1338)
Browse files Browse the repository at this point in the history
* feat: add support for multimodal in Vertex

* Nit changes and remove  tools if multimodal

* revert model name change

* Fix tools/multimodal condition

* chores(lint): fix formatting

---------
Co-authored-by: Thomas <thomas.poc@gmail.com>
Co-authored-by: Nathan Sarrazin <sarrazin.nathan@gmail.com>
  • Loading branch information
ArthurGoupil authored Sep 9, 2024
1 parent c53a4b5 commit 9549e2b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 15 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -775,21 +775,29 @@ MODELS=`[
{
"name": "gemini-1.5-pro",
"displayName": "Vertex Gemini Pro 1.5",
"multimodal": true,
"endpoints" : [{
"type": "vertex",
"project": "abc-xyz",
"location": "europe-west3",
"model": "gemini-1.5-pro-preview-0409", // model-name
// Optional
"safetyThreshold": "BLOCK_MEDIUM_AND_ABOVE",
"apiEndpoint": "", // alternative api endpoint url,
// Optional
"tools": [{
"googleSearchRetrieval": {
"disableAttribution": true
}
}]
}],
"multimodal": {
"image": {
"supportedMimeTypes": ["image/png", "image/jpeg", "image/webp"],
"preferredMimeType": "image/png",
"maxSizeInMB": 5,
"maxWidth": 2000,
"maxHeight": 1000;
}
}
}]
},
]`
Expand Down
60 changes: 48 additions & 12 deletions src/lib/server/endpoints/google/endpointVertex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type { Endpoint } from "../endpoints";
import { z } from "zod";
import type { Message } from "$lib/types/Message";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { createImageProcessorOptionsValidator, makeImageProcessor } from "../images";

export const endpointVertexParametersSchema = z.object({
weight: z.number().int().positive().default(1),
Expand All @@ -27,10 +28,28 @@ export const endpointVertexParametersSchema = z.object({
])
.optional(),
tools: z.array(z.any()).optional(),
multimodal: z
.object({
image: createImageProcessorOptionsValidator({
supportedMimeTypes: [
"image/png",
"image/jpeg",
"image/webp",
"image/avif",
"image/tiff",
"image/gif",
],
preferredMimeType: "image/webp",
maxSizeInMB: Infinity,
maxWidth: 4096,
maxHeight: 4096,
}),
})
.default({}),
});

export function endpointVertex(input: z.input<typeof endpointVertexParametersSchema>): Endpoint {
const { project, location, model, apiEndpoint, safetyThreshold, tools } =
const { project, location, model, apiEndpoint, safetyThreshold, tools, multimodal } =
endpointVertexParametersSchema.parse(input);

const vertex_ai = new VertexAI({
Expand All @@ -42,6 +61,8 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
return async ({ messages, preprompt, generateSettings }) => {
const parameters = { ...model.parameters, ...generateSettings };

const hasFiles = messages.some((message) => message.files && message.files.length > 0);

const generativeModel = vertex_ai.getGenerativeModel({
model: model.id ?? model.name,
safetySettings: safetyThreshold
Expand Down Expand Up @@ -73,7 +94,8 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
stopSequences: parameters?.stop,
temperature: parameters?.temperature ?? 1,
},
tools,
// tools and multimodal are mutually exclusive
tools: !hasFiles ? tools : undefined,
});

// Preprompt is the same as the first system message.
Expand All @@ -83,16 +105,30 @@ export function endpointVertex(input: z.input<typeof endpointVertexParametersSch
messages.shift();
}

const vertexMessages = messages.map(({ from, content }: Omit<Message, "id">): Content => {
return {
role: from === "user" ? "user" : "model",
parts: [
{
text: content,
},
],
};
});
const vertexMessages = await Promise.all(
messages.map(async ({ from, content, files }: Omit<Message, "id">): Promise<Content> => {
const imageProcessor = makeImageProcessor(multimodal.image);
const processedFiles =
files && files.length > 0
? await Promise.all(files.map(async (file) => imageProcessor(file)))
: [];

return {
role: from === "user" ? "user" : "model",
parts: [
...processedFiles.map((processedFile) => ({
inlineData: {
data: processedFile.image.toString("base64"),
mimeType: processedFile.mime,
},
})),
{
text: content,
},
],
};
})
);

const result = await generativeModel.generateContentStream({
contents: vertexMessages,
Expand Down

0 comments on commit 9549e2b

Please sign in to comment.