Skip to content

Commit

Permalink
Allow users to choose API endpoint version (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubox76 authored Mar 5, 2024
1 parent d9c3f4d commit 932e1be
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 19 deletions.
5 changes: 5 additions & 0 deletions .changeset/short-pots-provide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": minor
---

Add `apiVersion` property to `RequestOptions` to allow user to choose API endpoint version.
2 changes: 1 addition & 1 deletion docs/reference/generative-ai.harmblockthreshold.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

## HarmBlockThreshold enum

Threshhold above which a prompt or candidate will be blocked.
Threshold above which a prompt or candidate will be blocked.

**Signature:**

Expand Down
2 changes: 1 addition & 1 deletion docs/reference/generative-ai.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
| --- | --- |
| [BlockReason](./generative-ai.blockreason.md) | Reason that a prompt was blocked. |
| [FinishReason](./generative-ai.finishreason.md) | Reason that a candidate finished. |
| [HarmBlockThreshold](./generative-ai.harmblockthreshold.md) | Threshhold above which a prompt or candidate will be blocked. |
| [HarmBlockThreshold](./generative-ai.harmblockthreshold.md) | Threshold above which a prompt or candidate will be blocked. |
| [HarmCategory](./generative-ai.harmcategory.md) | Harm categories that would cause prompts or candidates to be blocked. |
| [HarmProbability](./generative-ai.harmprobability.md) | Probability that a prompt or candidate matches a harm category. |
| [TaskType](./generative-ai.tasktype.md) | Task type for embedding content. |
Expand Down
13 changes: 13 additions & 0 deletions docs/reference/generative-ai.requestoptions.apiversion.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<!-- Do not edit this file. It is automatically generated by API Documenter. -->

[Home](./index.md) &gt; [@google/generative-ai](./generative-ai.md) &gt; [RequestOptions](./generative-ai.requestoptions.md) &gt; [apiVersion](./generative-ai.requestoptions.apiversion.md)

## RequestOptions.apiVersion property

Version of API endpoint to call (e.g. "v1" or "v1beta"). If not specified, defaults to latest stable version.

**Signature:**

```typescript
apiVersion?: string;
```
3 changes: 2 additions & 1 deletion docs/reference/generative-ai.requestoptions.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ export interface RequestOptions

| Property | Modifiers | Type | Description |
| --- | --- | --- | --- |
| [timeout?](./generative-ai.requestoptions.timeout.md) | | number | _(Optional)_ |
| [apiVersion?](./generative-ai.requestoptions.apiversion.md) | | string | _(Optional)_ Version of API endpoint to call (e.g. "v1" or "v1beta"). If not specified, defaults to latest stable version. |
| [timeout?](./generative-ai.requestoptions.timeout.md) | | number | _(Optional)_ Request timeout in milliseconds. |

2 changes: 2 additions & 0 deletions docs/reference/generative-ai.requestoptions.timeout.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

## RequestOptions.timeout property

Request timeout in milliseconds.

**Signature:**

```typescript
Expand Down
2 changes: 1 addition & 1 deletion packages/main/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export async function countTokens(
params: CountTokensRequest,
requestOptions?: RequestOptions,
): Promise<CountTokensResponse> {
const url = new RequestUrl(model, Task.COUNT_TOKENS, apiKey, false);
const url = new RequestUrl(model, Task.COUNT_TOKENS, apiKey, false, {});
const response = await makeRequest(
url,
JSON.stringify({ ...params, model }),
Expand Down
10 changes: 8 additions & 2 deletions packages/main/src/methods/embed-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ export async function embedContent(
params: EmbedContentRequest,
requestOptions?: RequestOptions,
): Promise<EmbedContentResponse> {
const url = new RequestUrl(model, Task.EMBED_CONTENT, apiKey, false);
const url = new RequestUrl(model, Task.EMBED_CONTENT, apiKey, false, {});
const response = await makeRequest(
url,
JSON.stringify(params),
Expand All @@ -45,7 +45,13 @@ export async function batchEmbedContents(
params: BatchEmbedContentsRequest,
requestOptions?: RequestOptions,
): Promise<BatchEmbedContentsResponse> {
const url = new RequestUrl(model, Task.BATCH_EMBED_CONTENTS, apiKey, false);
const url = new RequestUrl(
model,
Task.BATCH_EMBED_CONTENTS,
apiKey,
false,
{},
);
const requestsWithModel: EmbedContentRequest[] = params.requests.map(
(request) => {
return { ...request, model };
Expand Down
2 changes: 2 additions & 0 deletions packages/main/src/methods/generate-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ export async function generateContentStream(
Task.STREAM_GENERATE_CONTENT,
apiKey,
/* stream */ true,
requestOptions,
);
const response = await makeRequest(
url,
Expand All @@ -57,6 +58,7 @@ export async function generateContent(
Task.GENERATE_CONTENT,
apiKey,
/* stream */ false,
requestOptions,
);
const response = await makeRequest(
url,
Expand Down
26 changes: 25 additions & 1 deletion packages/main/src/requests/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import { expect, use } from "chai";
import { restore, stub } from "sinon";
import * as sinonChai from "sinon-chai";
import * as chaiAsPromised from "chai-as-promised";
import { RequestUrl, Task, makeRequest } from "./request";
import { DEFAULT_API_VERSION, RequestUrl, Task, makeRequest } from "./request";

use(sinonChai);
use(chaiAsPromised);
Expand All @@ -29,6 +29,7 @@ const fakeRequestUrl = new RequestUrl(
Task.GENERATE_CONTENT,
"key",
true,
{},
);

describe("request methods", () => {
Expand All @@ -42,6 +43,7 @@ describe("request methods", () => {
Task.GENERATE_CONTENT,
"key",
true,
{},
);
expect(url.toString()).to.include("models/model-name:generateContent");
expect(url.toString()).to.not.include("key");
Expand All @@ -53,17 +55,39 @@ describe("request methods", () => {
Task.GENERATE_CONTENT,
"key",
false,
{},
);
expect(url.toString()).to.include("models/model-name:generateContent");
expect(url.toString()).to.not.include("key");
expect(url.toString()).to.not.include("alt=sse");
});
it("default apiVersion", async () => {
const url = new RequestUrl(
"models/model-name",
Task.GENERATE_CONTENT,
"key",
false,
{},
);
expect(url.toString()).to.include(DEFAULT_API_VERSION);
});
it("custom apiVersion", async () => {
const url = new RequestUrl(
"models/model-name",
Task.GENERATE_CONTENT,
"key",
false,
{ apiVersion: "v2beta" },
);
expect(url.toString()).to.include("/v2beta/models/model-name");
});
it("non-stream - tunedModels/", async () => {
const url = new RequestUrl(
"tunedModels/model-name",
Task.GENERATE_CONTENT,
"key",
false,
{},
);
expect(url.toString()).to.include(
"tunedModels/model-name:generateContent",
Expand Down
6 changes: 4 additions & 2 deletions packages/main/src/requests/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { GoogleGenerativeAIError } from "../errors";

const BASE_URL = "https://generativelanguage.googleapis.com";

const API_VERSION = "v1";
export const DEFAULT_API_VERSION = "v1";

/**
* We can't `require` package.json if this runs on web. We will use rollup to
Expand All @@ -43,9 +43,11 @@ export class RequestUrl {
public task: Task,
public apiKey: string,
public stream: boolean,
public requestOptions: RequestOptions,
) {}
toString(): string {
let url = `${BASE_URL}/${API_VERSION}/${this.model}:${this.task}`;
const apiVersion = this.requestOptions?.apiVersion || DEFAULT_API_VERSION;
let url = `${BASE_URL}/${apiVersion}/${this.model}:${this.task}`;
if (this.stream) {
url += "?alt=sse";
}
Expand Down
18 changes: 18 additions & 0 deletions packages/main/test-integration/node/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,24 @@ describe("generateContent", function () {
const response = result.response;
expect(response.text()).to.not.be.empty;
});
it("non-streaming, simple interface, custom API version", async () => {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || "");
const model = genAI.getGenerativeModel(
{
model: "gemini-pro",
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
],
},
{ apiVersion: "v1beta" },
);
const result = await model.generateContent("What do cats eat?");
const response = result.response;
expect(response.text()).to.not.be.empty;
});
it("non-streaming, image buffer provided", async () => {
const imageBuffer = fs.readFileSync(
join(__dirname, "../../test-utils/cat.png"),
Expand Down
30 changes: 20 additions & 10 deletions packages/main/test-integration/web/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,24 @@ describe("generateContent", function () {
const response = result.response;
expect(response.text()).to.not.be.empty;
});
it("non-streaming, simple interface, custom API version", async () => {
const genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY || "");
const model = genAI.getGenerativeModel(
{
model: "gemini-pro",
safetySettings: [
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
],
},
{ apiVersion: "v1beta" },
);
const result = await model.generateContent("What do cats eat?");
const response = result.response;
expect(response.text()).to.not.be.empty;
});
});

describe("startChat", function () {
Expand Down Expand Up @@ -150,11 +168,7 @@ describe("startChat", function () {
const question1 = "What is the capital of Oregon?";
const question2 = "How many people live there?";
const question3 = "What is the closest river?";
const chat = model.startChat({
generationConfig: {
maxOutputTokens: 100,
},
});
const chat = model.startChat();
const result1 = await chat.sendMessageStream(question1);
const response1 = await result1.response;
expect(response1.text()).to.not.be.empty;
Expand Down Expand Up @@ -190,11 +204,7 @@ describe("startChat", function () {
const question1 = "What are the most interesting cities in Oregon?";
const question2 = "How many people live there?";
const question3 = "What is the closest river?";
const chat = model.startChat({
generationConfig: {
maxOutputTokens: 100,
},
});
const chat = model.startChat();
const promise1 = chat.sendMessageStream(question1).then(async (result1) => {
for await (const response of result1.stream) {
expect(response.text()).to.not.be.empty;
Expand Down
8 changes: 8 additions & 0 deletions packages/main/types/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,13 @@ export interface BatchEmbedContentsRequest {
* @public
*/
export interface RequestOptions {
/**
* Request timeout in milliseconds.
*/
timeout?: number;
/**
* Version of API endpoint to call (e.g. "v1" or "v1beta"). If not specified,
* defaults to latest stable version.
*/
apiVersion?: string;
}

0 comments on commit 932e1be

Please sign in to comment.