diff --git a/Sources/GoogleAI/GenerativeAIRequest.swift b/Sources/GoogleAI/GenerativeAIRequest.swift index 1afd65d..d468576 100644 --- a/Sources/GoogleAI/GenerativeAIRequest.swift +++ b/Sources/GoogleAI/GenerativeAIRequest.swift @@ -26,9 +26,8 @@ protocol GenerativeAIRequest: Encodable { /// Configuration parameters for sending requests to the backend. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) public struct RequestOptions { - /// The request’s timeout interval in seconds; if not specified uses the default value for a - /// `URLRequest`. - let timeout: TimeInterval? + /// The request’s timeout interval in seconds. + let timeout: TimeInterval /// The API version to use in requests to the backend. let apiVersion: String @@ -36,10 +35,9 @@ public struct RequestOptions { /// Initializes a request options object. /// /// - Parameters: - /// - timeout The request’s timeout interval in seconds; if not specified uses the default value - /// for a `URLRequest`. - /// - apiVersion The API version to use in requests to the backend; defaults to "v1beta". - public init(timeout: TimeInterval? = nil, apiVersion: String = "v1beta") { + /// - timeout: The request’s timeout interval in seconds; defaults to 300 seconds (5 minutes). + /// - apiVersion: The API version to use in requests to the backend; defaults to "v1beta". + public init(timeout: TimeInterval = 300.0, apiVersion: String = "v1beta") { self.timeout = timeout self.apiVersion = apiVersion } diff --git a/Sources/GoogleAI/GenerativeAIService.swift b/Sources/GoogleAI/GenerativeAIService.swift index 8d90473..0f32d6a 100644 --- a/Sources/GoogleAI/GenerativeAIService.swift +++ b/Sources/GoogleAI/GenerativeAIService.swift @@ -156,10 +156,7 @@ struct GenerativeAIService { let encoder = JSONEncoder() encoder.keyEncodingStrategy = .convertToSnakeCase urlRequest.httpBody = try encoder.encode(request) - - if let timeoutInterval = request.options.timeout { - urlRequest.timeoutInterval = timeoutInterval - } + urlRequest.timeoutInterval = request.options.timeout return urlRequest } diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index ccd8979..5a20343 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -611,6 +611,20 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.candidates.count, 1) } + func testGenerateContent_requestOptions_defaultTimeout() async throws { + let expectedTimeout = 300.0 // Default in timeout in RequestOptions() + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-basic-reply-short", + withExtension: "json", + timeout: expectedTimeout + ) + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + } + // MARK: - Generate Content (Streaming) func testGenerateContentStream_failureInvalidAPIKey() async throws { @@ -967,6 +981,25 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(responses, 1) } + func testGenerateContentStream_requestOptions_defaultTimeout() async throws { + let expectedTimeout = 300.0 // Default in timeout in RequestOptions() + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-basic-reply-short", + withExtension: "txt", + timeout: expectedTimeout + ) + + var responses = 0 + let stream = model.generateContentStream(testPrompt) + for try await content in stream { + XCTAssertNotNil(content.text) + responses += 1 + } + + XCTAssertEqual(responses, 1) + } + // MARK: - Count Tokens func testCountTokens_succeeds() async throws { @@ -1019,6 +1052,20 @@ final class GenerativeModelTests: XCTestCase { XCTAssertEqual(response.totalTokens, 6) } + func testCountTokens_requestOptions_defaultTimeout() async throws { + let expectedTimeout = 300.0 + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "success-total-tokens", + withExtension: "json", + timeout: expectedTimeout + ) + + let response = try await model.countTokens(testPrompt) + + XCTAssertEqual(response.totalTokens, 6) + } + // MARK: - Model Resource Name func testModelResourceName_noPrefix() async throws { @@ -1067,8 +1114,8 @@ final class GenerativeModelTests: XCTestCase { private func httpRequestHandler(forResource name: String, withExtension ext: String, statusCode: Int = 200, - timeout: TimeInterval = URLRequest - .defaultTimeoutInterval()) throws -> ((URLRequest) throws -> ( + timeout: TimeInterval = RequestOptions() + .timeout) throws -> ((URLRequest) throws -> ( URLResponse, AsyncLineSequence? )) {