Skip to content

Commit

Permalink
Refactor makeRequest (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsubox76 authored Apr 8, 2024
1 parent 6fcca28 commit 0931d2c
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 105 deletions.
5 changes: 5 additions & 0 deletions .changeset/modern-brooms-sparkle.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"@google/generative-ai": patch
---

Refactor makeRequest to make fetch mockable.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
matrix:
# lock version 20 for now as 20.12.0 makes global fetch unstubbable
# until we can rewrite tests to stub some other way
node-version: ['18.x', '20.11.1']
node-version: ['18.x', '20.x']
steps:
- uses: actions/checkout@v4
- name: Use Node.js
Expand Down
8 changes: 2 additions & 6 deletions packages/main/src/methods/count-tokens.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,19 @@ import {
CountTokensResponse,
RequestOptions,
} from "../../types";
import { RequestUrl, Task, makeRequest } from "../requests/request";
import { Task, makeRequest } from "../requests/request";

export async function countTokens(
apiKey: string,
model: string,
params: CountTokensRequest,
requestOptions?: RequestOptions,
): Promise<CountTokensResponse> {
const url = new RequestUrl(
const response = await makeRequest(
model,
Task.COUNT_TOKENS,
apiKey,
false,
requestOptions,
);
const response = await makeRequest(
url,
JSON.stringify({ ...params, model }),
requestOptions,
);
Expand Down
20 changes: 6 additions & 14 deletions packages/main/src/methods/embed-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,19 @@ import {
EmbedContentResponse,
RequestOptions,
} from "../../types";
import { RequestUrl, Task, makeRequest } from "../requests/request";
import { Task, makeRequest } from "../requests/request";

export async function embedContent(
apiKey: string,
model: string,
params: EmbedContentRequest,
requestOptions?: RequestOptions,
): Promise<EmbedContentResponse> {
const url = new RequestUrl(
const response = await makeRequest(
model,
Task.EMBED_CONTENT,
apiKey,
false,
requestOptions,
);
const response = await makeRequest(
url,
JSON.stringify(params),
requestOptions,
);
Expand All @@ -51,20 +47,16 @@ export async function batchEmbedContents(
params: BatchEmbedContentsRequest,
requestOptions?: RequestOptions,
): Promise<BatchEmbedContentsResponse> {
const url = new RequestUrl(
model,
Task.BATCH_EMBED_CONTENTS,
apiKey,
false,
requestOptions,
);
const requestsWithModel: EmbedContentRequest[] = params.requests.map(
(request) => {
return { ...request, model };
},
);
const response = await makeRequest(
url,
model,
Task.BATCH_EMBED_CONTENTS,
apiKey,
false,
JSON.stringify({ requests: requestsWithModel }),
requestOptions,
);
Expand Down
46 changes: 33 additions & 13 deletions packages/main/src/methods/generate-content.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ describe("generateContent()", () => {
const result = await generateContent("key", "model", fakeRequestParams);
expect(result.response.text()).to.include("Helena");
expect(makeRequestStub).to.be.calledWith(
match.instanceOf(request.RequestUrl),
"model",
request.Task.GENERATE_CONTENT,
"key",
false,
match((value: string) => {
return value.includes("contents");
}),
Expand All @@ -73,7 +76,10 @@ describe("generateContent()", () => {
expect(result.response.text()).to.include("Use Freshly Ground Coffee");
expect(result.response.text()).to.include("30 minutes of brewing");
expect(makeRequestStub).to.be.calledWith(
match.instanceOf(request.RequestUrl),
"model",
request.Task.GENERATE_CONTENT,
"key",
false,
match.any,
);
});
Expand All @@ -88,7 +94,10 @@ describe("generateContent()", () => {
result.response.candidates[0].citationMetadata.citationSources.length,
).to.equal(1);
expect(makeRequestStub).to.be.calledWith(
match.instanceOf(request.RequestUrl),
"model",
request.Task.GENERATE_CONTENT,
"key",
false,
match.any,
);
});
Expand All @@ -102,7 +111,10 @@ describe("generateContent()", () => {
const result = await generateContent("key", "model", fakeRequestParams);
expect(result.response.text).to.throw("SAFETY");
expect(makeRequestStub).to.be.calledWith(
match.instanceOf(request.RequestUrl),
"model",
request.Task.GENERATE_CONTENT,
"key",
false,
match.any,
);
});
Expand All @@ -116,7 +128,10 @@ describe("generateContent()", () => {
const result = await generateContent("key", "model", fakeRequestParams);
expect(result.response.text).to.throw("SAFETY");
expect(makeRequestStub).to.be.calledWith(
match.instanceOf(request.RequestUrl),
"model",
request.Task.GENERATE_CONTENT,
"key",
false,
match.any,
);
});
Expand All @@ -128,7 +143,10 @@ describe("generateContent()", () => {
const result = await generateContent("key", "model", fakeRequestParams);
expect(result.response.text()).to.equal("");
expect(makeRequestStub).to.be.calledWith(
match.instanceOf(request.RequestUrl),
"model",
request.Task.GENERATE_CONTENT,
"key",
false,
match.any,
);
});
Expand All @@ -140,20 +158,22 @@ describe("generateContent()", () => {
const result = await generateContent("key", "model", fakeRequestParams);
expect(result.response.text()).to.include("30 minutes of brewing");
expect(makeRequestStub).to.be.calledWith(
match.instanceOf(request.RequestUrl),
"model",
request.Task.GENERATE_CONTENT,
"key",
false,
match.any,
);
});
it("image rejected (400)", async () => {
const mockResponse = getMockResponse("unary-failure-image-rejected.json");
const mockFetch = stub(globalThis, "fetch").resolves({
ok: false,
status: 400,
json: mockResponse.json,
} as Response);
const errorJson = await mockResponse.json();
const makeRequestStub = stub(request, "makeRequest").rejects(
new Error(`[400 ] ${errorJson.error.message}`),
);
await expect(
generateContent("key", "model", fakeRequestParams),
).to.be.rejectedWith(/400.*invalid argument/);
expect(mockFetch).to.be.called;
expect(makeRequestStub).to.be.called;
});
});
14 changes: 3 additions & 11 deletions packages/main/src/methods/generate-content.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {
GenerateContentStreamResult,
RequestOptions,
} from "../../types";
import { RequestUrl, Task, makeRequest } from "../requests/request";
import { Task, makeRequest } from "../requests/request";
import { addHelpers } from "../requests/response-helpers";
import { processStream } from "../requests/stream-reader";

Expand All @@ -32,15 +32,11 @@ export async function generateContentStream(
params: GenerateContentRequest,
requestOptions?: RequestOptions,
): Promise<GenerateContentStreamResult> {
const url = new RequestUrl(
const response = await makeRequest(
model,
Task.STREAM_GENERATE_CONTENT,
apiKey,
/* stream */ true,
requestOptions,
);
const response = await makeRequest(
url,
JSON.stringify(params),
requestOptions,
);
Expand All @@ -53,15 +49,11 @@ export async function generateContent(
params: GenerateContentRequest,
requestOptions?: RequestOptions,
): Promise<GenerateContentResult> {
const url = new RequestUrl(
const response = await makeRequest(
model,
Task.GENERATE_CONTENT,
apiKey,
/* stream */ false,
requestOptions,
);
const response = await makeRequest(
url,
JSON.stringify(params),
requestOptions,
);
Expand Down
Loading

0 comments on commit 0931d2c

Please sign in to comment.