Skip to content

Commit

Permalink
Throw GenerateContentError.invalidAPIKey in generateContentStream (
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Jan 25, 2024
1 parent 31f762d commit db2da90
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
19 changes: 14 additions & 5 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,7 @@ public final class GenerativeModel {
do {
response = try await generativeAIService.loadRequest(request: generateContentRequest)
} catch {
if let error = error as? RPCError, error.isInvalidAPIKeyError() {
throw GenerateContentError.invalidAPIKey
}
throw GenerateContentError.internalError(underlying: error)
throw GenerativeModel.generateContentError(from: error)
}

// Check the prompt feedback to see if the prompt was blocked.
Expand Down Expand Up @@ -168,7 +165,7 @@ public final class GenerativeModel {
do {
response = try await responseIterator.next()
} catch {
throw GenerateContentError.internalError(underlying: error)
throw GenerativeModel.generateContentError(from: error)
}

// The responseIterator will return `nil` when it's done.
Expand Down Expand Up @@ -240,6 +237,18 @@ public final class GenerativeModel {
return modelResourcePrefix + name
}
}

/// Returns a `GenerateContentError` (for public consumption) from an internal error.
///
/// If `error` is already a `GenerateContentError` the error is returned unchanged.
private static func generateContentError(from error: Error) -> GenerateContentError {
if let error = error as? GenerateContentError {
return error
} else if let error = error as? RPCError, error.isInvalidAPIKeyError() {
return GenerateContentError.invalidAPIKey
}
return GenerateContentError.internalError(underlying: error)
}
}

/// See ``GenerativeModel/countTokens(_:)-9spwl``.
Expand Down
20 changes: 20 additions & 0 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,26 @@ final class GenerativeModelTests: XCTestCase {

// MARK: - Generate Content (Streaming)

func testGenerateContentStream_failureInvalidAPIKey() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
forResource: "unary-failure-api-key",
withExtension: "json"
)

do {
let stream = model.generateContentStream("Hi")
for try await _ in stream {
XCTFail("No content is there, this shouldn't happen.")
}
} catch GenerateContentError.invalidAPIKey {
// invalidAPIKey error is as expected, nothing else to check.
return
}

XCTFail("Should have caught an error.")
}

func testGenerateContentStream_failureEmptyContent() async throws {
MockURLProtocol
.requestHandler = try httpRequestHandler(
Expand Down

0 comments on commit db2da90

Please sign in to comment.