Skip to content

Commit

Permalink
feat (provider/google-vertex): Add imagen support. (#4124)
Browse files Browse the repository at this point in the history
  • Loading branch information
shaper authored Dec 20, 2024
1 parent cf90a25 commit 5feec50
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 6 deletions.
5 changes: 5 additions & 0 deletions .changeset/neat-rice-travel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@ai-sdk/google-vertex': patch
---

feat (provider/google-vertex): Add imagen support.
10 changes: 6 additions & 4 deletions content/docs/03-ai-sdk-core/35-image-generation.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ const { image } = await generateImage({

## Image Models

| Provider | Model | Supported Sizes |
| --------------------------------------------------------- | ---------- | ------------------------------- |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 |
| Provider | Model | Supported Sizes |
| ----------------------------------------------------------------------- | ------------------------------ | ------------------------------------------------------------------------------------------------------------- |
| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |
| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 |
| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 |
27 changes: 27 additions & 0 deletions content/providers/01-ai-sdk-providers/11-google-vertex.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,33 @@ The following optional settings are available for Google Vertex AI embedding mod
model ID as a string if needed.
</Note>

### Image Models

You can create [Imagen](https://cloud.google.com/vertex-ai/generative-ai/docs/image/overview) models that call the [Imagen on Vertex AI API](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images)
using the `.image()` factory method. For more on image generation with the AI SDK see [generateImage()](/docs/reference/ai-sdk-core/generate-image).

Note that Imagen does not support an explicit size parameter. Instead, size is driven by the [aspect ratio](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) of the input image.

```ts
import { vertex } from '@ai-sdk/google-vertex';
import { experimental_generateImage as generateImage } from 'ai';

const { image } = await generateImage({
model: vertex.image('imagen-3.0-generate-001'),
prompt: 'A futuristic cityscape at sunset',
providerOptions: {
vertex: { aspectRatio: '16:9' },
},
});
```

#### Model Capabilities

| Model | Supported Sizes |
| ------------------------------ | ------------------------------------------------------------------------------------------------------------- |
| `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |
| `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) |

## Google Vertex Anthropic Provider Usage

The Google Vertex Anthropic provider for the [AI SDK](https://sdk.vercel.ai/docs) offers support for Anthropic's Claude models through the Google Vertex AI APIs. This section provides details on how to set up and use the Google Vertex Anthropic provider.
Expand Down
63 changes: 62 additions & 1 deletion examples/ai-core/src/e2e/google-vertex.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,34 @@ import {
streamObject,
embed,
embedMany,
experimental_generateImage as generateImage,
} from 'ai';
import fs from 'fs';
import { GoogleGenerativeAIProviderMetadata } from '@ai-sdk/google';

const LONG_TEST_MILLIS = 10000;
const LONG_TEST_MILLIS = 20000;

const mimeTypeSignatures = [
{ mimeType: 'image/gif' as const, bytes: [0x47, 0x49, 0x46] },
{ mimeType: 'image/png' as const, bytes: [0x89, 0x50, 0x4e, 0x47] },
{ mimeType: 'image/jpeg' as const, bytes: [0xff, 0xd8] },
{ mimeType: 'image/webp' as const, bytes: [0x52, 0x49, 0x46, 0x46] },
];

function detectImageMimeType(
image: Uint8Array,
): 'image/jpeg' | 'image/png' | 'image/gif' | 'image/webp' | undefined {
for (const { bytes, mimeType } of mimeTypeSignatures) {
if (
image.length >= bytes.length &&
bytes.every((byte, index) => image[index] === byte)
) {
return mimeType;
}
}

return undefined;
}

// Model variants to test against
const MODEL_VARIANTS = {
Expand All @@ -26,6 +49,7 @@ const MODEL_VARIANTS = {
// 'gemini-1.0-pro-001',
],
embedding: ['textembedding-gecko', 'textembedding-gecko-multilingual'],
image: ['imagen-3.0-generate-001', 'imagen-3.0-fast-generate-001'],
} as const;

// Define runtime variants
Expand Down Expand Up @@ -452,5 +476,42 @@ describe.each(Object.values(RUNTIME_VARIANTS))(
expect(result.usage?.tokens).toBeGreaterThan(0);
});
});

describe.each(MODEL_VARIANTS.image)('Image Model: %s', modelId => {
it('should generate an image with correct dimensions and format', async () => {
const model = vertex.image(modelId);
const { image } = await generateImage({
model,
prompt: 'A burrito launched through a tunnel',
providerOptions: {
vertex: {
aspectRatio: '3:4',
},
},
});

// Verify we got a Uint8Array back
expect(image.uint8Array).toBeInstanceOf(Uint8Array);

// Check the file size is reasonable (at least 10KB, less than 10MB)
expect(image.uint8Array.length).toBeGreaterThan(10 * 1024);
expect(image.uint8Array.length).toBeLessThan(10 * 1024 * 1024);

// Verify PNG format
const mimeType = detectImageMimeType(image.uint8Array);
expect(mimeType).toBe('image/png');

// Create a temporary buffer to verify image dimensions
const tempBuffer = Buffer.from(image.uint8Array);

// PNG dimensions are stored at bytes 16-24
const width = tempBuffer.readUInt32BE(16);
const height = tempBuffer.readUInt32BE(20);

// https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#performance-limits
expect(width).toBe(896);
expect(height).toBe(1280);
});
});
},
);
22 changes: 22 additions & 0 deletions examples/ai-core/src/generate-image/google-vertex.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import 'dotenv/config';
import { vertex } from '@ai-sdk/google-vertex';
import { experimental_generateImage as generateImage } from 'ai';
import fs from 'fs';

