Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat (provider/google-vertex): Add imagen support. #4124

Merged
merged 4 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/neat-rice-travel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/google-vertex': patch
---

feat (provider/google-vertex): Add imagen support.
10 changes: 6 additions & 4 deletions content/docs/03-ai-sdk-core/35-image-generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ const { image } = await generateImage({

## Image Models

| Provider | Model | Supported Sizes |
| --------------------------------------------------------- | ---------- | ------------------------------- |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 |
| Provider | Model | Supported Sizes |
| ----------------------------------------------------------------------- | ------------------------------ | ------------------------------------------------------------------------------------------------------------- |
| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |
| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 |
27 changes: 27 additions & 0 deletions content/providers/01-ai-sdk-providers/11-google-vertex.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,33 @@ The following optional settings are available for Google Vertex AI embedding mod
model ID as a string if needed.
</Note>

### Image Models

You can create [Imagen](https://cloud.google.com/vertex-ai/generative-ai/docs/image/overview) models that call the [Imagen on Vertex AI API](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images)
using the `.image()` factory method. For more on image generation with the AI SDK see [generateImage()](/docs/reference/ai-sdk-core/generate-image).

Note that Imagen does not support an explicit size parameter. Instead, size is driven by the [aspect ratio](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) of the input image.

```ts
import { vertex } from '@ai-sdk/google-vertex';
import { experimental_generateImage as generateImage } from 'ai';

const { image } = await generateImage({
model: vertex.image('imagen-3.0-generate-001'),
prompt: 'A futuristic cityscape at sunset',
providerOptions: {
vertex: { aspectRatio: '16:9' },
},
});
```

#### Model Capabilities

| Model | Supported Sizes |
| ------------------------------ | ------------------------------------------------------------------------------------------------------------- |
| `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |
| `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |

## Google Vertex Anthropic Provider Usage

The Google Vertex Anthropic provider for the [AI SDK](https://sdk.vercel.ai/docs) offers support for Anthropic's Claude models through the Google Vertex AI APIs. This section provides details on how to set up and use the Google Vertex Anthropic provider.
Expand Down
63 changes: 62 additions & 1 deletion examples/ai-core/src/e2e/google-vertex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,34 @@ import {
streamObject,
embed,
embedMany,
experimental_generateImage as generateImage,
} from 'ai';
import fs from 'fs';
import { GoogleGenerativeAIProviderMetadata } from '@ai-sdk/google';

const LONG_TEST_MILLIS = 10000;
const LONG_TEST_MILLIS = 20000;

const mimeTypeSignatures = [
{ mimeType: 'image/gif' as const, bytes: [0x47, 0x49, 0x46] },
{ mimeType: 'image/png' as const, bytes: [0x89, 0x50, 0x4e, 0x47] },
{ mimeType: 'image/jpeg' as const, bytes: [0xff, 0xd8] },
{ mimeType: 'image/webp' as const, bytes: [0x52, 0x49, 0x46, 0x46] },
];

function detectImageMimeType(
image: Uint8Array,
): 'image/jpeg' | 'image/png' | 'image/gif' | 'image/webp' | undefined {
for (const { bytes, mimeType } of mimeTypeSignatures) {
if (
image.length >= bytes.length &&
bytes.every((byte, index) => image[index] === byte)
) {
return mimeType;
}
}

return undefined;
}

// Model variants to test against
const MODEL_VARIANTS = {
Expand All @@ -26,6 +49,7 @@ const MODEL_VARIANTS = {
// 'gemini-1.0-pro-001',
],
embedding: ['textembedding-gecko', 'textembedding-gecko-multilingual'],
image: ['imagen-3.0-generate-001', 'imagen-3.0-fast-generate-001'],
shaper marked this conversation as resolved.
Show resolved Hide resolved
} as const;

// Define runtime variants
Expand Down Expand Up @@ -452,5 +476,42 @@ describe.each(Object.values(RUNTIME_VARIANTS))(
expect(result.usage?.tokens).toBeGreaterThan(0);
});
});

describe.each(MODEL_VARIANTS.image)('Image Model: %s', modelId => {
it('should generate an image with correct dimensions and format', async () => {
const model = vertex.image(modelId);
const { image } = await generateImage({
model,
prompt: 'A burrito launched through a tunnel',
providerOptions: {
vertex: {
aspectRatio: '3:4',
},
},
});

// Verify we got a Uint8Array back
expect(image.uint8Array).toBeInstanceOf(Uint8Array);

// Check the file size is reasonable (at least 10KB, less than 10MB)
expect(image.uint8Array.length).toBeGreaterThan(10 * 1024);
expect(image.uint8Array.length).toBeLessThan(10 * 1024 * 1024);

// Verify PNG format
const mimeType = detectImageMimeType(image.uint8Array);
expect(mimeType).toBe('image/png');

// Create a temporary buffer to verify image dimensions
const tempBuffer = Buffer.from(image.uint8Array);

// PNG dimensions are stored at bytes 16-24
const width = tempBuffer.readUInt32BE(16);
const height = tempBuffer.readUInt32BE(20);

// https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#performance-limits
expect(width).toBe(896);
expect(height).toBe(1280);
});
});
},
);
22 changes: 22 additions & 0 deletions examples/ai-core/src/generate-image/google-vertex.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import 'dotenv/config';
import { vertex } from '@ai-sdk/google-vertex';
import { experimental_generateImage as generateImage } from 'ai';
import fs from 'fs';

