From e6f66a92373382b811985795212cb9d7e545a364 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 15 May 2024 16:25:37 -0400 Subject: [PATCH 1/2] Add default RequestOptions.timeout of 300 seconds --- Sources/GoogleAI/GenerativeAIRequest.swift | 7 +- .../GoogleAITests/GenerativeModelTests.swift | 78 ++++++++++++++++++- 2 files changed, 78 insertions(+), 7 deletions(-) diff --git a/Sources/GoogleAI/GenerativeAIRequest.swift b/Sources/GoogleAI/GenerativeAIRequest.swift index 1afd65d..f5c5aff 100644 --- a/Sources/GoogleAI/GenerativeAIRequest.swift +++ b/Sources/GoogleAI/GenerativeAIRequest.swift @@ -36,10 +36,9 @@ public struct RequestOptions { /// Initializes a request options object. /// /// - Parameters: - /// - timeout The request’s timeout interval in seconds; if not specified uses the default value - /// for a `URLRequest`. - /// - apiVersion The API version to use in requests to the backend; defaults to "v1beta". - public init(timeout: TimeInterval? = nil, apiVersion: String = "v1beta") { + /// - timeout: The request’s timeout interval in seconds; defaults to 300 seconds (5 minutes). + /// - apiVersion: The API version to use in requests to the backend; defaults to "v1beta". + public init(timeout: TimeInterval? = 300.0, apiVersion: String = "v1beta") { self.timeout = timeout self.apiVersion = apiVersion } diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index ccd8979..79f9b8d 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -611,6 +611,27 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.candidates.count, 1) } + func testGenerateContent_requestOptions_nilTimeout() async throws { + let expectedTimeout: TimeInterval? = nil + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + timeout: expectedTimeout + ) + let requestOptions = RequestOptions(timeout: expectedTimeout) + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + requestOptions: requestOptions, + urlSession: urlSession + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + } + // MARK: - Generate Content (Streaming) func testGenerateContentStream_failureInvalidAPIKey() async throws { @@ -967,6 +988,32 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(responses, 1) } + func testGenerateContentStream_requestOptions_nilTimeout() async throws { + let expectedTimeout: TimeInterval? = nil + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + timeout: expectedTimeout + ) + let requestOptions = RequestOptions(timeout: expectedTimeout) + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + requestOptions: requestOptions, + urlSession: urlSession + ) + + var responses = 0 + let stream = model.generateContentStream(testPrompt) + for try await content in stream { + XCTAssertNotNil(content.text) + responses += 1 + } + + XCTAssertEqual(responses, 1) + } + // MARK: - Count Tokens func testCountTokens_succeeds() async throws { @@ -1019,6 +1066,27 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.totalTokens, 6) } + func testCountTokens_requestOptions_nilTimeout() async throws { + let expectedTimeout: TimeInterval? = nil + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "success-total-tokens", + withExtension: "json", + timeout: expectedTimeout + ) + let requestOptions = RequestOptions(timeout: expectedTimeout) + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + requestOptions: requestOptions, + urlSession: urlSession + ) + + let response = try await model.countTokens(testPrompt) + + XCTAssertEqual(response.totalTokens, 6) + } + // MARK: - Model Resource Name func testModelResourceName_noPrefix() async throws { @@ -1067,8 +1135,8 @@ final class GenerativeModelTests: XCTestCase { private func httpRequestHandler(forResource name: String, withExtension ext: String, statusCode: Int = 200, - timeout: TimeInterval = URLRequest - .defaultTimeoutInterval()) throws -> ((URLRequest) throws -> ( + timeout: TimeInterval? = RequestOptions() + .timeout) throws -> ((URLRequest) throws -> ( URLResponse, AsyncLineSequence? )) { @@ -1076,7 +1144,11 @@ final class GenerativeModelTests: XCTestCase { return { request in let requestURL = try XCTUnwrap(request.url) XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1) - XCTAssertEqual(request.timeoutInterval, timeout) + if let timeout { + XCTAssertEqual(request.timeoutInterval, timeout) + } else { + XCTAssertEqual(request.timeoutInterval, URLRequest.defaultTimeoutInterval()) + } let response = try XCTUnwrap(HTTPURLResponse( url: requestURL, statusCode: statusCode, From 587769627b9b0b011b9f34f824a034276b9d051a Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 15 May 2024 17:47:38 -0400 Subject: [PATCH 2/2] Make `timeout: TimeInterval?` non-optional --- Sources/GoogleAI/GenerativeAIRequest.swift | 7 ++-- Sources/GoogleAI/GenerativeAIService.swift | 5 +-- .../GoogleAITests/GenerativeModelTests.swift | 41 ++++--------------- 3 files changed, 12 insertions(+), 41 deletions(-) diff --git a/Sources/GoogleAI/GenerativeAIRequest.swift b/Sources/GoogleAI/GenerativeAIRequest.swift index f5c5aff..d468576 100644 --- a/Sources/GoogleAI/GenerativeAIRequest.swift +++ b/Sources/GoogleAI/GenerativeAIRequest.swift @@ -26,9 +26,8 @@ protocol GenerativeAIRequest: Encodable { /// Configuration parameters for sending requests to the backend. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) public struct RequestOptions { - /// The request’s timeout interval in seconds; if not specified uses the default value for a - /// `URLRequest`. - let timeout: TimeInterval? + /// The request’s timeout interval in seconds. + let timeout: TimeInterval /// The API version to use in requests to the backend. let apiVersion: String @@ -38,7 +37,7 @@ public struct RequestOptions { /// - Parameters: /// - timeout: The request’s timeout interval in seconds; defaults to 300 seconds (5 minutes). /// - apiVersion: The API version to use in requests to the backend; defaults to "v1beta". - public init(timeout: TimeInterval? = 300.0, apiVersion: String = "v1beta") { + public init(timeout: TimeInterval = 300.0, apiVersion: String = "v1beta") { self.timeout = timeout self.apiVersion = apiVersion } diff --git a/Sources/GoogleAI/GenerativeAIService.swift b/Sources/GoogleAI/GenerativeAIService.swift index 8d90473..0f32d6a 100644 --- a/Sources/GoogleAI/GenerativeAIService.swift +++ b/Sources/GoogleAI/GenerativeAIService.swift @@ -156,10 +156,7 @@ struct GenerativeAIService { let encoder = JSONEncoder() encoder.keyEncodingStrategy = .convertToSnakeCase urlRequest.httpBody = try encoder.encode(request) - - if let timeoutInterval = request.options.timeout { - urlRequest.timeoutInterval = timeoutInterval - } + urlRequest.timeoutInterval = request.options.timeout return urlRequest } diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 79f9b8d..5a20343 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -611,21 +611,14 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.candidates.count, 1) } - func testGenerateContent_requestOptions_nilTimeout() async throws { - let expectedTimeout: TimeInterval? = nil + func testGenerateContent_requestOptions_defaultTimeout() async throws { + let expectedTimeout = 300.0 // Default in timeout in RequestOptions() MockURLProtocol .requestHandler = try httpRequestHandler( forResource: "unary-success-basic-reply-short", withExtension: "json", timeout: expectedTimeout ) - let requestOptions = RequestOptions(timeout: expectedTimeout) - model = GenerativeModel( - name: "my-model", - apiKey: "API_KEY", - requestOptions: requestOptions, - urlSession: urlSession - ) let response = try await model.generateContent(testPrompt) @@ -988,21 +981,14 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(responses, 1) } - func testGenerateContentStream_requestOptions_nilTimeout() async throws { - let expectedTimeout: TimeInterval? = nil + func testGenerateContentStream_requestOptions_defaultTimeout() async throws { + let expectedTimeout = 300.0 // Default in timeout in RequestOptions() MockURLProtocol .requestHandler = try httpRequestHandler( forResource: "streaming-success-basic-reply-short", withExtension: "txt", timeout: expectedTimeout ) - let requestOptions = RequestOptions(timeout: expectedTimeout) - model = GenerativeModel( - name: "my-model", - apiKey: "API_KEY", - requestOptions: requestOptions, - urlSession: urlSession - ) var responses = 0 let stream = model.generateContentStream(testPrompt) @@ -1066,21 +1052,14 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.totalTokens, 6) } - func testCountTokens_requestOptions_nilTimeout() async throws { - let expectedTimeout: TimeInterval? = nil + func testCountTokens_requestOptions_defaultTimeout() async throws { + let expectedTimeout = 300.0 MockURLProtocol .requestHandler = try httpRequestHandler( forResource: "success-total-tokens", withExtension: "json", timeout: expectedTimeout ) - let requestOptions = RequestOptions(timeout: expectedTimeout) - model = GenerativeModel( - name: "my-model", - apiKey: "API_KEY", - requestOptions: requestOptions, - urlSession: urlSession - ) let response = try await model.countTokens(testPrompt) @@ -1135,7 +1114,7 @@ final class GenerativeModelTests: XCTestCase { private func httpRequestHandler(forResource name: String, withExtension ext: String, statusCode: Int = 200, - timeout: TimeInterval? = RequestOptions() + timeout: TimeInterval = RequestOptions() .timeout) throws -> ((URLRequest) throws -> ( URLResponse, AsyncLineSequence? @@ -1144,11 +1123,7 @@ final class GenerativeModelTests: XCTestCase { return { request in let requestURL = try XCTUnwrap(request.url) XCTAssertEqual(requestURL.path.occurrenceCount(of: "models/"), 1) - if let timeout { - XCTAssertEqual(request.timeoutInterval, timeout) - } else { - XCTAssertEqual(request.timeoutInterval, URLRequest.defaultTimeoutInterval()) - } + XCTAssertEqual(request.timeoutInterval, timeout) let response = try XCTUnwrap(HTTPURLResponse( url: requestURL, statusCode: statusCode,