Skip to content

Commit

Permalink
Allow PartsRepresentable to throw errors (#88)
Browse files Browse the repository at this point in the history
* almost nonbreaking change

* style

* make partsvalue accessible only when error is never

* fix tests

* fix macos

* put errors into api methods

* style

* remove generic error

* add non-erroring protocol so force unwraps arent required

* api review feedback: use more specific error case and add failure tests

* specialize error

* style

* code feedback changes

* use partsValue

* use consistent closure name
  • Loading branch information
morganchen12 authored Feb 21, 2024
1 parent e2cebcd commit ac4aea1
Show file tree
Hide file tree
Showing 9 changed files with 340 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PhotoReasoningViewModel: ObservableObject {

let prompt = "Look at the image(s), and then answer the following question: \(userInput)"

var images = [PartsRepresentable]()
var images = [any ThrowingPartsRepresentable]()
for item in selectedItems {
if let data = try? await item.loadTransferable(type: Data.self) {
guard let image = UIImage(data: data) else {
Expand Down
40 changes: 33 additions & 7 deletions Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ public class Chat {
public var history: [ModelContent]

/// See ``sendMessage(_:)-3ify5``.
public func sendMessage(_ parts: PartsRepresentable...) async throws -> GenerateContentResponse {
public func sendMessage(_ parts: any ThrowingPartsRepresentable...) async throws
-> GenerateContentResponse {
return try await sendMessage([ModelContent(parts: parts)])
}

Expand All @@ -40,9 +41,19 @@ public class Chat {
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: The model's response if no error occurred.
/// - Throws: A ``GenerateContentError`` if an error occurred.
public func sendMessage(_ content: [ModelContent]) async throws -> GenerateContentResponse {
public func sendMessage(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> GenerateContentResponse {
// Ensure that the new content has the role set.
let newContent: [ModelContent] = content.map(populateContentRole(_:))
let newContent: [ModelContent]
do {
newContent = try content().map(populateContentRole(_:))
} catch let underlying {
if let contentError = underlying as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: contentError)
} else {
throw GenerateContentError.internalError(underlying: underlying)
}
}

// Send the history alongside the new message as context.
let request = history + newContent
Expand All @@ -67,24 +78,39 @@ public class Chat {

/// See ``sendMessageStream(_:)-4abs3``.
@available(macOS 12.0, *)
public func sendMessageStream(_ parts: PartsRepresentable...)
public func sendMessageStream(_ parts: any ThrowingPartsRepresentable...)
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return sendMessageStream([ModelContent(parts: parts)])
return try sendMessageStream([ModelContent(parts: parts)])
}

/// Sends a message using the existing history of this chat as context. If successful, the message
/// and response will be added to the history. If unsuccessful, history will remain unchanged.
/// - Parameter content: The new content to send as a single chat message.
/// - Returns: A stream containing the model's response or an error if an error occurred.
@available(macOS 12.0, *)
public func sendMessageStream(_ content: [ModelContent])
public func sendMessageStream(_ content: @autoclosure () throws -> [ModelContent])
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let resolvedContent: [ModelContent]
do {
resolvedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
}
}

return AsyncThrowingStream { continuation in
Task {
var aggregatedContent: [ModelContent] = []

// Ensure that the new content has the role set.
let newContent: [ModelContent] = content.map(populateContentRole(_:))
let newContent: [ModelContent] = resolvedContent.map(populateContentRole(_:))

// Send the history alongside the new message as context.
let request = history + newContent
Expand Down
3 changes: 3 additions & 0 deletions Sources/GoogleAI/GenerateContentError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import Foundation
/// Errors that occur when generating content from a model.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
public enum GenerateContentError: Error {
/// An error occurred when constructing the prompt. Examine the related error for details.
case promptImageContentError(underlying: ImageConversionError)

/// An internal error occurred. See the underlying error for more context.
case internalError(underlying: Error)

Expand Down
70 changes: 46 additions & 24 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,12 @@ public final class GenerativeModel {
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
/// prompts, see ``generateContent(_:)-58rm0``.
///
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
/// - Parameter content: The input(s) given to the model as a prompt (see
/// ``ThrowingPartsRepresentable``
/// for conforming types).
/// - Returns: The content generated by the model.
/// - Throws: A ``GenerateContentError`` if the request failed.
public func generateContent(_ parts: PartsRepresentable...)
public func generateContent(_ parts: any ThrowingPartsRepresentable...)
async throws -> GenerateContentResponse {
return try await generateContent([ModelContent(parts: parts)])
}
Expand All @@ -110,18 +111,21 @@ public final class GenerativeModel {
/// - Parameter content: The input(s) given to the model as a prompt.
/// - Returns: The generated content response from the model.
/// - Throws: A ``GenerateContentError`` if the request failed.
public func generateContent(_ content: [ModelContent]) async throws
public func generateContent(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> GenerateContentResponse {
let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: content,
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: false,
options: requestOptions)
let response: GenerateContentResponse
do {
let generateContentRequest = try GenerateContentRequest(model: modelResourceName,
contents: content(),
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: false,
options: requestOptions)
response = try await generativeAIService.loadRequest(request: generateContentRequest)
} catch {
if let imageError = error as? ImageConversionError {
throw GenerateContentError.promptImageContentError(underlying: imageError)
}
throw GenerativeModel.generateContentError(from: error)
}

Expand All @@ -148,14 +152,15 @@ public final class GenerativeModel {
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
/// prompts, see ``generateContent(_:)-58rm0``.
///
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
/// - Parameter content: The input(s) given to the model as a prompt (see
/// ``ThrowingPartsRepresentable``
/// for conforming types).
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ parts: PartsRepresentable...)
public func generateContentStream(_ parts: any ThrowingPartsRepresentable...)
-> AsyncThrowingStream<GenerateContentResponse, Error> {
return generateContentStream([ModelContent(parts: parts)])
return try generateContentStream([ModelContent(parts: parts)])
}

/// Generates new content from input content given to the model as a prompt.
Expand All @@ -164,10 +169,25 @@ public final class GenerativeModel {
/// - Returns: A stream wrapping content generated by the model or a ``GenerateContentError``
/// error if an error occurred.
@available(macOS 12.0, *)
public func generateContentStream(_ content: [ModelContent])
public func generateContentStream(_ content: @autoclosure () throws -> [ModelContent])
-> AsyncThrowingStream<GenerateContentResponse, Error> {
let evaluatedContent: [ModelContent]
do {
evaluatedContent = try content()
} catch let underlying {
return AsyncThrowingStream { continuation in
let error: Error
if let contentError = underlying as? ImageConversionError {
error = GenerateContentError.promptImageContentError(underlying: contentError)
} else {
error = GenerateContentError.internalError(underlying: underlying)
}
continuation.finish(throwing: error)
}
}

let generateContentRequest = GenerateContentRequest(model: modelResourceName,
contents: content,
contents: evaluatedContent,
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: true,
Expand Down Expand Up @@ -218,12 +238,14 @@ public final class GenerativeModel {
/// [few-shot](https://developers.google.com/machine-learning/glossary/generative#few-shot-prompting)
/// input, see ``countTokens(_:)-9spwl``.
///
/// - Parameter content: The input(s) given to the model as a prompt (see ``PartsRepresentable``
/// - Parameter content: The input(s) given to the model as a prompt (see
/// ``ThrowingPartsRepresentable``
/// for conforming types).
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
public func countTokens(_ parts: PartsRepresentable...) async throws -> CountTokensResponse {
public func countTokens(_ parts: any ThrowingPartsRepresentable...) async throws
-> CountTokensResponse {
return try await countTokens([ModelContent(parts: parts)])
}

Expand All @@ -232,16 +254,16 @@ public final class GenerativeModel {
/// - Parameter content: The input given to the model as a prompt.
/// - Returns: The results of running the model's tokenizer on the input; contains
/// ``CountTokensResponse/totalTokens``.
/// - Throws: A ``CountTokensError`` if the tokenization request failed.
public func countTokens(_ content: [ModelContent]) async throws
/// - Throws: A ``CountTokensError`` if the tokenization request failed or the input content was
/// invalid.
public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws
-> CountTokensResponse {
let countTokensRequest = CountTokensRequest(
model: modelResourceName,
contents: content,
options: requestOptions
)

do {
let countTokensRequest = try CountTokensRequest(
model: modelResourceName,
contents: content(),
options: requestOptions
)
return try await generativeAIService.loadRequest(request: countTokensRequest)
} catch {
throw CountTokensError.internalError(underlying: error)
Expand Down
25 changes: 21 additions & 4 deletions Sources/GoogleAI/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,14 @@ public struct ModelContent: Codable, Equatable {
public let parts: [Part]

/// Creates a new value from any data or `Array` of data interpretable as a
/// ``Part``. See ``PartsRepresentable`` for types that can be interpreted as `Part`s.
/// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s.
public init(role: String? = "user", parts: some ThrowingPartsRepresentable) throws {
self.role = role
try self.parts = parts.tryPartsValue()
}

/// Creates a new value from any data or `Array` of data interpretable as a
/// ``Part``. See ``ThrowingPartsRepresentable`` for types that can be interpreted as `Part`s.
public init(role: String? = "user", parts: some PartsRepresentable) {
self.role = role
self.parts = parts.partsValue
Expand All @@ -116,9 +123,19 @@ public struct ModelContent: Codable, Equatable {
self.parts = parts
}

/// Creates a new value from any data interpretable as a ``Part``. See ``PartsRepresentable``
/// Creates a new value from any data interpretable as a ``Part``. See
/// ``ThrowingPartsRepresentable``
/// for types that can be interpreted as `Part`s.
public init(role: String? = "user", _ parts: any ThrowingPartsRepresentable...) throws {
let content = try parts.flatMap { try $0.tryPartsValue() }
self.init(role: role, parts: content)
}

/// Creates a new value from any data interpretable as a ``Part``. See
/// ``ThrowingPartsRepresentable``
/// for types that can be interpreted as `Part`s.
public init(role: String? = "user", _ parts: PartsRepresentable...) {
self.init(role: role, parts: parts)
public init(role: String? = "user", _ parts: [PartsRepresentable]) {
let content = parts.flatMap { $0.partsValue }
self.init(role: role, parts: content)
}
}
107 changes: 107 additions & 0 deletions Sources/GoogleAI/PartsRepresentable+Image.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import UniformTypeIdentifiers
#if canImport(UIKit)
import UIKit // For UIImage extensions.
#elseif canImport(AppKit)
import AppKit // For NSImage extensions.
#endif

private let imageCompressionQuality: CGFloat = 0.8

/// An enum describing failures that can occur when converting image types to model content data.
/// For some image types like `CIImage`, creating valid model content requires creating a JPEG
/// representation of the image that may not yet exist, which may be computationally expensive.
public enum ImageConversionError: Error {
/// The image (the receiver of the call `toModelContentParts()`) was invalid.
case invalidUnderlyingImage

/// A valid image destination could not be allocated.
case couldNotAllocateDestination

/// JPEG image data conversion failed, accompanied by the original image, which may be an
/// instance of `NSImageRep`, `UIImage`, `CGImage`, or `CIImage`.
case couldNotConvertToJPEG(Any)
}

#if canImport(UIKit)
/// Enables images to be representable as ``ThrowingPartsRepresentable``.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension UIImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
guard let data = jpegData(compressionQuality: imageCompressionQuality) else {
throw ImageConversionError.couldNotConvertToJPEG(self)
}
return [ModelContent.Part.data(mimetype: "image/jpeg", data)]
}
}

#elseif canImport(AppKit)
/// Enables images to be representable as ``ThrowingPartsRepresentable``.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension NSImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
guard let cgImage = cgImage(forProposedRect: nil, context: nil, hints: nil) else {
throw ImageConversionError.invalidUnderlyingImage
}
let bmp = NSBitmapImageRep(cgImage: cgImage)
guard let data = bmp.representation(using: .jpeg, properties: [.compressionFactor: 0.8])
else {
throw ImageConversionError.couldNotConvertToJPEG(bmp)
}
return [ModelContent.Part.data(mimetype: "image/jpeg", data)]
}
}
#endif

/// Enables `CGImages` to be representable as model content.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension CGImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
let output = NSMutableData()
guard let imageDestination = CGImageDestinationCreateWithData(
output, UTType.jpeg.identifier as CFString, 1, nil
) else {
throw ImageConversionError.couldNotAllocateDestination
}
CGImageDestinationAddImage(imageDestination, self, nil)
CGImageDestinationSetProperties(imageDestination, [
kCGImageDestinationLossyCompressionQuality: imageCompressionQuality,
] as CFDictionary)
if CGImageDestinationFinalize(imageDestination) {
return [.data(mimetype: "image/jpeg", output as Data)]
}
throw ImageConversionError.couldNotConvertToJPEG(self)
}
}

/// Enables `CIImages` to be representable as model content.
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
extension CIImage: ThrowingPartsRepresentable {
public func tryPartsValue() throws -> [ModelContent.Part] {
let context = CIContext()
let jpegData = (colorSpace ?? CGColorSpace(name: CGColorSpace.sRGB))
.flatMap {
// The docs specify kCGImageDestinationLossyCompressionQuality as a supported option, but
// Swift's type system does not allow this.
// [kCGImageDestinationLossyCompressionQuality: imageCompressionQuality]
context.jpegRepresentation(of: self, colorSpace: $0, options: [:])
}
if let jpegData = jpegData {
return [.data(mimetype: "image/jpeg", jpegData)]
}
throw ImageConversionError.couldNotConvertToJPEG(self)
}
}
Loading

0 comments on commit ac4aea1

Please sign in to comment.