async function main() {
const { image } = await generateImage({
model: vertex.image('imagen-3.0-generate-001'),
prompt: 'A burrito launched through a tunnel',
providerOptions: {
vertex: {
aspectRatio: '16:9',
},
},
});

const filename = `image-${Date.now()}.png`;
fs.writeFileSync(filename, image.uint8Array);
console.log(`Image saved to ${filename}`);
}

main().catch(console.error);
134 changes: 134 additions & 0 deletions packages/google-vertex/src/google-vertex-image-model.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import { JsonTestServer } from '@ai-sdk/provider-utils/test';
import { GoogleVertexImageModel } from './google-vertex-image-model';
import { describe, it, expect, vi } from 'vitest';

const prompt = 'A cute baby sea otter';

const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
provider: 'google-vertex',
baseURL: 'https://api.example.com',
headers: { 'api-key': 'test-key' },
});

describe('GoogleVertexImageModel', () => {
describe('doGenerate', () => {
const server = new JsonTestServer(
'https://api.example.com/models/imagen-3.0-generate-001:predict',
);

server.setupTestEnvironment();

function prepareJsonResponse() {
server.responseBodyJson = {
predictions: [
{ bytesBase64Encoded: 'base64-image-1' },
{ bytesBase64Encoded: 'base64-image-2' },
],
};
}

it('should pass the correct parameters', async () => {
prepareJsonResponse();

await model.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: { vertex: { aspectRatio: '1:1' } },
});

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt }],
parameters: {
sampleCount: 2,
aspectRatio: '1:1',
},
});
});

it('should pass headers', async () => {
prepareJsonResponse();

const modelWithHeaders = new GoogleVertexImageModel(
'imagen-3.0-generate-001',
{
provider: 'google-vertex',
baseURL: 'https://api.example.com',
headers: {
'Custom-Provider-Header': 'provider-header-value',
},
},
);

await modelWithHeaders.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: {},
headers: {
'Custom-Request-Header': 'request-header-value',
},
});

const requestHeaders = await server.getRequestHeaders();

expect(requestHeaders).toStrictEqual({
'content-type': 'application/json',
'custom-provider-header': 'provider-header-value',
'custom-request-header': 'request-header-value',
});
});

it('should extract the generated images', async () => {
prepareJsonResponse();

const result = await model.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: {},
});

expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']);
});

it('throws when size is specified', async () => {
const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
provider: 'vertex',
baseURL: 'https://example.com',
});

await expect(
model.doGenerate({
prompt: 'test prompt',
n: 1,
size: '1024x1024',
providerOptions: {},
}),
).rejects.toThrow(/Google Vertex does not support the `size` option./);
});

it('sends aspect ratio in the request', async () => {
prepareJsonResponse();

await model.doGenerate({
prompt: 'test prompt',
n: 1,
size: undefined,
providerOptions: {
vertex: {
aspectRatio: '16:9',
},
},
});

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt: 'test prompt' }],
parameters: {
sampleCount: 1,
aspectRatio: '16:9',
},
});
});
});
});
86 changes: 86 additions & 0 deletions packages/google-vertex/src/google-vertex-image-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import { ImageModelV1, JSONValue } from '@ai-sdk/provider';
import {
Resolvable,
postJsonToApi,
combineHeaders,
createJsonResponseHandler,
resolve,
} from '@ai-sdk/provider-utils';
import { z } from 'zod';
import { googleVertexFailedResponseHandler } from './google-vertex-error';

export type GoogleVertexImageModelId =
| 'imagen-3.0-generate-001'
| 'imagen-3.0-fast-generate-001';

interface GoogleVertexImageModelConfig {
provider: string;
baseURL: string;
headers?: Resolvable<Record<string, string | undefined>>;
fetch?: typeof fetch;
}

// https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images
export class GoogleVertexImageModel implements ImageModelV1 {
readonly specificationVersion = 'v1';

get provider(): string {
return this.config.provider;
}

constructor(
readonly modelId: GoogleVertexImageModelId,
private config: GoogleVertexImageModelConfig,
) {}

async doGenerate({
prompt,
n,
size,
providerOptions,
headers,
abortSignal,
}: Parameters<ImageModelV1['doGenerate']>[0]): Promise<
Awaited<ReturnType<ImageModelV1['doGenerate']>>
> {
if (size) {
throw new Error(
'Google Vertex does not support the `size` option. Use ' +
'`providerOptions.vertex.aspectRatio` instead. See ' +
'https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio',
);
}

const body = {
instances: [{ prompt }],
parameters: {
sampleCount: n,
...(providerOptions.vertex ?? {}),
},
};

const { value: response } = await postJsonToApi({
url: `${this.config.baseURL}/models/${this.modelId}:predict`,
headers: combineHeaders(await resolve(this.config.headers), headers),
body,
failedResponseHandler: googleVertexFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(
vertexImageResponseSchema,
),
abortSignal: abortSignal,
fetch: this.config.fetch,
});

return {
images: response.predictions.map(
(p: { bytesBase64Encoded: string }) => p.bytesBase64Encoded,
),
};
}
}

// minimal version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const vertexImageResponseSchema = z.object({
predictions: z.array(z.object({ bytesBase64Encoded: z.string() })),
});
4 changes: 4 additions & 0 deletions packages/google-vertex/src/google-vertex-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ vi.mock('./google-vertex-embedding-model', () => ({
GoogleVertexEmbeddingModel: vi.fn(),
}));

vi.mock('./google-vertex-image-model', () => ({
GoogleVertexImageModel: vi.fn(),
}));

describe('google-vertex-provider', () => {
beforeEach(() => {
vi.clearAllMocks();
Expand Down
Loading
Loading