From 5feec50dada410657299a580bdb13a4b4121e46b Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Thu, 19 Dec 2024 17:37:27 -0800 Subject: [PATCH] feat (provider/google-vertex): Add imagen support. (#4124) --- .changeset/neat-rice-travel.md | 5 + .../03-ai-sdk-core/35-image-generation.mdx | 10 +- .../01-ai-sdk-providers/11-google-vertex.mdx | 27 ++++ .../ai-core/src/e2e/google-vertex.test.ts | 63 +++++++- .../src/generate-image/google-vertex.ts | 22 +++ .../src/google-vertex-image-model.test.ts | 134 ++++++++++++++++++ .../src/google-vertex-image-model.ts | 86 +++++++++++ .../src/google-vertex-provider.test.ts | 4 + .../src/google-vertex-provider.ts | 20 ++- 9 files changed, 365 insertions(+), 6 deletions(-) create mode 100644 .changeset/neat-rice-travel.md create mode 100644 examples/ai-core/src/generate-image/google-vertex.ts create mode 100644 packages/google-vertex/src/google-vertex-image-model.test.ts create mode 100644 packages/google-vertex/src/google-vertex-image-model.ts diff --git a/.changeset/neat-rice-travel.md b/.changeset/neat-rice-travel.md new file mode 100644 index 000000000000..5d0d126d0495 --- /dev/null +++ b/.changeset/neat-rice-travel.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/google-vertex': patch +--- + +feat (provider/google-vertex): Add imagen support. diff --git a/content/docs/03-ai-sdk-core/35-image-generation.mdx b/content/docs/03-ai-sdk-core/35-image-generation.mdx index b09d40b89c74..53ef78e83bd2 100644 --- a/content/docs/03-ai-sdk-core/35-image-generation.mdx +++ b/content/docs/03-ai-sdk-core/35-image-generation.mdx @@ -93,7 +93,9 @@ const { image } = await generateImage({ ## Image Models -| Provider | Model | Supported Sizes | -| --------------------------------------------------------- | ---------- | ------------------------------- | -| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | -| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | +| Provider | Model | Supported Sizes | +| ----------------------------------------------------------------------- | ------------------------------ | ------------------------------------------------------------------------------------------------------------- | +| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | +| [Google Vertex](/providers/ai-sdk-providers/google-vertex#image-models) | `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | +| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-3` | 1024x1024, 1792x1024, 1024x1792 | +| [OpenAI](/providers/ai-sdk-providers/openai#image-models) | `dall-e-2` | 256x256, 512x512, 1024x1024 | diff --git a/content/providers/01-ai-sdk-providers/11-google-vertex.mdx b/content/providers/01-ai-sdk-providers/11-google-vertex.mdx index a4bee6f00252..4967d67872df 100644 --- a/content/providers/01-ai-sdk-providers/11-google-vertex.mdx +++ b/content/providers/01-ai-sdk-providers/11-google-vertex.mdx @@ -559,6 +559,33 @@ The following optional settings are available for Google Vertex AI embedding mod model ID as a string if needed. +### Image Models + +You can create [Imagen](https://cloud.google.com/vertex-ai/generative-ai/docs/image/overview) models that call the [Imagen on Vertex AI API](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images) +using the `.image()` factory method. For more on image generation with the AI SDK see [generateImage()](/docs/reference/ai-sdk-core/generate-image). + +Note that Imagen does not support an explicit size parameter. Instead, size is driven by the [aspect ratio](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) of the input image. + +```ts +import { vertex } from '@ai-sdk/google-vertex'; +import { experimental_generateImage as generateImage } from 'ai'; + +const { image } = await generateImage({ + model: vertex.image('imagen-3.0-generate-001'), + prompt: 'A futuristic cityscape at sunset', + providerOptions: { + vertex: { aspectRatio: '16:9' }, + }, +}); +``` + +#### Model Capabilities + +| Model | Supported Sizes | +| ------------------------------ | ------------------------------------------------------------------------------------------------------------- | +| `imagen-3.0-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | +| `imagen-3.0-fast-generate-001` | See [aspect ratios](https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio) | + ## Google Vertex Anthropic Provider Usage The Google Vertex Anthropic provider for the [AI SDK](https://sdk.vercel.ai/docs) offers support for Anthropic's Claude models through the Google Vertex AI APIs. This section provides details on how to set up and use the Google Vertex Anthropic provider. diff --git a/examples/ai-core/src/e2e/google-vertex.test.ts b/examples/ai-core/src/e2e/google-vertex.test.ts index 6ac0e1d11ca4..7d5bf9a0063a 100644 --- a/examples/ai-core/src/e2e/google-vertex.test.ts +++ b/examples/ai-core/src/e2e/google-vertex.test.ts @@ -10,11 +10,34 @@ import { streamObject, embed, embedMany, + experimental_generateImage as generateImage, } from 'ai'; import fs from 'fs'; import { GoogleGenerativeAIProviderMetadata } from '@ai-sdk/google'; -const LONG_TEST_MILLIS = 10000; +const LONG_TEST_MILLIS = 20000; + +const mimeTypeSignatures = [ + { mimeType: 'image/gif' as const, bytes: [0x47, 0x49, 0x46] }, + { mimeType: 'image/png' as const, bytes: [0x89, 0x50, 0x4e, 0x47] }, + { mimeType: 'image/jpeg' as const, bytes: [0xff, 0xd8] }, + { mimeType: 'image/webp' as const, bytes: [0x52, 0x49, 0x46, 0x46] }, +]; + +function detectImageMimeType( + image: Uint8Array, +): 'image/jpeg' | 'image/png' | 'image/gif' | 'image/webp' | undefined { + for (const { bytes, mimeType } of mimeTypeSignatures) { + if ( + image.length >= bytes.length && + bytes.every((byte, index) => image[index] === byte) + ) { + return mimeType; + } + } + + return undefined; +} // Model variants to test against const MODEL_VARIANTS = { @@ -26,6 +49,7 @@ const MODEL_VARIANTS = { // 'gemini-1.0-pro-001', ], embedding: ['textembedding-gecko', 'textembedding-gecko-multilingual'], + image: ['imagen-3.0-generate-001', 'imagen-3.0-fast-generate-001'], } as const; // Define runtime variants @@ -452,5 +476,42 @@ describe.each(Object.values(RUNTIME_VARIANTS))( expect(result.usage?.tokens).toBeGreaterThan(0); }); }); + + describe.each(MODEL_VARIANTS.image)('Image Model: %s', modelId => { + it('should generate an image with correct dimensions and format', async () => { + const model = vertex.image(modelId); + const { image } = await generateImage({ + model, + prompt: 'A burrito launched through a tunnel', + providerOptions: { + vertex: { + aspectRatio: '3:4', + }, + }, + }); + + // Verify we got a Uint8Array back + expect(image.uint8Array).toBeInstanceOf(Uint8Array); + + // Check the file size is reasonable (at least 10KB, less than 10MB) + expect(image.uint8Array.length).toBeGreaterThan(10 * 1024); + expect(image.uint8Array.length).toBeLessThan(10 * 1024 * 1024); + + // Verify PNG format + const mimeType = detectImageMimeType(image.uint8Array); + expect(mimeType).toBe('image/png'); + + // Create a temporary buffer to verify image dimensions + const tempBuffer = Buffer.from(image.uint8Array); + + // PNG dimensions are stored at bytes 16-24 + const width = tempBuffer.readUInt32BE(16); + const height = tempBuffer.readUInt32BE(20); + + // https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#performance-limits + expect(width).toBe(896); + expect(height).toBe(1280); + }); + }); }, ); diff --git a/examples/ai-core/src/generate-image/google-vertex.ts b/examples/ai-core/src/generate-image/google-vertex.ts new file mode 100644 index 000000000000..699ead66c91b --- /dev/null +++ b/examples/ai-core/src/generate-image/google-vertex.ts @@ -0,0 +1,22 @@ +import 'dotenv/config'; +import { vertex } from '@ai-sdk/google-vertex'; +import { experimental_generateImage as generateImage } from 'ai'; +import fs from 'fs'; + +async function main() { + const { image } = await generateImage({ + model: vertex.image('imagen-3.0-generate-001'), + prompt: 'A burrito launched through a tunnel', + providerOptions: { + vertex: { + aspectRatio: '16:9', + }, + }, + }); + + const filename = `image-${Date.now()}.png`; + fs.writeFileSync(filename, image.uint8Array); + console.log(`Image saved to ${filename}`); +} + +main().catch(console.error); diff --git a/packages/google-vertex/src/google-vertex-image-model.test.ts b/packages/google-vertex/src/google-vertex-image-model.test.ts new file mode 100644 index 000000000000..e45ded4f6fd8 --- /dev/null +++ b/packages/google-vertex/src/google-vertex-image-model.test.ts @@ -0,0 +1,134 @@ +import { JsonTestServer } from '@ai-sdk/provider-utils/test'; +import { GoogleVertexImageModel } from './google-vertex-image-model'; +import { describe, it, expect, vi } from 'vitest'; + +const prompt = 'A cute baby sea otter'; + +const model = new GoogleVertexImageModel('imagen-3.0-generate-001', { + provider: 'google-vertex', + baseURL: 'https://api.example.com', + headers: { 'api-key': 'test-key' }, +}); + +describe('GoogleVertexImageModel', () => { + describe('doGenerate', () => { + const server = new JsonTestServer( + 'https://api.example.com/models/imagen-3.0-generate-001:predict', + ); + + server.setupTestEnvironment(); + + function prepareJsonResponse() { + server.responseBodyJson = { + predictions: [ + { bytesBase64Encoded: 'base64-image-1' }, + { bytesBase64Encoded: 'base64-image-2' }, + ], + }; + } + + it('should pass the correct parameters', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt, + n: 2, + size: undefined, + providerOptions: { vertex: { aspectRatio: '1:1' } }, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + instances: [{ prompt }], + parameters: { + sampleCount: 2, + aspectRatio: '1:1', + }, + }); + }); + + it('should pass headers', async () => { + prepareJsonResponse(); + + const modelWithHeaders = new GoogleVertexImageModel( + 'imagen-3.0-generate-001', + { + provider: 'google-vertex', + baseURL: 'https://api.example.com', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }, + ); + + await modelWithHeaders.doGenerate({ + prompt, + n: 2, + size: undefined, + providerOptions: {}, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + }); + }); + + it('should extract the generated images', async () => { + prepareJsonResponse(); + + const result = await model.doGenerate({ + prompt, + n: 2, + size: undefined, + providerOptions: {}, + }); + + expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']); + }); + + it('throws when size is specified', async () => { + const model = new GoogleVertexImageModel('imagen-3.0-generate-001', { + provider: 'vertex', + baseURL: 'https://example.com', + }); + + await expect( + model.doGenerate({ + prompt: 'test prompt', + n: 1, + size: '1024x1024', + providerOptions: {}, + }), + ).rejects.toThrow(/Google Vertex does not support the `size` option./); + }); + + it('sends aspect ratio in the request', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt: 'test prompt', + n: 1, + size: undefined, + providerOptions: { + vertex: { + aspectRatio: '16:9', + }, + }, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + instances: [{ prompt: 'test prompt' }], + parameters: { + sampleCount: 1, + aspectRatio: '16:9', + }, + }); + }); + }); +}); diff --git a/packages/google-vertex/src/google-vertex-image-model.ts b/packages/google-vertex/src/google-vertex-image-model.ts new file mode 100644 index 000000000000..8dcde2f98b88 --- /dev/null +++ b/packages/google-vertex/src/google-vertex-image-model.ts @@ -0,0 +1,86 @@ +import { ImageModelV1, JSONValue } from '@ai-sdk/provider'; +import { + Resolvable, + postJsonToApi, + combineHeaders, + createJsonResponseHandler, + resolve, +} from '@ai-sdk/provider-utils'; +import { z } from 'zod'; +import { googleVertexFailedResponseHandler } from './google-vertex-error'; + +export type GoogleVertexImageModelId = + | 'imagen-3.0-generate-001' + | 'imagen-3.0-fast-generate-001'; + +interface GoogleVertexImageModelConfig { + provider: string; + baseURL: string; + headers?: Resolvable>; + fetch?: typeof fetch; +} + +// https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images +export class GoogleVertexImageModel implements ImageModelV1 { + readonly specificationVersion = 'v1'; + + get provider(): string { + return this.config.provider; + } + + constructor( + readonly modelId: GoogleVertexImageModelId, + private config: GoogleVertexImageModelConfig, + ) {} + + async doGenerate({ + prompt, + n, + size, + providerOptions, + headers, + abortSignal, + }: Parameters[0]): Promise< + Awaited> + > { + if (size) { + throw new Error( + 'Google Vertex does not support the `size` option. Use ' + + '`providerOptions.vertex.aspectRatio` instead. See ' + + 'https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio', + ); + } + + const body = { + instances: [{ prompt }], + parameters: { + sampleCount: n, + ...(providerOptions.vertex ?? {}), + }, + }; + + const { value: response } = await postJsonToApi({ + url: `${this.config.baseURL}/models/${this.modelId}:predict`, + headers: combineHeaders(await resolve(this.config.headers), headers), + body, + failedResponseHandler: googleVertexFailedResponseHandler, + successfulResponseHandler: createJsonResponseHandler( + vertexImageResponseSchema, + ), + abortSignal: abortSignal, + fetch: this.config.fetch, + }); + + return { + images: response.predictions.map( + (p: { bytesBase64Encoded: string }) => p.bytesBase64Encoded, + ), + }; + } +} + +// minimal version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const vertexImageResponseSchema = z.object({ + predictions: z.array(z.object({ bytesBase64Encoded: z.string() })), +}); diff --git a/packages/google-vertex/src/google-vertex-provider.test.ts b/packages/google-vertex/src/google-vertex-provider.test.ts index 12cf144fb63f..43d4deca5d33 100644 --- a/packages/google-vertex/src/google-vertex-provider.test.ts +++ b/packages/google-vertex/src/google-vertex-provider.test.ts @@ -18,6 +18,10 @@ vi.mock('./google-vertex-embedding-model', () => ({ GoogleVertexEmbeddingModel: vi.fn(), })); +vi.mock('./google-vertex-image-model', () => ({ + GoogleVertexImageModel: vi.fn(), +})); + describe('google-vertex-provider', () => { beforeEach(() => { vi.clearAllMocks(); diff --git a/packages/google-vertex/src/google-vertex-provider.ts b/packages/google-vertex/src/google-vertex-provider.ts index a58a8656d819..e572419d2d1b 100644 --- a/packages/google-vertex/src/google-vertex-provider.ts +++ b/packages/google-vertex/src/google-vertex-provider.ts @@ -1,4 +1,4 @@ -import { LanguageModelV1, ProviderV1 } from '@ai-sdk/provider'; +import { LanguageModelV1, ProviderV1, ImageModelV1 } from '@ai-sdk/provider'; import { FetchFunction, generateId, @@ -16,6 +16,10 @@ import { } from './google-vertex-embedding-settings'; import { GoogleVertexEmbeddingModel } from './google-vertex-embedding-model'; import { GoogleGenerativeAILanguageModel } from '@ai-sdk/google/internal'; +import { + GoogleVertexImageModel, + GoogleVertexImageModelId, +} from './google-vertex-image-model'; export interface GoogleVertexProvider extends ProviderV1 { /** @@ -30,6 +34,11 @@ Creates a model for text generation. modelId: GoogleVertexModelId, settings?: GoogleVertexSettings, ) => LanguageModelV1; + + /** + * Creates a model for image generation. + */ + image(modelId: GoogleVertexImageModelId): ImageModelV1; } export interface GoogleVertexProviderSettings { @@ -123,6 +132,14 @@ export function createVertex( baseURL: loadBaseURL(), }); + const createImageModel = (modelId: GoogleVertexImageModelId) => + new GoogleVertexImageModel(modelId, { + provider: `google.vertex.image`, + baseURL: loadBaseURL(), + headers: options.headers ?? {}, + fetch: options.fetch, + }); + const provider = function ( modelId: GoogleVertexModelId, settings?: GoogleVertexSettings, @@ -138,6 +155,7 @@ export function createVertex( provider.languageModel = createChatModel; provider.textEmbeddingModel = createEmbeddingModel; + provider.image = createImageModel; return provider as GoogleVertexProvider; }