Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add automatic function calling prototype #118

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions Examples/GenerativeAICLI/Sources/GenerateContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,20 @@ struct GenerateContent: AsyncParsableCommand {
name: modelNameOrDefault(),
apiKey: apiKey,
generationConfig: config,
safetySettings: safetySettings
safetySettings: safetySettings,
tools: [Tool(functionDeclarations: [
FunctionDeclaration(
name: "get_exchange_rate",
description: "Get the exchange rate for currencies between countries",
parameters: getExchangeRateSchema(),
function: getExchangeRateWrapper
),
])],
requestOptions: RequestOptions(apiVersion: "v1beta")
)

let chat = model.startChat()

var parts = [ModelContent.Part]()

if let textPrompt = textPrompt {
Expand All @@ -96,15 +107,16 @@ struct GenerateContent: AsyncParsableCommand {
let input = [ModelContent(parts: parts)]

if isStreaming {
let contentStream = model.generateContentStream(input)
let contentStream = chat.sendMessageStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
if let text = content.text {
print(text)
}
}
} else {
let content = try await model.generateContent(input)
// Unary generate content
let content = try await chat.sendMessage(input)
if let text = content.text {
print("Generated Content:\n\(text)")
}
Expand All @@ -123,6 +135,76 @@ struct GenerateContent: AsyncParsableCommand {
return "gemini-1.0-pro"
}
}

// MARK: - Callable Functions

// Returns exchange rates from the Frankfurter API
// This is an example function that a developer might provide.
func getExchangeRate(amount: Double, date: String, from: String,
to: String) async throws -> String {
var urlComponents = URLComponents(string: "https://api.frankfurter.app")!
urlComponents.path = "/\(date)"
urlComponents.queryItems = [
.init(name: "amount", value: String(amount)),
.init(name: "from", value: from),
.init(name: "to", value: to),
]

let (data, _) = try await URLSession.shared.data(from: urlComponents.url!)
return String(data: data, encoding: .utf8)!
}

// This is a wrapper for the `getExchangeRate` function.
func getExchangeRateWrapper(args: JSONObject) async throws -> JSONObject {
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`)
guard case let .string(date) = args["currency_date"] else {
fatalError()
}
guard case let .string(from) = args["currency_from"] else {
fatalError()
}
guard case let .string(to) = args["currency_to"] else {
fatalError()
}
guard case let .number(amount) = args["amount"] else {
fatalError()
}

// 2. Call the wrapped function
let response = try await getExchangeRate(amount: amount, date: date, from: from, to: to)

// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`)
return ["content": .string(response)]
}

// Returns the schema of the `getExchangeRate` function
func getExchangeRateSchema() -> Schema {
return Schema(
type: .object,
properties: [
"currency_date": Schema(
type: .string,
description: """
A date that must always be in YYYY-MM-DD format or the value 'latest' if a time period
is not specified
"""
),
"currency_from": Schema(
type: .string,
description: "The currency to convert from in ISO 4217 format"
),
"currency_to": Schema(
type: .string,
description: "The currency to convert to in ISO 4217 format"
),
"amount": Schema(
type: .number,
description: "The amount of currency to convert as a double value"
),
],
required: ["currency_date", "currency_from", "currency_to", "amount"]
)
}
}

enum CLIError: Error {
Expand Down
32 changes: 31 additions & 1 deletion Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,32 @@ public class Chat {
// Make sure we inject the role into the content received.
let toAdd = ModelContent(role: "model", parts: reply.parts)

var functionResponses = [FunctionResponse]()
for part in reply.parts {
if case let .functionCall(functionCall) = part {
try functionResponses.append(await model.executeFunction(functionCall: functionCall))
}
}

// Call the functions requested by the model, if any.
let functionResponseContent = try ModelContent(
role: "function",
functionResponses.map { functionResponse in
ModelContent.Part.functionResponse(functionResponse)
}
)

// Append the request and successful result to history, then return the value.
history.append(contentsOf: newContent)
history.append(toAdd)
return result

// If no function calls requested, return the results.
if functionResponses.isEmpty {
return result
}

// Re-send the message with the function responses.
return try await sendMessage([functionResponseContent])
}

/// See ``sendMessageStream(_:)-4abs3``.
Expand Down Expand Up @@ -162,6 +184,14 @@ public class Chat {
}

parts.append(part)

case .functionCall:
// TODO(andrewheard): Add function call to the chat history when encoding is implemented.
fatalError("Function calling not yet implemented in chat.")

case .functionResponse:
// TODO(andrewheard): Add function response to chat history when encoding is implemented.
fatalError("Function calling not yet implemented in chat.")
}
}
}
Expand Down
139 changes: 139 additions & 0 deletions Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

/// A predicted function call returned from the model.
///
/// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functioncall
public struct FunctionCall: Equatable, Encodable {
/// The name of the function to call.
public let name: String

/// The function parameters and values.
public let args: JSONObject
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#schema
public class Schema: Encodable {
let type: DataType

let format: String?

let description: String?

let nullable: Bool?

let enumValues: [String]?

let items: Schema?

let properties: [String: Schema]?

let required: [String]?

public init(type: DataType, format: String? = nil, description: String? = nil,
nullable: Bool? = nil,
enumValues: [String]? = nil, items: Schema? = nil,
properties: [String: Schema]? = nil,
required: [String]? = nil) {
self.type = type
self.format = format
self.description = description
self.nullable = nullable
self.enumValues = enumValues
self.items = items
self.properties = properties
self.required = required
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#Type
public enum DataType: String, Encodable {
case string = "STRING"
case number = "NUMBER"
case integer = "INTEGER"
case boolean = "BOOLEAN"
case array = "ARRAY"
case object = "OBJECT"
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool#FunctionDeclaration
public struct FunctionDeclaration {
let name: String

let description: String

let parameters: Schema

let function: ((JSONObject) async throws -> JSONObject)?

public init(name: String, description: String, parameters: Schema,
function: ((JSONObject) async throws -> JSONObject)?) {
self.name = name
self.description = description
self.parameters = parameters
self.function = function
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Tool
public struct Tool: Encodable {
let functionDeclarations: [FunctionDeclaration]?

public init(functionDeclarations: [FunctionDeclaration]?) {
self.functionDeclarations = functionDeclarations
}
}

// REST Docs: https://ai.google.dev/api/rest/v1beta/Content#functionresponse
public struct FunctionResponse: Equatable, Encodable {
let name: String

let response: JSONObject
}

// MARK: - Codable Conformance

extension FunctionCall: Decodable {
enum CodingKeys: CodingKey {
case name
case args
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
name = try container.decode(String.self, forKey: .name)
if let args = try container.decodeIfPresent(JSONObject.self, forKey: .args) {
self.args = args
} else {
args = JSONObject()
}
}
}

extension FunctionDeclaration: Encodable {
enum CodingKeys: String, CodingKey {
case name
case description
case parameters
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encode(name, forKey: .name)
try container.encode(description, forKey: .description)
try container.encode(parameters, forKey: .parameters)
}
}
2 changes: 2 additions & 0 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct GenerateContentRequest {
let contents: [ModelContent]
let generationConfig: GenerationConfig?
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let isStreaming: Bool
let options: RequestOptions
}
Expand All @@ -31,6 +32,7 @@ extension GenerateContentRequest: Encodable {
case contents
case generationConfig
case safetySettings
case tools
}
}

Expand Down
Loading
Loading