From 0931d2ce051215db72785d76fe3ae4e0bc3b5475 Mon Sep 17 00:00:00 2001 From: Christina Holland Date: Mon, 8 Apr 2024 09:31:36 -0700 Subject: [PATCH] Refactor makeRequest (#87) --- .changeset/modern-brooms-sparkle.md | 5 + .github/workflows/test.yml | 2 +- packages/main/src/methods/count-tokens.ts | 8 +- packages/main/src/methods/embed-content.ts | 20 +-- .../main/src/methods/generate-content.test.ts | 46 +++-- packages/main/src/methods/generate-content.ts | 14 +- packages/main/src/requests/request.test.ts | 162 ++++++++++++------ packages/main/src/requests/request.ts | 76 ++++++-- 8 files changed, 228 insertions(+), 105 deletions(-) create mode 100644 .changeset/modern-brooms-sparkle.md diff --git a/.changeset/modern-brooms-sparkle.md b/.changeset/modern-brooms-sparkle.md new file mode 100644 index 00000000..cff92ae6 --- /dev/null +++ b/.changeset/modern-brooms-sparkle.md @@ -0,0 +1,5 @@ +--- +"@google/generative-ai": patch +--- + +Refactor makeRequest to make fetch mockable. diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d4ca9b71..e7b68293 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -26,7 +26,7 @@ jobs: matrix: # lock version 20 for now as 20.12.0 makes global fetch unstubbable # until we can rewrite tests to stub some other way - node-version: ['18.x', '20.11.1'] + node-version: ['18.x', '20.x'] steps: - uses: actions/checkout@v4 - name: Use Node.js diff --git a/packages/main/src/methods/count-tokens.ts b/packages/main/src/methods/count-tokens.ts index e370b81b..772d5b7b 100644 --- a/packages/main/src/methods/count-tokens.ts +++ b/packages/main/src/methods/count-tokens.ts @@ -20,7 +20,7 @@ import { CountTokensResponse, RequestOptions, } from "../../types"; -import { RequestUrl, Task, makeRequest } from "../requests/request"; +import { Task, makeRequest } from "../requests/request"; export async function countTokens( apiKey: string, @@ -28,15 +28,11 @@ export async function countTokens( params: CountTokensRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl( + const response = await makeRequest( model, Task.COUNT_TOKENS, apiKey, false, - requestOptions, - ); - const response = await makeRequest( - url, JSON.stringify({ ...params, model }), requestOptions, ); diff --git a/packages/main/src/methods/embed-content.ts b/packages/main/src/methods/embed-content.ts index 277eb54e..f5060e56 100644 --- a/packages/main/src/methods/embed-content.ts +++ b/packages/main/src/methods/embed-content.ts @@ -22,7 +22,7 @@ import { EmbedContentResponse, RequestOptions, } from "../../types"; -import { RequestUrl, Task, makeRequest } from "../requests/request"; +import { Task, makeRequest } from "../requests/request"; export async function embedContent( apiKey: string, @@ -30,15 +30,11 @@ export async function embedContent( params: EmbedContentRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl( + const response = await makeRequest( model, Task.EMBED_CONTENT, apiKey, false, - requestOptions, - ); - const response = await makeRequest( - url, JSON.stringify(params), requestOptions, ); @@ -51,20 +47,16 @@ export async function batchEmbedContents( params: BatchEmbedContentsRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl( - model, - Task.BATCH_EMBED_CONTENTS, - apiKey, - false, - requestOptions, - ); const requestsWithModel: EmbedContentRequest[] = params.requests.map( (request) => { return { ...request, model }; }, ); const response = await makeRequest( - url, + model, + Task.BATCH_EMBED_CONTENTS, + apiKey, + false, JSON.stringify({ requests: requestsWithModel }), requestOptions, ); diff --git a/packages/main/src/methods/generate-content.test.ts b/packages/main/src/methods/generate-content.test.ts index 1b6c2dbe..cea96bda 100644 --- a/packages/main/src/methods/generate-content.test.ts +++ b/packages/main/src/methods/generate-content.test.ts @@ -58,7 +58,10 @@ describe("generateContent()", () => { const result = await generateContent("key", "model", fakeRequestParams); expect(result.response.text()).to.include("Helena"); expect(makeRequestStub).to.be.calledWith( - match.instanceOf(request.RequestUrl), + "model", + request.Task.GENERATE_CONTENT, + "key", + false, match((value: string) => { return value.includes("contents"); }), @@ -73,7 +76,10 @@ describe("generateContent()", () => { expect(result.response.text()).to.include("Use Freshly Ground Coffee"); expect(result.response.text()).to.include("30 minutes of brewing"); expect(makeRequestStub).to.be.calledWith( - match.instanceOf(request.RequestUrl), + "model", + request.Task.GENERATE_CONTENT, + "key", + false, match.any, ); }); @@ -88,7 +94,10 @@ describe("generateContent()", () => { result.response.candidates[0].citationMetadata.citationSources.length, ).to.equal(1); expect(makeRequestStub).to.be.calledWith( - match.instanceOf(request.RequestUrl), + "model", + request.Task.GENERATE_CONTENT, + "key", + false, match.any, ); }); @@ -102,7 +111,10 @@ describe("generateContent()", () => { const result = await generateContent("key", "model", fakeRequestParams); expect(result.response.text).to.throw("SAFETY"); expect(makeRequestStub).to.be.calledWith( - match.instanceOf(request.RequestUrl), + "model", + request.Task.GENERATE_CONTENT, + "key", + false, match.any, ); }); @@ -116,7 +128,10 @@ describe("generateContent()", () => { const result = await generateContent("key", "model", fakeRequestParams); expect(result.response.text).to.throw("SAFETY"); expect(makeRequestStub).to.be.calledWith( - match.instanceOf(request.RequestUrl), + "model", + request.Task.GENERATE_CONTENT, + "key", + false, match.any, ); }); @@ -128,7 +143,10 @@ describe("generateContent()", () => { const result = await generateContent("key", "model", fakeRequestParams); expect(result.response.text()).to.equal(""); expect(makeRequestStub).to.be.calledWith( - match.instanceOf(request.RequestUrl), + "model", + request.Task.GENERATE_CONTENT, + "key", + false, match.any, ); }); @@ -140,20 +158,22 @@ describe("generateContent()", () => { const result = await generateContent("key", "model", fakeRequestParams); expect(result.response.text()).to.include("30 minutes of brewing"); expect(makeRequestStub).to.be.calledWith( - match.instanceOf(request.RequestUrl), + "model", + request.Task.GENERATE_CONTENT, + "key", + false, match.any, ); }); it("image rejected (400)", async () => { const mockResponse = getMockResponse("unary-failure-image-rejected.json"); - const mockFetch = stub(globalThis, "fetch").resolves({ - ok: false, - status: 400, - json: mockResponse.json, - } as Response); + const errorJson = await mockResponse.json(); + const makeRequestStub = stub(request, "makeRequest").rejects( + new Error(`[400 ] ${errorJson.error.message}`), + ); await expect( generateContent("key", "model", fakeRequestParams), ).to.be.rejectedWith(/400.*invalid argument/); - expect(mockFetch).to.be.called; + expect(makeRequestStub).to.be.called; }); }); diff --git a/packages/main/src/methods/generate-content.ts b/packages/main/src/methods/generate-content.ts index 60c39243..d9aa3aab 100644 --- a/packages/main/src/methods/generate-content.ts +++ b/packages/main/src/methods/generate-content.ts @@ -22,7 +22,7 @@ import { GenerateContentStreamResult, RequestOptions, } from "../../types"; -import { RequestUrl, Task, makeRequest } from "../requests/request"; +import { Task, makeRequest } from "../requests/request"; import { addHelpers } from "../requests/response-helpers"; import { processStream } from "../requests/stream-reader"; @@ -32,15 +32,11 @@ export async function generateContentStream( params: GenerateContentRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl( + const response = await makeRequest( model, Task.STREAM_GENERATE_CONTENT, apiKey, /* stream */ true, - requestOptions, - ); - const response = await makeRequest( - url, JSON.stringify(params), requestOptions, ); @@ -53,15 +49,11 @@ export async function generateContent( params: GenerateContentRequest, requestOptions?: RequestOptions, ): Promise { - const url = new RequestUrl( + const response = await makeRequest( model, Task.GENERATE_CONTENT, apiKey, /* stream */ false, - requestOptions, - ); - const response = await makeRequest( - url, JSON.stringify(params), requestOptions, ); diff --git a/packages/main/src/requests/request.test.ts b/packages/main/src/requests/request.test.ts index 6ad02c80..288f1d9a 100644 --- a/packages/main/src/requests/request.test.ts +++ b/packages/main/src/requests/request.test.ts @@ -16,7 +16,7 @@ */ import { expect, use } from "chai"; -import { restore, stub } from "sinon"; +import { match, restore, stub } from "sinon"; import * as sinonChai from "sinon-chai"; import * as chaiAsPromised from "chai-as-promised"; import { @@ -24,20 +24,13 @@ import { DEFAULT_BASE_URL, RequestUrl, Task, - makeRequest, + _makeRequestInternal, + constructRequest, } from "./request"; use(sinonChai); use(chaiAsPromised); -const fakeRequestUrl = new RequestUrl( - "model-name", - Task.GENERATE_CONTENT, - "key", - true, - {}, -); - describe("request methods", () => { afterEach(() => { restore(); @@ -113,80 +106,139 @@ describe("request methods", () => { expect(url.toString()).to.not.include("alt=sse"); }); }); - describe("makeRequest", () => { - it("no error", async () => { - const fetchStub = stub(globalThis, "fetch").resolves({ - ok: true, - } as Response); - const response = await makeRequest(fakeRequestUrl, ""); - expect(fetchStub).to.be.calledWith(fakeRequestUrl.toString(), { - method: "POST", - headers: { - "Content-Type": "application/json", - "x-goog-api-client": "genai-js/__PACKAGE_VERSION__", - "x-goog-api-key": fakeRequestUrl.apiKey, - }, - body: "", - }); - expect(response.ok).to.be.true; + describe("constructRequest", () => { + it("handles basic request", async () => { + const request = await constructRequest( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + {}, + ); + expect( + (request.fetchOptions.headers as Headers).get("x-goog-api-client"), + ).to.equal("genai-js/__PACKAGE_VERSION__"); + expect( + (request.fetchOptions.headers as Headers).get("x-goog-api-key"), + ).to.equal("key"); + expect( + (request.fetchOptions.headers as Headers).get("Content-Type"), + ).to.equal("application/json"); }); it("passes apiClient", async () => { - const fetchStub = stub(globalThis, "fetch").resolves({ + const request = await constructRequest( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + { + apiClient: "client/version", + }, + ); + expect( + (request.fetchOptions.headers as Headers).get("x-goog-api-client"), + ).to.equal("client/version genai-js/__PACKAGE_VERSION__"); + }); + it("passes timeout", async () => { + const request = await constructRequest( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + { + timeout: 5000, + }, + ); + expect(request.fetchOptions.signal).to.be.instanceOf(AbortSignal); + }); + }); + describe("_makeRequestInternal", () => { + it("no error", async () => { + const fetchStub = stub().resolves({ ok: true, } as Response); - const response = await makeRequest(fakeRequestUrl, "", { - apiClient: "client/version", - }); - expect(fetchStub).to.be.calledWith(fakeRequestUrl.toString(), { + const response = await _makeRequestInternal( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + {}, + fetchStub as typeof fetch, + ); + expect(fetchStub).to.be.calledWith(match.string, { method: "POST", - headers: { - "Content-Type": "application/json", - "x-goog-api-client": "client/version genai-js/__PACKAGE_VERSION__", - "x-goog-api-key": fakeRequestUrl.apiKey, - }, + headers: match.instanceOf(Headers), body: "", }); expect(response.ok).to.be.true; }); it("error with timeout", async () => { - const fetchStub = stub(globalThis, "fetch").resolves({ + const fetchStub = stub().resolves({ ok: false, status: 500, statusText: "AbortError", } as Response); await expect( - makeRequest(fakeRequestUrl, "", { - timeout: 0, - }), + _makeRequestInternal( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + { + timeout: 0, + }, + fetchStub as typeof fetch, + ), ).to.be.rejectedWith("500 AbortError"); expect(fetchStub).to.be.calledOnce; }); it("Network error, no response.json()", async () => { - const fetchStub = stub(globalThis, "fetch").resolves({ + const fetchStub = stub().resolves({ ok: false, status: 500, statusText: "Server Error", } as Response); - await expect(makeRequest(fakeRequestUrl, "")).to.be.rejectedWith( - /500 Server Error/, - ); + await expect( + _makeRequestInternal( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + {}, + fetchStub as typeof fetch, + ), + ).to.be.rejectedWith(/500 Server Error/); expect(fetchStub).to.be.calledOnce; }); it("Network error, includes response.json()", async () => { - const fetchStub = stub(globalThis, "fetch").resolves({ + const fetchStub = stub().resolves({ ok: false, status: 500, statusText: "Server Error", json: () => Promise.resolve({ error: { message: "extra info" } }), } as Response); - await expect(makeRequest(fakeRequestUrl, "")).to.be.rejectedWith( - /500 Server Error.*extra info/, - ); + await expect( + _makeRequestInternal( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + {}, + fetchStub as typeof fetch, + ), + ).to.be.rejectedWith(/500 Server Error.*extra info/); expect(fetchStub).to.be.calledOnce; }); it("Network error, includes response.json() and details", async () => { - const fetchStub = stub(globalThis, "fetch").resolves({ + const fetchStub = stub().resolves({ ok: false, status: 500, statusText: "Server Error", @@ -204,7 +256,17 @@ describe("request methods", () => { }, }), } as Response); - await expect(makeRequest(fakeRequestUrl, "")).to.be.rejectedWith( + await expect( + _makeRequestInternal( + "model-name", + Task.GENERATE_CONTENT, + "key", + true, + "", + {}, + fetchStub as typeof fetch, + ), + ).to.be.rejectedWith( /500 Server Error.*extra info.*generic::invalid_argument/, ); expect(fetchStub).to.be.calledOnce; diff --git a/packages/main/src/requests/request.ts b/packages/main/src/requests/request.ts index 7ab86934..5f30aad2 100644 --- a/packages/main/src/requests/request.ts +++ b/packages/main/src/requests/request.ts @@ -68,23 +68,79 @@ export function getClientHeaders(requestOptions: RequestOptions): string { return clientHeaders.join(" "); } +export async function getHeaders(url: RequestUrl): Promise { + const headers = new Headers(); + headers.append("Content-Type", "application/json"); + headers.append("x-goog-api-client", getClientHeaders(url.requestOptions)); + headers.append("x-goog-api-key", url.apiKey); + return headers; +} + +export async function constructRequest( + model: string, + task: Task, + apiKey: string, + stream: boolean, + body: string, + requestOptions?: RequestOptions, +): Promise<{ url: string; fetchOptions: RequestInit }> { + const url = new RequestUrl(model, task, apiKey, stream, requestOptions); + return { + url: url.toString(), + fetchOptions: { + ...buildFetchOptions(requestOptions), + method: "POST", + headers: await getHeaders(url), + body, + }, + }; +} + +/** + * Wrapper for _makeRequestInternal that automatically uses native fetch, + * allowing _makeRequestInternal to be tested with a mocked fetch function. + */ export async function makeRequest( - url: RequestUrl, + model: string, + task: Task, + apiKey: string, + stream: boolean, body: string, requestOptions?: RequestOptions, ): Promise { + return _makeRequestInternal( + model, + task, + apiKey, + stream, + body, + requestOptions, + fetch, + ); +} + +export async function _makeRequestInternal( + model: string, + task: Task, + apiKey: string, + stream: boolean, + body: string, + requestOptions?: RequestOptions, + // Allows this to be stubbed for tests + fetchFn = fetch, +): Promise { + const url = new RequestUrl(model, task, apiKey, stream, requestOptions); let response; try { - response = await fetch(url.toString(), { - ...buildFetchOptions(requestOptions), - method: "POST", - headers: { - "Content-Type": "application/json", - "x-goog-api-client": getClientHeaders(requestOptions), - "x-goog-api-key": url.apiKey, - }, + const request = await constructRequest( + model, + task, + apiKey, + stream, body, - }); + requestOptions, + ); + response = await fetchFn(request.url, request.fetchOptions); if (!response.ok) { let message = ""; try {