async function main() {
const { image } = await generateImage({
model: vertex.image('imagen-3.0-generate-001'),
prompt: 'A burrito launched through a tunnel',
providerOptions: {
vertex: {
aspectRatio: '16:9',
},
},
});

const filename = `image-${Date.now()}.png`;
fs.writeFileSync(filename, image.uint8Array);
console.log(`Image saved to ${filename}`);
}

main().catch(console.error);
134 changes: 134 additions & 0 deletions packages/google-vertex/src/google-vertex-image-model.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import { JsonTestServer } from '@ai-sdk/provider-utils/test';
import { GoogleVertexImageModel } from './google-vertex-image-model';
import { describe, it, expect, vi } from 'vitest';

const prompt = 'A cute baby sea otter';

const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
provider: 'google-vertex',
baseURL: 'https://api.example.com',
headers: { 'api-key': 'test-key' },
});

describe('GoogleVertexImageModel', () => {
describe('doGenerate', () => {
const server = new JsonTestServer(
'https://api.example.com/models/imagen-3.0-generate-001:predict',
);

server.setupTestEnvironment();

function prepareJsonResponse() {
server.responseBodyJson = {
predictions: [
{ bytesBase64Encoded: 'base64-image-1' },
{ bytesBase64Encoded: 'base64-image-2' },
],
};
}

it('should pass the correct parameters', async () => {
prepareJsonResponse();

await model.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: { vertex: { aspectRatio: '1:1' } },
});

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt }],
parameters: {
sampleCount: 2,
aspectRatio: '1:1',
},
});
});

it('should pass headers', async () => {
prepareJsonResponse();

const modelWithHeaders = new GoogleVertexImageModel(
'imagen-3.0-generate-001',
{
provider: 'google-vertex',
baseURL: 'https://api.example.com',
headers: {
'Custom-Provider-Header': 'provider-header-value',
},
},
);

await modelWithHeaders.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: {},
headers: {
'Custom-Request-Header': 'request-header-value',
},
});

const requestHeaders = await server.getRequestHeaders();

expect(requestHeaders).toStrictEqual({
'content-type': 'application/json',
'custom-provider-header': 'provider-header-value',
'custom-request-header': 'request-header-value',
});
});

