-
Notifications
You must be signed in to change notification settings - Fork 152
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0976d0f
commit 2063447
Showing
4 changed files
with
434 additions
and
0 deletions.
There are no files selected for viewing
128 changes: 128 additions & 0 deletions
128
Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
// Copyright 2023 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import GenerativeAIUIComponents | ||
import GoogleGenerativeAI | ||
import SwiftUI | ||
|
||
struct FunctionCallingScreen: View { | ||
@EnvironmentObject | ||
var viewModel: FunctionCallingViewModel | ||
|
||
@State | ||
private var userPrompt = "What is 100 Euros in U.S. Dollars?" | ||
|
||
enum FocusedField: Hashable { | ||
case message | ||
} | ||
|
||
@FocusState | ||
var focusedField: FocusedField? | ||
|
||
var body: some View { | ||
VStack { | ||
ScrollViewReader { scrollViewProxy in | ||
List { | ||
Text("Interact with a currency conversion API using function calling in Gemini.") | ||
ForEach(viewModel.messages) { message in | ||
MessageView(message: message) | ||
} | ||
if let error = viewModel.error { | ||
ErrorView(error: error) | ||
.tag("errorView") | ||
} | ||
} | ||
.listStyle(.plain) | ||
.onChange(of: viewModel.messages, perform: { newValue in | ||
if viewModel.hasError { | ||
// Wait for a short moment to make sure we can actually scroll to the bottom. | ||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { | ||
withAnimation { | ||
scrollViewProxy.scrollTo("errorView", anchor: .bottom) | ||
} | ||
focusedField = .message | ||
} | ||
} else { | ||
guard let lastMessage = viewModel.messages.last else { return } | ||
|
||
// Wait for a short moment to make sure we can actually scroll to the bottom. | ||
DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { | ||
withAnimation { | ||
scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom) | ||
} | ||
focusedField = .message | ||
} | ||
} | ||
}) | ||
} | ||
InputField("Message...", text: $userPrompt) { | ||
Image(systemName: viewModel.busy ? "stop.circle.fill" : "arrow.up.circle.fill") | ||
.font(.title) | ||
} | ||
.focused($focusedField, equals: .message) | ||
.onSubmit { sendOrStop() } | ||
} | ||
.toolbar { | ||
ToolbarItem(placement: .primaryAction) { | ||
Button(action: newChat) { | ||
Image(systemName: "square.and.pencil") | ||
} | ||
} | ||
} | ||
.navigationTitle("Function Calling") | ||
.onAppear { | ||
focusedField = .message | ||
} | ||
} | ||
|
||
private func sendMessage() { | ||
Task { | ||
let prompt = userPrompt | ||
userPrompt = "" | ||
await viewModel.sendMessage(prompt, streaming: true) | ||
} | ||
} | ||
|
||
private func sendOrStop() { | ||
if viewModel.busy { | ||
viewModel.stop() | ||
} else { | ||
sendMessage() | ||
} | ||
} | ||
|
||
private func newChat() { | ||
viewModel.startNewChat() | ||
} | ||
} | ||
|
||
struct FunctionCallingScreen_Previews: PreviewProvider { | ||
struct ContainerView: View { | ||
@EnvironmentObject | ||
var viewModel: FunctionCallingViewModel | ||
|
||
var body: some View { | ||
FunctionCallingScreen() | ||
.onAppear { | ||
viewModel.messages = ChatMessage.samples | ||
} | ||
} | ||
} | ||
|
||
static var previews: some View { | ||
NavigationStack { | ||
FunctionCallingScreen().environmentObject(FunctionCallingViewModel()) | ||
} | ||
} | ||
} |
266 changes: 266 additions & 0 deletions
266
Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,266 @@ | ||
// Copyright 2023 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Foundation | ||
import GoogleGenerativeAI | ||
import UIKit | ||
|
||
@MainActor | ||
class FunctionCallingViewModel: ObservableObject { | ||
/// This array holds both the user's and the system's chat messages | ||
@Published var messages = [ChatMessage]() | ||
|
||
/// Indicates we're waiting for the model to finish | ||
@Published var busy = false | ||
|
||
@Published var error: Error? | ||
var hasError: Bool { | ||
return error != nil | ||
} | ||
|
||
/// Function calls pending processing | ||
private var functionCalls = [FunctionCall]() | ||
|
||
private var model: GenerativeModel | ||
private var chat: Chat | ||
|
||
private var chatTask: Task<Void, Never>? | ||
|
||
init() { | ||
model = GenerativeModel( | ||
name: "gemini-1.0-pro", | ||
apiKey: APIKey.default, | ||
tools: [Tool(functionDeclarations: [ | ||
FunctionDeclaration( | ||
name: "get_exchange_rate", | ||
description: "Get the exchange rate for currencies between countries", | ||
parameters: [ | ||
"currency_from": Schema( | ||
type: .string, | ||
format: "enum", | ||
description: "The currency to convert from in ISO 4217 format", | ||
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] | ||
), | ||
"currency_to": Schema( | ||
type: .string, | ||
format: "enum", | ||
description: "The currency to convert to in ISO 4217 format", | ||
enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] | ||
), | ||
], | ||
requiredParameters: ["currency_from", "currency_to"] | ||
), | ||
])], | ||
requestOptions: RequestOptions(apiVersion: "v1beta") | ||
) | ||
chat = model.startChat() | ||
} | ||
|
||
func sendMessage(_ text: String, streaming: Bool = true) async { | ||
error = nil | ||
chatTask?.cancel() | ||
|
||
chatTask = Task { | ||
busy = true | ||
defer { | ||
busy = false | ||
} | ||
|
||
// first, add the user's message to the chat | ||
let userMessage = ChatMessage(message: text, participant: .user) | ||
messages.append(userMessage) | ||
|
||
// add a pending message while we're waiting for a response from the backend | ||
let systemMessage = ChatMessage.pending(participant: .system) | ||
messages.append(systemMessage) | ||
|
||
print(messages) | ||
do { | ||
repeat { | ||
if streaming { | ||
try await internalSendMessageStreaming(text) | ||
} else { | ||
try await internalSendMessage(text) | ||
} | ||
} while !functionCalls.isEmpty | ||
} catch { | ||
self.error = error | ||
print(error.localizedDescription) | ||
messages.removeLast() | ||
} | ||
} | ||
} | ||
|
||
func startNewChat() { | ||
stop() | ||
error = nil | ||
chat = model.startChat() | ||
messages.removeAll() | ||
} | ||
|
||
func stop() { | ||
chatTask?.cancel() | ||
error = nil | ||
} | ||
|
||
private func internalSendMessageStreaming(_ text: String) async throws { | ||
let functionResponses = try await processFunctionCalls() | ||
let responseStream: AsyncThrowingStream<GenerateContentResponse, Error> | ||
if functionResponses.isEmpty { | ||
responseStream = chat.sendMessageStream(text) | ||
} else { | ||
for functionResponse in functionResponses { | ||
messages.insert(functionResponse.chatMessage(), at: messages.count - 1) | ||
} | ||
responseStream = chat.sendMessageStream(functionResponses.modelContent()) | ||
} | ||
for try await chunk in responseStream { | ||
processResponseContent(content: chunk) | ||
} | ||
} | ||
|
||
private func internalSendMessage(_ text: String) async throws { | ||
let functionResponses = try await processFunctionCalls() | ||
let response: GenerateContentResponse | ||
if functionResponses.isEmpty { | ||
response = try await chat.sendMessage(text) | ||
} else { | ||
for functionResponse in functionResponses { | ||
messages.insert(functionResponse.chatMessage(), at: messages.count - 1) | ||
} | ||
response = try await chat.sendMessage(functionResponses.modelContent()) | ||
} | ||
processResponseContent(content: response) | ||
} | ||
|
||
func processResponseContent(content: GenerateContentResponse) { | ||
guard let candidate = content.candidates.first else { | ||
fatalError("No candidate.") | ||
} | ||
|
||
for part in candidate.content.parts { | ||
switch part { | ||
case let .text(text): | ||
// replace pending message with backend response | ||
messages[messages.count - 1].message += text | ||
messages[messages.count - 1].pending = false | ||
case let .functionCall(functionCall): | ||
messages.insert(functionCall.chatMessage(), at: messages.count - 1) | ||
functionCalls.append(functionCall) | ||
case .data, .functionResponse: | ||
fatalError("Unsupported response content.") | ||
} | ||
} | ||
} | ||
|
||
func processFunctionCalls() async throws -> [FunctionResponse] { | ||
var functionResponses = [FunctionResponse]() | ||
for functionCall in functionCalls { | ||
switch functionCall.name { | ||
case "get_exchange_rate": | ||
let exchangeRates = getExchangeRate(args: functionCall.args) | ||
functionResponses.append(FunctionResponse( | ||
name: "get_exchange_rate", | ||
response: exchangeRates | ||
)) | ||
default: | ||
fatalError("Unknown function named \"\(functionCall.name)\".") | ||
} | ||
} | ||
functionCalls = [] | ||
|
||
return functionResponses | ||
} | ||
|
||
// MARK: - Callable Functions | ||
|
||
func getExchangeRate(args: JSONObject) -> JSONObject { | ||
// 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) | ||
guard case let .string(from) = args["currency_from"] else { | ||
fatalError("Missing `currency_from` parameter.") | ||
} | ||
guard case let .string(to) = args["currency_to"] else { | ||
fatalError("Missing `currency_to` parameter.") | ||
} | ||
|
||
// 2. Get the exchange rate | ||
let allRates: [String: [String: Double]] = [ | ||
"AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379], | ||
"CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362], | ||
"EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932], | ||
"GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836], | ||
"JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679], | ||
"USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26], | ||
] | ||
guard let fromRates = allRates[from] else { | ||
return ["error": .string("No data for currency \(from).")] | ||
} | ||
guard let toRate = fromRates[to] else { | ||
return ["error": .string("No data for currency \(to).")] | ||
} | ||
|
||
// 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) | ||
return ["rates": .number(toRate)] | ||
} | ||
} | ||
|
||
private extension FunctionCall { | ||
func chatMessage() -> ChatMessage { | ||
let encoder = JSONEncoder() | ||
encoder.outputFormatting = .prettyPrinted | ||
|
||
let jsonData: Data | ||
do { | ||
jsonData = try encoder.encode(self) | ||
} catch { | ||
fatalError("JSON Encoding Failed: \(error.localizedDescription)") | ||
} | ||
guard let json = String(data: jsonData, encoding: .utf8) else { | ||
fatalError("Failed to convert JSON data to a String.") | ||
} | ||
let messageText = "Function call requested by model:\n```\n\(json)\n```" | ||
|
||
return ChatMessage(message: messageText, participant: .system) | ||
} | ||
} | ||
|
||
private extension FunctionResponse { | ||
func chatMessage() -> ChatMessage { | ||
let encoder = JSONEncoder() | ||
encoder.outputFormatting = .prettyPrinted | ||
|
||
let jsonData: Data | ||
do { | ||
jsonData = try encoder.encode(self) | ||
} catch { | ||
fatalError("JSON Encoding Failed: \(error.localizedDescription)") | ||
} | ||
guard let json = String(data: jsonData, encoding: .utf8) else { | ||
fatalError("Failed to convert JSON data to a String.") | ||
} | ||
let messageText = "Function response returned by app:\n```\n\(json)\n```" | ||
|
||
return ChatMessage(message: messageText, participant: .user) | ||
} | ||
} | ||
|
||
private extension [FunctionResponse] { | ||
func modelContent() -> [ModelContent] { | ||
return self.map { ModelContent( | ||
role: "function", | ||
parts: [ModelContent.Part.functionResponse($0)] | ||
) | ||
} | ||
} | ||
} |
Oops, something went wrong.