Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 fix: Gemini cannot input images when server database is enabled #3370

Merged
merged 8 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions src/libs/agent-runtime/google/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import OpenAI from 'openai';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { OpenAIChatMessage } from '@/libs/agent-runtime';
import * as imageToBase64Module from '@/utils/imageToBase64';

import * as debugStreamModule from '../utils/debugStream';
import { LobeGoogleAI } from './index';
Expand Down Expand Up @@ -303,36 +304,57 @@ describe('LobeGoogleAI', () => {

describe('private method', () => {
describe('convertContentToGooglePart', () => {
it('should throw TypeError when image URL does not contain base64 data', () => {
// 提供一个不包含base64数据的图像URL
const invalidImageUrl = 'http://example.com/image.png';
it('should handle URL type images', async () => {
const imageUrl = 'http://example.com/image.png';
const mockBase64 = 'mockBase64Data';

expect(() =>
// Mock the imageUrlToBase64 function
vi.spyOn(imageToBase64Module, 'imageUrlToBase64').mockResolvedValueOnce(mockBase64);

const result = await instance['convertContentToGooglePart']({
type: 'image_url',
image_url: { url: imageUrl },
});

expect(result).toEqual({
inlineData: {
data: mockBase64,
mimeType: 'image/png',
},
});

expect(imageToBase64Module.imageUrlToBase64).toHaveBeenCalledWith(imageUrl);
});

it('should throw TypeError for unsupported image URL types', async () => {
const unsupportedImageUrl = 'unsupported://example.com/image.png';

await expect(
instance['convertContentToGooglePart']({
type: 'image_url',
image_url: { url: invalidImageUrl },
image_url: { url: unsupportedImageUrl },
}),
).toThrow(TypeError);
).rejects.toThrow(TypeError);
});
});

describe('buildGoogleMessages', () => {
it('get default result with gemini-pro', () => {
it('get default result with gemini-pro', async () => {
const messages: OpenAIChatMessage[] = [{ content: 'Hello', role: 'user' }];

const contents = instance['buildGoogleMessages'](messages, 'gemini-pro');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro');

expect(contents).toHaveLength(1);
expect(contents).toEqual([{ parts: [{ text: 'Hello' }], role: 'user' }]);
});

it('messages should end with user if using gemini-pro', () => {
it('messages should end with user if using gemini-pro', async () => {
const messages: OpenAIChatMessage[] = [
{ content: 'Hello', role: 'user' },
{ content: 'Hi', role: 'assistant' },
];

const contents = instance['buildGoogleMessages'](messages, 'gemini-pro');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro');

expect(contents).toHaveLength(3);
expect(contents).toEqual([
Expand All @@ -342,13 +364,13 @@ describe('LobeGoogleAI', () => {
]);
});

it('should include system role if there is a system role prompt', () => {
it('should include system role if there is a system role prompt', async () => {
const messages: OpenAIChatMessage[] = [
{ content: 'you are ChatGPT', role: 'system' },
{ content: 'Who are you', role: 'user' },
];

const contents = instance['buildGoogleMessages'](messages, 'gemini-pro');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-pro');

expect(contents).toHaveLength(3);
expect(contents).toEqual([
Expand All @@ -358,13 +380,13 @@ describe('LobeGoogleAI', () => {
]);
});

it('should not modify the length if model is gemini-1.5-pro', () => {
it('should not modify the length if model is gemini-1.5-pro', async () => {
const messages: OpenAIChatMessage[] = [
{ content: 'Hello', role: 'user' },
{ content: 'Hi', role: 'assistant' },
];

const contents = instance['buildGoogleMessages'](messages, 'gemini-1.5-pro-latest');
const contents = await instance['buildGoogleMessages'](messages, 'gemini-1.5-pro-latest');

expect(contents).toHaveLength(2);
expect(contents).toEqual([
Expand All @@ -373,7 +395,7 @@ describe('LobeGoogleAI', () => {
]);
});

it('should use specified model when images are included in messages', () => {
it('should use specified model when images are included in messages', async () => {
const messages: OpenAIChatMessage[] = [
{
content: [
Expand All @@ -386,7 +408,7 @@ describe('LobeGoogleAI', () => {
const model = 'gemini-1.5-flash-latest';

// 调用 buildGoogleMessages 方法
const contents = instance['buildGoogleMessages'](messages, model);
const contents = await instance['buildGoogleMessages'](messages, model);

expect(contents).toHaveLength(1);
expect(contents).toEqual([
Expand Down Expand Up @@ -501,35 +523,35 @@ describe('LobeGoogleAI', () => {
});

describe('convertOAIMessagesToGoogleMessage', () => {
it('should correctly convert assistant message', () => {
it('should correctly convert assistant message', async () => {
const message: OpenAIChatMessage = {
role: 'assistant',
content: 'Hello',
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'model',
parts: [{ text: 'Hello' }],
});
});

it('should correctly convert user message', () => {
it('should correctly convert user message', async () => {
const message: OpenAIChatMessage = {
role: 'user',
content: 'Hi',
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'user',
parts: [{ text: 'Hi' }],
});
});

it('should correctly convert message with inline base64 image parts', () => {
it('should correctly convert message with inline base64 image parts', async () => {
const message: OpenAIChatMessage = {
role: 'user',
content: [
Expand All @@ -538,7 +560,7 @@ describe('LobeGoogleAI', () => {
],
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'user',
Expand All @@ -548,7 +570,7 @@ describe('LobeGoogleAI', () => {
],
});
});
it.skip('should correctly convert message with image url parts', () => {
it.skip('should correctly convert message with image url parts', async () => {
const message: OpenAIChatMessage = {
role: 'user',
content: [
Expand All @@ -557,7 +579,7 @@ describe('LobeGoogleAI', () => {
],
};

const converted = instance['convertOAIMessagesToGoogleMessage'](message);
const converted = await instance['convertOAIMessagesToGoogleMessage'](message);

expect(converted).toEqual({
role: 'user',
Expand Down
49 changes: 30 additions & 19 deletions src/libs/agent-runtime/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
import { JSONSchema7 } from 'json-schema';
import { transform } from 'lodash-es';

import { imageUrlToBase64 } from '@/utils/imageToBase64';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType, ILobeAgentRuntimeErrorType } from '../error';
import {
Expand Down Expand Up @@ -52,7 +54,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
try {
const model = payload.model;

const contents = this.buildGoogleMessages(payload.messages, model);
const contents = await this.buildGoogleMessages(payload.messages, model);

const geminiStreamResult = await this.client
.getGenerativeModel(
Expand Down Expand Up @@ -109,7 +111,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
}
}

private convertContentToGooglePart = (content: UserMessageContentPart): Part => {
private convertContentToGooglePart = async (content: UserMessageContentPart): Promise<Part> => {
switch (content.type) {
case 'text': {
return { text: content.text };
Expand All @@ -130,51 +132,60 @@ export class LobeGoogleAI implements LobeRuntimeAI {
};
}

// if (type === 'url') {
// return {
// fileData: {
// fileUri: content.image_url.url,
// mimeType: mimeType || 'image/png',
// },
// };
// }
if (type === 'url') {
const base64Image = await imageUrlToBase64(content.image_url.url);

return {
inlineData: {
data: base64Image,
mimeType: mimeType || 'image/png',
},
};
}

throw new TypeError(`currently we don't support image url: ${content.image_url.url}`);
}
}
};

private convertOAIMessagesToGoogleMessage = (message: OpenAIChatMessage): Content => {
private convertOAIMessagesToGoogleMessage = async (
message: OpenAIChatMessage,
): Promise<Content> => {
const content = message.content as string | UserMessageContentPart[];

return {
parts:
typeof content === 'string'
? [{ text: content }]
: content.map((c) => this.convertContentToGooglePart(c)),
: await Promise.all(content.map(async (c) => await this.convertContentToGooglePart(c))),
role: message.role === 'assistant' ? 'model' : 'user',
};
};

// convert messages from the Vercel AI SDK Format to the format
// that is expected by the Google GenAI SDK
private buildGoogleMessages = (messages: OpenAIChatMessage[], model: string): Content[] => {
private buildGoogleMessages = async (
messages: OpenAIChatMessage[],
model: string,
): Promise<Content[]> => {
// if the model is gemini-1.5-pro-latest, we don't need any special handling
if (model === 'gemini-1.5-pro-latest') {
return messages
const pools = messages
.filter((message) => message.role !== 'function')
.map((msg) => this.convertOAIMessagesToGoogleMessage(msg));
.map(async (msg) => await this.convertOAIMessagesToGoogleMessage(msg));

return Promise.all(pools);
}

const contents: Content[] = [];
let lastRole = 'model';

messages.forEach((message) => {
for (const message of messages) {
// current to filter function message
if (message.role === 'function') {
return;
continue;
}
const googleMessage = this.convertOAIMessagesToGoogleMessage(message);
const googleMessage = await this.convertOAIMessagesToGoogleMessage(message);

// if the last message is a model message and the current message is a model message
// then we need to add a user message to separate them
Expand All @@ -187,7 +198,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {

// update the last role
lastRole = googleMessage.role;
});
}

// if the last message is a user message, then we need to add a model message to separate them
if (lastRole === 'model') {
Expand Down
16 changes: 16 additions & 0 deletions src/utils/imageToBase64.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ export const imageToBase64 = ({

return canvas.toDataURL(type);
};

export const imageUrlToBase64 = async (imageUrl: string): Promise<string> => {
try {
const res = await fetch(imageUrl);
const arrayBuffer = await res.arrayBuffer();

return typeof btoa === 'function'
? btoa(
new Uint8Array(arrayBuffer).reduce((data, byte) => data + String.fromCharCode(byte), ''),
)
: Buffer.from(arrayBuffer).toString('base64');
} catch (error) {
console.error('Error converting image to base64:', error);
throw error;
}
};
Loading