From 5e631cdf08295a594a2c76e367913201ed558480 Mon Sep 17 00:00:00 2001 From: Walter Korman Date: Wed, 18 Dec 2024 23:20:35 -0800 Subject: [PATCH] address feedback --- ...t-cooks-breathe.md => neat-rice-travel.md} | 1 + .../ai-core/src/e2e/google-vertex.test.ts | 19 +- .../src/generate-image/google-vertex.ts | 3 +- packages/ai/core/index.ts | 1 + .../src/google-vertex-image-model.test.ts | 188 ++++++++++-------- .../src/google-vertex-image-model.ts | 104 +++++----- .../src/google-vertex-provider.test.ts | 4 + pnpm-lock.yaml | 14 +- 8 files changed, 184 insertions(+), 150 deletions(-) rename .changeset/{silent-cooks-breathe.md => neat-rice-travel.md} (88%) diff --git a/.changeset/silent-cooks-breathe.md b/.changeset/neat-rice-travel.md similarity index 88% rename from .changeset/silent-cooks-breathe.md rename to .changeset/neat-rice-travel.md index 5d0d126d0495..4617fd0c116b 100644 --- a/.changeset/silent-cooks-breathe.md +++ b/.changeset/neat-rice-travel.md @@ -1,5 +1,6 @@ --- '@ai-sdk/google-vertex': patch +'ai': patch --- feat (provider/google-vertex): Add imagen support. diff --git a/examples/ai-core/src/e2e/google-vertex.test.ts b/examples/ai-core/src/e2e/google-vertex.test.ts index 4e4c8ecf4c04..c92f0681c13f 100644 --- a/examples/ai-core/src/e2e/google-vertex.test.ts +++ b/examples/ai-core/src/e2e/google-vertex.test.ts @@ -4,6 +4,7 @@ import { vertex as vertexEdge } from '@ai-sdk/google-vertex/edge'; import { vertex as vertexNode } from '@ai-sdk/google-vertex'; import { z } from 'zod'; import { + detectImageMimeType, generateText, generateObject, streamText, @@ -462,7 +463,11 @@ describe.each(Object.values(RUNTIME_VARIANTS))( const { image } = await generateImage({ model, prompt: 'A burrito launched through a tunnel', - size: '1024x1024', + providerOptions: { + vertex: { + aspectRatio: '3:4', + }, + }, }); // Verify we got a Uint8Array back @@ -472,10 +477,9 @@ describe.each(Object.values(RUNTIME_VARIANTS))( expect(image.uint8Array.length).toBeGreaterThan(10 * 1024); expect(image.uint8Array.length).toBeLessThan(10 * 1024 * 1024); - // Verify PNG format by checking magic numbers - const pngSignature = [137, 80, 78, 71, 13, 10, 26, 10]; - const actualSignature = Array.from(image.uint8Array.slice(0, 8)); - expect(actualSignature).toEqual(pngSignature); + // Verify PNG format + const mimeType = detectImageMimeType(image.uint8Array); + expect(mimeType).toBe('image/png'); // Create a temporary buffer to verify image dimensions const tempBuffer = Buffer.from(image.uint8Array); @@ -484,8 +488,9 @@ describe.each(Object.values(RUNTIME_VARIANTS))( const width = tempBuffer.readUInt32BE(16); const height = tempBuffer.readUInt32BE(20); - expect(width).toBe(1024); - expect(height).toBe(1024); + // https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#performance-limits + expect(width).toBe(896); + expect(height).toBe(1280); }); }); }, diff --git a/examples/ai-core/src/generate-image/google-vertex.ts b/examples/ai-core/src/generate-image/google-vertex.ts index f69fc0d5eab2..699ead66c91b 100644 --- a/examples/ai-core/src/generate-image/google-vertex.ts +++ b/examples/ai-core/src/generate-image/google-vertex.ts @@ -7,10 +7,9 @@ async function main() { const { image } = await generateImage({ model: vertex.image('imagen-3.0-generate-001'), prompt: 'A burrito launched through a tunnel', - size: '1024x1024', providerOptions: { vertex: { - // Vertex AI specific options if needed + aspectRatio: '16:9', }, }, }); diff --git a/packages/ai/core/index.ts b/packages/ai/core/index.ts index e975f474c664..3931546d4fb3 100644 --- a/packages/ai/core/index.ts +++ b/packages/ai/core/index.ts @@ -11,3 +11,4 @@ export * from './registry'; export * from './tool'; export * from './types'; export { cosineSimilarity } from './util/cosine-similarity'; +export { detectImageMimeType } from './util/detect-image-mimetype'; diff --git a/packages/google-vertex/src/google-vertex-image-model.test.ts b/packages/google-vertex/src/google-vertex-image-model.test.ts index 3639c4ad7b17..e45ded4f6fd8 100644 --- a/packages/google-vertex/src/google-vertex-image-model.test.ts +++ b/packages/google-vertex/src/google-vertex-image-model.test.ts @@ -1,5 +1,6 @@ import { JsonTestServer } from '@ai-sdk/provider-utils/test'; import { GoogleVertexImageModel } from './google-vertex-image-model'; +import { describe, it, expect, vi } from 'vitest'; const prompt = 'A cute baby sea otter'; @@ -9,104 +10,125 @@ const model = new GoogleVertexImageModel('imagen-3.0-generate-001', { headers: { 'api-key': 'test-key' }, }); -describe('doGenerate', () => { - const server = new JsonTestServer( - 'https://api.example.com/models/imagen-3.0-generate-001:predict', - ); - - server.setupTestEnvironment(); - - function prepareJsonResponse() { - server.responseBodyJson = { - predictions: [ - { bytesBase64Encoded: 'base64-image-1' }, - { bytesBase64Encoded: 'base64-image-2' }, - ], - }; - } - - it('should pass the correct parameters', async () => { - prepareJsonResponse(); - - await model.doGenerate({ - prompt, - n: 2, - size: '1024x1024', - providerOptions: { customOption: { value: 123 } }, - }); +describe('GoogleVertexImageModel', () => { + describe('doGenerate', () => { + const server = new JsonTestServer( + 'https://api.example.com/models/imagen-3.0-generate-001:predict', + ); - expect(await server.getRequestBodyJson()).toStrictEqual({ - instances: [{ prompt }], - parameters: { - sampleCount: 2, - aspectRatio: '1:1', - customOption: { value: 123 }, - }, + server.setupTestEnvironment(); + + function prepareJsonResponse() { + server.responseBodyJson = { + predictions: [ + { bytesBase64Encoded: 'base64-image-1' }, + { bytesBase64Encoded: 'base64-image-2' }, + ], + }; + } + + it('should pass the correct parameters', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt, + n: 2, + size: undefined, + providerOptions: { vertex: { aspectRatio: '1:1' } }, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + instances: [{ prompt }], + parameters: { + sampleCount: 2, + aspectRatio: '1:1', + }, + }); }); - }); - it('should pass headers', async () => { - prepareJsonResponse(); + it('should pass headers', async () => { + prepareJsonResponse(); + + const modelWithHeaders = new GoogleVertexImageModel( + 'imagen-3.0-generate-001', + { + provider: 'google-vertex', + baseURL: 'https://api.example.com', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }, + ); - const modelWithHeaders = new GoogleVertexImageModel( - 'imagen-3.0-generate-001', - { - provider: 'google-vertex', - baseURL: 'https://api.example.com', + await modelWithHeaders.doGenerate({ + prompt, + n: 2, + size: undefined, + providerOptions: {}, headers: { - 'Custom-Provider-Header': 'provider-header-value', + 'Custom-Request-Header': 'request-header-value', }, - }, - ); - - await modelWithHeaders.doGenerate({ - prompt, - n: 2, - size: '1024x1024', - providerOptions: {}, - headers: { - 'Custom-Request-Header': 'request-header-value', - }, - }); + }); - const requestHeaders = await server.getRequestHeaders(); + const requestHeaders = await server.getRequestHeaders(); - expect(requestHeaders).toStrictEqual({ - 'content-type': 'application/json', - 'custom-provider-header': 'provider-header-value', - 'custom-request-header': 'request-header-value', + expect(requestHeaders).toStrictEqual({ + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + }); }); - }); - - it('should extract the generated images', async () => { - prepareJsonResponse(); - const result = await model.doGenerate({ - prompt, - n: 2, - size: undefined, - providerOptions: {}, - }); + it('should extract the generated images', async () => { + prepareJsonResponse(); - expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']); - }); + const result = await model.doGenerate({ + prompt, + n: 2, + size: undefined, + providerOptions: {}, + }); - it('should handle different aspect ratios', async () => { - prepareJsonResponse(); + expect(result.images).toStrictEqual(['base64-image-1', 'base64-image-2']); + }); - await model.doGenerate({ - prompt, - n: 1, - size: '1280x896', - providerOptions: {}, + it('throws when size is specified', async () => { + const model = new GoogleVertexImageModel('imagen-3.0-generate-001', { + provider: 'vertex', + baseURL: 'https://example.com', + }); + + await expect( + model.doGenerate({ + prompt: 'test prompt', + n: 1, + size: '1024x1024', + providerOptions: {}, + }), + ).rejects.toThrow(/Google Vertex does not support the `size` option./); }); - expect(await server.getRequestBodyJson()).toStrictEqual({ - instances: [{ prompt }], - parameters: { - sampleCount: 1, - aspectRatio: '4:3', - }, + it('sends aspect ratio in the request', async () => { + prepareJsonResponse(); + + await model.doGenerate({ + prompt: 'test prompt', + n: 1, + size: undefined, + providerOptions: { + vertex: { + aspectRatio: '16:9', + }, + }, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + instances: [{ prompt: 'test prompt' }], + parameters: { + sampleCount: 1, + aspectRatio: '16:9', + }, + }); }); }); }); diff --git a/packages/google-vertex/src/google-vertex-image-model.ts b/packages/google-vertex/src/google-vertex-image-model.ts index 04cc4f752842..8dcde2f98b88 100644 --- a/packages/google-vertex/src/google-vertex-image-model.ts +++ b/packages/google-vertex/src/google-vertex-image-model.ts @@ -1,11 +1,19 @@ import { ImageModelV1, JSONValue } from '@ai-sdk/provider'; -import { Resolvable, resolve } from '@ai-sdk/provider-utils'; +import { + Resolvable, + postJsonToApi, + combineHeaders, + createJsonResponseHandler, + resolve, +} from '@ai-sdk/provider-utils'; +import { z } from 'zod'; +import { googleVertexFailedResponseHandler } from './google-vertex-error'; export type GoogleVertexImageModelId = | 'imagen-3.0-generate-001' | 'imagen-3.0-fast-generate-001'; -interface GoogleVertexImageModelOptions { +interface GoogleVertexImageModelConfig { provider: string; baseURL: string; headers?: Resolvable>; @@ -17,68 +25,62 @@ export class GoogleVertexImageModel implements ImageModelV1 { readonly specificationVersion = 'v1'; get provider(): string { - return this.options.provider; + return this.config.provider; } constructor( readonly modelId: GoogleVertexImageModelId, - private options: GoogleVertexImageModelOptions, + private config: GoogleVertexImageModelConfig, ) {} - async doGenerate(options: { - prompt: string; - n: number; - size: `${number}x${number}` | undefined; - providerOptions: Record>; - abortSignal?: AbortSignal; - headers?: Record; - }): Promise<{ images: string[] }> { - const [width, height] = (options.size ?? '1024x1024') - .split('x') - .map(Number); + async doGenerate({ + prompt, + n, + size, + providerOptions, + headers, + abortSignal, + }: Parameters[0]): Promise< + Awaited> + > { + if (size) { + throw new Error( + 'Google Vertex does not support the `size` option. Use ' + + '`providerOptions.vertex.aspectRatio` instead. See ' + + 'https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-images#aspect-ratio', + ); + } - const response = await (this.options.fetch ?? fetch)( - `${this.options.baseURL}/models/${this.modelId}:predict`, - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - ...(await resolve(this.options.headers)), - ...options.headers, - }, - signal: options.abortSignal, - body: JSON.stringify({ - instances: [{ prompt: options.prompt }], - parameters: { - sampleCount: options.n, - aspectRatio: this.getAspectRatio(width, height), - ...options.providerOptions, - }, - }), + const body = { + instances: [{ prompt }], + parameters: { + sampleCount: n, + ...(providerOptions.vertex ?? {}), }, - ); + }; - if (!response.ok) { - throw new Error(`HTTP error! status: ${response.status}`); - } + const { value: response } = await postJsonToApi({ + url: `${this.config.baseURL}/models/${this.modelId}:predict`, + headers: combineHeaders(await resolve(this.config.headers), headers), + body, + failedResponseHandler: googleVertexFailedResponseHandler, + successfulResponseHandler: createJsonResponseHandler( + vertexImageResponseSchema, + ), + abortSignal: abortSignal, + fetch: this.config.fetch, + }); - const data = await response.json(); return { - images: data.predictions.map( + images: response.predictions.map( (p: { bytesBase64Encoded: string }) => p.bytesBase64Encoded, ), }; } - - private getAspectRatio(width: number, height: number): string { - // Map common dimensions to Imagen's supported aspect ratios - if (width === height) return '1:1'; - if (width === 896 && height === 1280) return '3:4'; - if (width === 1280 && height === 896) return '4:3'; - if (width === 768 && height === 1408) return '9:16'; - if (width === 1408 && height === 768) return '16:9'; - - // Default to 1:1 if no match - return '1:1'; - } } + +// minimal version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const vertexImageResponseSchema = z.object({ + predictions: z.array(z.object({ bytesBase64Encoded: z.string() })), +}); diff --git a/packages/google-vertex/src/google-vertex-provider.test.ts b/packages/google-vertex/src/google-vertex-provider.test.ts index 12cf144fb63f..43d4deca5d33 100644 --- a/packages/google-vertex/src/google-vertex-provider.test.ts +++ b/packages/google-vertex/src/google-vertex-provider.test.ts @@ -18,6 +18,10 @@ vi.mock('./google-vertex-embedding-model', () => ({ GoogleVertexEmbeddingModel: vi.fn(), })); +vi.mock('./google-vertex-image-model', () => ({ + GoogleVertexImageModel: vi.fn(), +})); + describe('google-vertex-provider', () => { beforeEach(() => { vi.clearAllMocks(); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 778acd632814..9066b128559c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -21590,7 +21590,7 @@ snapshots: eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 eslint-import-resolver-typescript: 3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1) - eslint-plugin-import: 2.29.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.1) + eslint-plugin-import: 2.29.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) eslint-plugin-jsx-a11y: 6.9.0(eslint@8.57.1) eslint-plugin-react: 7.35.0(eslint@8.57.1) eslint-plugin-react-hooks: 4.6.2(eslint@8.57.1) @@ -21639,8 +21639,8 @@ snapshots: debug: 4.3.7(supports-color@9.4.0) enhanced-resolve: 5.17.1 eslint: 8.57.1 - eslint-module-utils: 2.8.0(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.1) - eslint-plugin-import: 2.29.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.1) + eslint-module-utils: 2.8.0(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) + eslint-plugin-import: 2.29.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) fast-glob: 3.3.2 get-tsconfig: 4.7.2 is-core-module: 2.13.1 @@ -21662,7 +21662,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-module-utils@2.8.0(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.1): + eslint-module-utils@2.8.0(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1): dependencies: debug: 3.2.7 optionalDependencies: @@ -21684,7 +21684,7 @@ snapshots: transitivePeerDependencies: - supports-color - eslint-module-utils@2.8.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.1): + eslint-module-utils@2.8.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1): dependencies: debug: 3.2.7 optionalDependencies: @@ -21722,7 +21722,7 @@ snapshots: - eslint-import-resolver-webpack - supports-color - eslint-plugin-import@2.29.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.1): + eslint-plugin-import@2.29.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1): dependencies: array-includes: 3.1.8 array.prototype.findlastindex: 1.2.5 @@ -21732,7 +21732,7 @@ snapshots: doctrine: 2.1.0 eslint: 8.57.1 eslint-import-resolver-node: 0.3.9 - eslint-module-utils: 2.8.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1)(eslint@8.57.1) + eslint-module-utils: 2.8.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-import-resolver-typescript@3.6.1(@typescript-eslint/parser@7.2.0(eslint@8.57.1)(typescript@5.6.3))(eslint-import-resolver-node@0.3.9)(eslint-plugin-import@2.29.1(eslint@8.57.1))(eslint@8.57.1))(eslint@8.57.1) hasown: 2.0.2 is-core-module: 2.15.0 is-glob: 4.0.3