it('should extract the generated images', async () => {
prepareJsonResponse();

const result = await model.doGenerate({
prompt,
n: 2,
size: undefined,
providerOptions: {},
});

expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']);
});

it('throws when size is specified', async () => {
const model = new GoogleVertexImageModel('imagen-3.0-generate-001', {
provider: 'vertex',
baseURL: 'https://example.com',
});

await expect(
model.doGenerate({
prompt: 'test prompt',
n: 1,
size: '1024x1024',
providerOptions: {},
}),
).rejects.toThrow(/Google Vertex does not support the `size` option./);
});

it('sends aspect ratio in the request', async () => {
prepareJsonResponse();

await model.doGenerate({
prompt: 'test prompt',
n: 1,
size: undefined,
providerOptions: {
vertex: {
aspectRatio: '16:9',
},
},
});

expect(await server.getRequestBodyJson()).toStrictEqual({
instances: [{ prompt: 'test prompt' }],
parameters: {
sampleCount: 1,
aspectRatio: '16:9',
},
});
});
});
});
86 changes: 86 additions & 0 deletions packages/google-vertex/src/google-vertex-image-model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import { ImageModelV1, JSONValue } from '@ai-sdk/provider';
import {
Resolvable,
postJsonToApi,
combineHeaders,
createJsonResponseHandler,
resolve,
} from '@ai-sdk/provider-utils';
import { z } from 'zod';
import { googleVertexFailedResponseHandler } from './google-vertex-error';

export type GoogleVertexImageModelId =
| 'imagen-3.0-generate-001'
| 'imagen-3.0-fast-generate-001';

interface GoogleVertexImageModelConfig {
provider: string;
baseURL: string;
headers?: Resolvable<Record<string, string | undefined>>;
fetch?: typeof fetch;
}

// https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images
export class GoogleVertexImageModel implements ImageModelV1 {
readonly specificationVersion = 'v1';

get provider(): string {
return this.config.provider;
}

constructor(
readonly modelId: GoogleVertexImageModelId,
private config: GoogleVertexImageModelConfig,
) {}

async doGenerate({
prompt,
n,
size,
providerOptions,
headers,
abortSignal,
}: Parameters<ImageModelV1['doGenerate']>[0]): Promise<
Awaited<ReturnType<ImageModelV1['doGenerate']>>
> {
if (size) {
throw new Error(
'Google Vertex does not support the `size` option. Use ' +
'`providerOptions.vertex.aspectRatio` instead. See ' +
'https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio',
);
}

const body = {
instances: [{ prompt }],
parameters: {
sampleCount: n,
...(providerOptions.vertex ?? {}),
},
};

const { value: response } = await postJsonToApi({
url: `${this.config.baseURL}/models/${this.modelId}:predict`,
headers: combineHeaders(await resolve(this.config.headers), headers),
body,
failedResponseHandler: googleVertexFailedResponseHandler,
successfulResponseHandler: createJsonResponseHandler(
vertexImageResponseSchema,
),
abortSignal: abortSignal,
fetch: this.config.fetch,
});

return {
images: response.predictions.map(
(p: { bytesBase64Encoded: string }) => p.bytesBase64Encoded,
),
};
}
}

// minimal version of the schema, focussed on what is needed for the implementation
// this approach limits breakages when the API changes and increases efficiency
const vertexImageResponseSchema = z.object({
predictions: z.array(z.object({ bytesBase64Encoded: z.string() })),
});
4 changes: 4 additions & 0 deletions packages/google-vertex/src/google-vertex-provider.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ vi.mock('./google-vertex-embedding-model', () => ({
GoogleVertexEmbeddingModel: vi.fn(),
}));

vi.mock('./google-vertex-image-model', () => ({
GoogleVertexImageModel: vi.fn(),
}));

describe('google-vertex-provider', () => {
beforeEach(() => {
vi.clearAllMocks();
Expand Down
Loading

0 comments on commit 5feec50

Please sign in to comment.