diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index d59f660..f5e4b21 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -108,10 +108,7 @@ public final class GenerativeModel { do { response = try await generativeAIService.loadRequest(request: generateContentRequest) } catch { - if let error = error as? RPCError, error.isInvalidAPIKeyError() { - throw GenerateContentError.invalidAPIKey - } - throw GenerateContentError.internalError(underlying: error) + throw GenerativeModel.generateContentError(from: error) } // Check the prompt feedback to see if the prompt was blocked. @@ -168,7 +165,7 @@ public final class GenerativeModel { do { response = try await responseIterator.next() } catch { - throw GenerateContentError.internalError(underlying: error) + throw GenerativeModel.generateContentError(from: error) } // The responseIterator will return `nil` when it's done. @@ -240,6 +237,18 @@ public final class GenerativeModel { return modelResourcePrefix + name } } + + /// Returns a `GenerateContentError` (for public consumption) from an internal error. + /// + /// If `error` is already a `GenerateContentError` the error is returned unchanged. + private static func generateContentError(from error: Error) -> GenerateContentError { + if let error = error as? GenerateContentError { + return error + } else if let error = error as? RPCError, error.isInvalidAPIKeyError() { + return GenerateContentError.invalidAPIKey + } + return GenerateContentError.internalError(underlying: error) + } } /// See ``GenerativeModel/countTokens(_:)-9spwl``. diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index bb50e65..db19ddc 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -439,6 +439,26 @@ final class GenerativeModelTests: XCTestCase { // MARK: - Generate Content (Streaming) + func testGenerateContentStream_failureInvalidAPIKey() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-failure-api-key", + withExtension: "json" + ) + + do { + let stream = model.generateContentStream("Hi") + for try await _ in stream { + XCTFail("No content is there, this shouldn't happen.") + } + } catch GenerateContentError.invalidAPIKey { + // invalidAPIKey error is as expected, nothing else to check. + return + } + + XCTFail("Should have caught an error.") + } + func testGenerateContentStream_failureEmptyContent() async throws { MockURLProtocol .requestHandler = try httpRequestHandler(