diff --git a/.changeset/short-pots-provide.md b/.changeset/short-pots-provide.md new file mode 100644 index 00000000..0f2ce96e --- /dev/null +++ b/.changeset/short-pots-provide.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": minor +--- + +Add `apiVersion` property to `RequestOptions` to allow user to choose API endpoint version. diff --git a/docs/reference/generative-ai.harmblockthreshold.md b/docs/reference/generative-ai.harmblockthreshold.md index a2917d47..95e1bcb9 100644 --- a/docs/reference/generative-ai.harmblockthreshold.md +++ b/docs/reference/generative-ai.harmblockthreshold.md @@ -4,7 +4,7 @@ ## HarmBlockThreshold enum -Threshhold above which a prompt or candidate will be blocked. +Threshold above which a prompt or candidate will be blocked. **Signature:** diff --git a/docs/reference/generative-ai.md b/docs/reference/generative-ai.md index ed5b752d..23605fc6 100644 --- a/docs/reference/generative-ai.md +++ b/docs/reference/generative-ai.md @@ -18,7 +18,7 @@ | --- | --- | | [BlockReason](./generative-ai.blockreason.md) | Reason that a prompt was blocked. | | [FinishReason](./generative-ai.finishreason.md) | Reason that a candidate finished. | -| [HarmBlockThreshold](./generative-ai.harmblockthreshold.md) | Threshhold above which a prompt or candidate will be blocked. | +| [HarmBlockThreshold](./generative-ai.harmblockthreshold.md) | Threshold above which a prompt or candidate will be blocked. | | [HarmCategory](./generative-ai.harmcategory.md) | Harm categories that would cause prompts or candidates to be blocked. | | [HarmProbability](./generative-ai.harmprobability.md) | Probability that a prompt or candidate matches a harm category. | | [TaskType](./generative-ai.tasktype.md) | Task type for embedding content. | diff --git a/docs/reference/generative-ai.requestoptions.apiversion.md b/docs/reference/generative-ai.requestoptions.apiversion.md new file mode 100644 index 00000000..2edf5f19 --- /dev/null +++ b/docs/reference/generative-ai.requestoptions.apiversion.md @@ -0,0 +1,13 @@ + + +[Home](./index.md) > [@google/generative-ai](./generative-ai.md) > [RequestOptions](./generative-ai.requestoptions.md) > [apiVersion](./generative-ai.requestoptions.apiversion.md) + +## RequestOptions.apiVersion property + +Version of API endpoint to call (e.g. "v1" or "v1beta"). If not specified, defaults to latest stable version. + +**Signature:** + +```typescript +apiVersion?: string; +``` diff --git a/docs/reference/generative-ai.requestoptions.md b/docs/reference/generative-ai.requestoptions.md index f3ca57c7..8ef9862a 100644 --- a/docs/reference/generative-ai.requestoptions.md +++ b/docs/reference/generative-ai.requestoptions.md @@ -16,5 +16,6 @@ export interface RequestOptions | Property | Modifiers | Type | Description | | --- | --- | --- | --- | -| [timeout?](./generative-ai.requestoptions.timeout.md) | | number | _(Optional)_ | +| [apiVersion?](./generative-ai.requestoptions.apiversion.md) | | string | _(Optional)_ Version of API endpoint to call (e.g. "v1" or "v1beta"). If not specified, defaults to latest stable version. | +| [timeout?](./generative-ai.requestoptions.timeout.md) | | number | _(Optional)_ Request timeout in milliseconds. | diff --git a/docs/reference/generative-ai.requestoptions.timeout.md b/docs/reference/generative-ai.requestoptions.timeout.md index a6c1eeb8..60526d20 100644 --- a/docs/reference/generative-ai.requestoptions.timeout.md +++ b/docs/reference/generative-ai.requestoptions.timeout.md @@ -4,6 +4,8 @@ ## RequestOptions.timeout property +Request timeout in milliseconds. + **Signature:** ```typescript diff --git a/packages/main/src/methods/count-tokens.ts b/packages/main/src/methods/count-tokens.ts index f41d1eeb..1b065169 100644 --- a/packages/main/src/methods/count-tokens.ts +++ b/packages/main/src/methods/count-tokens.ts @@ -28,7 +28,7 @@ export async function countTokens( params: CountTokensRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl(model, Task.COUNT_TOKENS, apiKey, false); + const url = new RequestUrl(model, Task.COUNT_TOKENS, apiKey, false, {}); const response = await makeRequest( url, JSON.stringify({ ...params, model }), diff --git a/packages/main/src/methods/embed-content.ts b/packages/main/src/methods/embed-content.ts index 377839e6..cd4cc26a 100644 --- a/packages/main/src/methods/embed-content.ts +++ b/packages/main/src/methods/embed-content.ts @@ -30,7 +30,7 @@ export async function embedContent( params: EmbedContentRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl(model, Task.EMBED_CONTENT, apiKey, false); + const url = new RequestUrl(model, Task.EMBED_CONTENT, apiKey, false, {}); const response = await makeRequest( url, JSON.stringify(params), @@ -45,7 +45,13 @@ export async function batchEmbedContents( params: BatchEmbedContentsRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl(model, Task.BATCH_EMBED_CONTENTS, apiKey, false); + const url = new RequestUrl( + model, + Task.BATCH_EMBED_CONTENTS, + apiKey, + false, + {}, + ); const requestsWithModel: EmbedContentRequest[] = params.requests.map( (request) => { return { ...request, model }; diff --git a/packages/main/src/methods/generate-content.ts b/packages/main/src/methods/generate-content.ts index 31f2ce8f..2af684f6 100644 --- a/packages/main/src/methods/generate-content.ts +++ b/packages/main/src/methods/generate-content.ts @@ -37,6 +37,7 @@ export async function generateContentStream( Task.STREAM_GENERATE_CONTENT, apiKey, /* stream */ true, + requestOptions, ); const response = await makeRequest( url, @@ -57,6 +58,7 @@ export async function generateContent( Task.GENERATE_CONTENT, apiKey, /* stream */ false, + requestOptions, ); const response = await makeRequest( url, diff --git a/packages/main/src/requests/request.test.ts b/packages/main/src/requests/request.test.ts index 6981fc74..8ef93970 100644 --- a/packages/main/src/requests/request.test.ts +++ b/packages/main/src/requests/request.test.ts @@ -19,7 +19,7 @@ import { expect, use } from "chai"; import { restore, stub } from "sinon"; import * as sinonChai from "sinon-chai"; import * as chaiAsPromised from "chai-as-promised"; -import { RequestUrl, Task, makeRequest } from "./request"; +import { DEFAULT_API_VERSION, RequestUrl, Task, makeRequest } from "./request"; use(sinonChai); use(chaiAsPromised); @@ -29,6 +29,7 @@ const fakeRequestUrl = new RequestUrl( Task.GENERATE_CONTENT, "key", true, + {}, ); describe("request methods", () => { @@ -42,6 +43,7 @@ describe("request methods", () => { Task.GENERATE_CONTENT, "key", true, + {}, ); expect(url.toString()).to.include("models/model-name:generateContent"); expect(url.toString()).to.not.include("key"); @@ -53,17 +55,39 @@ describe("request methods", () => { Task.GENERATE_CONTENT, "key", false, + {}, ); expect(url.toString()).to.include("models/model-name:generateContent"); expect(url.toString()).to.not.include("key"); expect(url.toString()).to.not.include("alt=sse"); }); + it("default apiVersion", async () => { + const url = new RequestUrl( + "models/model-name", + Task.GENERATE_CONTENT, + "key", + false, + {}, + ); + expect(url.toString()).to.include(DEFAULT_API_VERSION); + }); + it("custom apiVersion", async () => { + const url = new RequestUrl( + "models/model-name", + Task.GENERATE_CONTENT, + "key", + false, + { apiVersion: "v2beta" }, + ); + expect(url.toString()).to.include("/v2beta/models/model-name"); + }); it("non-stream - tunedModels/", async () => { const url = new RequestUrl( "tunedModels/model-name", Task.GENERATE_CONTENT, "key", false, + {}, ); expect(url.toString()).to.include( "tunedModels/model-name:generateContent", diff --git a/packages/main/src/requests/request.ts b/packages/main/src/requests/request.ts index 1ec343ef..b816b258 100644 --- a/packages/main/src/requests/request.ts +++ b/packages/main/src/requests/request.ts @@ -20,7 +20,7 @@ import { GoogleGenerativeAIError } from "../errors"; const BASE_URL = "https://generativelanguage.googleapis.com"; -const API_VERSION = "v1"; +export const DEFAULT_API_VERSION = "v1"; /** * We can't `require` package.json if this runs on web. We will use rollup to @@ -43,9 +43,11 @@ export class RequestUrl { public task: Task, public apiKey: string, public stream: boolean, + public requestOptions: RequestOptions, ) {} toString(): string { - let url = `${BASE_URL}/${API_VERSION}/${this.model}:${this.task}`; + const apiVersion = this.requestOptions?.apiVersion || DEFAULT_API_VERSION; + let url = `${BASE_URL}/${apiVersion}/${this.model}:${this.task}`; if (this.stream) { url += "?alt=sse"; } diff --git a/packages/main/test-integration/node/index.test.ts b/packages/main/test-integration/node/index.test.ts index 8d78bb6d..2a747d04 100644 --- a/packages/main/test-integration/node/index.test.ts +++ b/packages/main/test-integration/node/index.test.ts @@ -131,6 +131,24 @@ describe("generateContent", function () { const response = result.response; expect(response.text()).to.not.be.empty; }); + it("non-streaming, simple interface, custom API version", async () => { + const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || ""); + const model = genAI.getGenerativeModel( + { + model: "gemini-pro", + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH, + }, + ], + }, + { apiVersion: "v1beta" }, + ); + const result = await model.generateContent("What do cats eat?"); + const response = result.response; + expect(response.text()).to.not.be.empty; + }); it("non-streaming, image buffer provided", async () => { const imageBuffer = fs.readFileSync( join(__dirname, "../../test-utils/cat.png"), diff --git a/packages/main/test-integration/web/index.test.ts b/packages/main/test-integration/web/index.test.ts index e9cad202..49631cc4 100644 --- a/packages/main/test-integration/web/index.test.ts +++ b/packages/main/test-integration/web/index.test.ts @@ -68,6 +68,24 @@ describe("generateContent", function () { const response = result.response; expect(response.text()).to.not.be.empty; }); + it("non-streaming, simple interface, custom API version", async () => { + const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || ""); + const model = genAI.getGenerativeModel( + { + model: "gemini-pro", + safetySettings: [ + { + category: HarmCategory.HARM_CATEGORY_HARASSMENT, + threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH, + }, + ], + }, + { apiVersion: "v1beta" }, + ); + const result = await model.generateContent("What do cats eat?"); + const response = result.response; + expect(response.text()).to.not.be.empty; + }); }); describe("startChat", function () { @@ -150,11 +168,7 @@ describe("startChat", function () { const question1 = "What is the capital of Oregon?"; const question2 = "How many people live there?"; const question3 = "What is the closest river?"; - const chat = model.startChat({ - generationConfig: { - maxOutputTokens: 100, - }, - }); + const chat = model.startChat(); const result1 = await chat.sendMessageStream(question1); const response1 = await result1.response; expect(response1.text()).to.not.be.empty; @@ -190,11 +204,7 @@ describe("startChat", function () { const question1 = "What are the most interesting cities in Oregon?"; const question2 = "How many people live there?"; const question3 = "What is the closest river?"; - const chat = model.startChat({ - generationConfig: { - maxOutputTokens: 100, - }, - }); + const chat = model.startChat(); const promise1 = chat.sendMessageStream(question1).then(async (result1) => { for await (const response of result1.stream) { expect(response.text()).to.not.be.empty; diff --git a/packages/main/types/requests.ts b/packages/main/types/requests.ts index ca71ef28..887cf9c6 100644 --- a/packages/main/types/requests.ts +++ b/packages/main/types/requests.ts @@ -104,5 +104,13 @@ export interface BatchEmbedContentsRequest { * @public */ export interface RequestOptions { + /** + * Request timeout in milliseconds. + */ timeout?: number; + /** + * Version of API endpoint to call (e.g. "v1" or "v1beta"). If not specified, + * defaults to latest stable version. + */ + apiVersion?: string; }