From ac26decd2b1afc7875fbdf6c5a4e99425c154de4 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 9 Apr 2024 21:06:41 -0400 Subject: [PATCH 1/4] Add function calling sample --- .../Screens/FunctionCallingScreen.swift | 128 ++++++++ .../ViewModels/FunctionCallingViewModel.swift | 273 ++++++++++++++++++ .../project.pbxproj | 32 ++ .../GenerativeAISample/ContentView.swift | 8 + 4 files changed, 441 insertions(+) create mode 100644 Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift create mode 100644 Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift diff --git a/Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift b/Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift new file mode 100644 index 0000000..a210978 --- /dev/null +++ b/Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift @@ -0,0 +1,128 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import GenerativeAIUIComponents +import GoogleGenerativeAI +import SwiftUI + +struct FunctionCallingScreen: View { + @EnvironmentObject + var viewModel: FunctionCallingViewModel + + @State + private var userPrompt = "What is 100 Euros in U.S. Dollars?" + + enum FocusedField: Hashable { + case message + } + + @FocusState + var focusedField: FocusedField? + + var body: some View { + VStack { + ScrollViewReader { scrollViewProxy in + List { + Text("Interact with a currency conversion API using function calling in Gemini.") + ForEach(viewModel.messages) { message in + MessageView(message: message) + } + if let error = viewModel.error { + ErrorView(error: error) + .tag("errorView") + } + } + .listStyle(.plain) + .onChange(of: viewModel.messages, perform: { newValue in + if viewModel.hasError { + // wait for a short moment to make sure we can actually scroll to the bottom + DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { + withAnimation { + scrollViewProxy.scrollTo("errorView", anchor: .bottom) + } + focusedField = .message + } + } else { + guard let lastMessage = viewModel.messages.last else { return } + + // wait for a short moment to make sure we can actually scroll to the bottom + DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { + withAnimation { + scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom) + } + focusedField = .message + } + } + }) + } + InputField("Message...", text: $userPrompt) { + Image(systemName: viewModel.busy ? "stop.circle.fill" : "arrow.up.circle.fill") + .font(.title) + } + .focused($focusedField, equals: .message) + .onSubmit { sendOrStop() } + } + .toolbar { + ToolbarItem(placement: .primaryAction) { + Button(action: newChat) { + Image(systemName: "square.and.pencil") + } + } + } + .navigationTitle("Function Calling") + .onAppear { + focusedField = .message + } + } + + private func sendMessage() { + Task { + let prompt = userPrompt + userPrompt = "" + await viewModel.sendMessage(prompt, streaming: true) + } + } + + private func sendOrStop() { + if viewModel.busy { + viewModel.stop() + } else { + sendMessage() + } + } + + private func newChat() { + viewModel.startNewChat() + } +} + +struct FunctionCallingScreen_Previews: PreviewProvider { + struct ContainerView: View { + @EnvironmentObject + var viewModel: FunctionCallingViewModel + + var body: some View { + FunctionCallingScreen() + .onAppear { + viewModel.messages = ChatMessage.samples + } + } + } + + static var previews: some View { + NavigationStack { + FunctionCallingScreen().environmentObject(FunctionCallingViewModel()) + } + } +} diff --git a/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift new file mode 100644 index 0000000..31a541a --- /dev/null +++ b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -0,0 +1,273 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import GoogleGenerativeAI +import UIKit + +@MainActor +class FunctionCallingViewModel: ObservableObject { + /// This array holds both the user's and the system's chat messages + @Published var messages = [ChatMessage]() + + /// Indicates we're waiting for the model to finish + @Published var busy = false + + @Published var error: Error? + var hasError: Bool { + return error != nil + } + + /// Function calls pending processing + private var functionCalls = [FunctionCall]() + + private var model: GenerativeModel + private var chat: Chat + + private var chatTask: Task? + + init() { + model = GenerativeModel( + name: "gemini-1.0-pro", + apiKey: APIKey.default, + tools: [Tool(functionDeclarations: [ + FunctionDeclaration( + name: "get_exchange_rate", + description: "Get the exchange rate for currencies between countries", + parameters: [ + "currency_from": Schema( + type: .string, + format: "enum", + description: "The currency to convert from in ISO 4217 format", + enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] + ), + "currency_to": Schema( + type: .string, + format: "enum", + description: "The currency to convert to in ISO 4217 format", + enumValues: ["USD", "EUR", "JPY", "GBP", "AUD", "CAD"] + ), + ], + requiredParameters: ["currency_from", "currency_to"] + ), + ])], + requestOptions: RequestOptions(apiVersion: "v1beta") + ) + chat = model.startChat() + } + + func sendMessage(_ text: String, streaming: Bool = true) async { + error = nil + chatTask?.cancel() + + chatTask = Task { + busy = true + defer { + busy = false + } + + // first, add the user's message to the chat + let userMessage = ChatMessage(message: text, participant: .user) + messages.append(userMessage) + + // add a pending message while we're waiting for a response from the backend + let systemMessage = ChatMessage.pending(participant: .system) + messages.append(systemMessage) + + print(messages) + do { + repeat { + if streaming { + try await internalSendMessageStreaming(text) + } else { + try await internalSendMessage(text) + } + } while !functionCalls.isEmpty + messages[pendingMessageIndex()].pending = false + } catch { + self.error = error + print(error.localizedDescription) + messages.removeLast() + } + } + } + + func startNewChat() { + stop() + error = nil + chat = model.startChat() + messages.removeAll() + } + + func stop() { + chatTask?.cancel() + error = nil + } + + private func internalSendMessageStreaming(_ text: String) async throws { + let functionResponses = try await processFunctionCalls() + let responseStream: AsyncThrowingStream + if functionResponses.isEmpty { + responseStream = chat.sendMessageStream(text) + } else { + for functionResponse in functionResponses { + messages.insert(functionResponse.chatMessage(), at: pendingMessageIndex()) + } + responseStream = chat.sendMessageStream(functionResponses.modelContent()) + } + for try await chunk in responseStream { + processResponseContent(content: chunk) + } + } + + private func internalSendMessage(_ text: String) async throws { + let functionResponses = try await processFunctionCalls() + let response: GenerateContentResponse + if functionResponses.isEmpty { + response = try await chat.sendMessage(text) + } else { + for functionResponse in functionResponses { + messages.insert(functionResponse.chatMessage(), at: pendingMessageIndex()) + } + response = try await chat.sendMessage(functionResponses.modelContent()) + } + processResponseContent(content: response) + messages[pendingMessageIndex()].pending = false + } + + func processResponseContent(content: GenerateContentResponse) { + guard let candidate = content.candidates.first else { + fatalError("No candidate.") + } + + for part in candidate.content.parts { + switch part { + case let .text(text): + // replace pending message with backend response + messages[pendingMessageIndex()].message += text + case let .functionCall(functionCall): + messages.insert(functionCall.chatMessage(), at: pendingMessageIndex()) + functionCalls.append(functionCall) + case .data, .functionResponse: + fatalError("Unsupported response content.") + } + } + } + + func processFunctionCalls() async throws -> [FunctionResponse] { + var functionResponses = [FunctionResponse]() + for functionCall in functionCalls { + switch functionCall.name { + case "get_exchange_rate": + let exchangeRates = getExchangeRate(args: functionCall.args) + functionResponses.append(FunctionResponse( + name: "get_exchange_rate", + response: exchangeRates + )) + default: + fatalError("Unknown function named \"\(functionCall.name)\".") + } + } + functionCalls = [] + + return functionResponses + } + + private func pendingMessageIndex() -> Int { + return messages.lastIndex(where: { chatMessage in + chatMessage.participant == .system && chatMessage.pending + }) ?? messages.endIndex + } + + // MARK: - Callable Functions + + func getExchangeRate(args: JSONObject) -> JSONObject { + // 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) + guard case let .string(from) = args["currency_from"] else { + fatalError("Missing `currency_from` parameter.") + } + guard case let .string(to) = args["currency_to"] else { + fatalError("Missing `currency_to` parameter.") + } + + // 2. Get the exchange rate + let allRates: [String: [String: Double]] = [ + "AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379], + "CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362], + "EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932], + "GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836], + "JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679], + "USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26], + ] + guard let fromRates = allRates[from] else { + return ["error": .string("No data for currency \(from).")] + } + guard let toRate = fromRates[to] else { + return ["error": .string("No data for currency \(to).")] + } + + // 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) + return ["rates": .number(toRate)] + } +} + +private extension FunctionCall { + func chatMessage() -> ChatMessage { + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + + let jsonData: Data + do { + jsonData = try encoder.encode(self) + } catch { + fatalError("JSON Encoding Failed: \(error.localizedDescription)") + } + guard let json = String(data: jsonData, encoding: .utf8) else { + fatalError("Failed to convert JSON data to a String.") + } + let messageText = "Function call requested by model:\n```\n\(json)\n```" + + return ChatMessage(message: messageText, participant: .system) + } +} + +private extension FunctionResponse { + func chatMessage() -> ChatMessage { + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + + let jsonData: Data + do { + jsonData = try encoder.encode(self) + } catch { + fatalError("JSON Encoding Failed: \(error.localizedDescription)") + } + guard let json = String(data: jsonData, encoding: .utf8) else { + fatalError("Failed to convert JSON data to a String.") + } + let messageText = "Function response returned by app:\n```\n\(json)\n```" + + return ChatMessage(message: messageText, participant: .user) + } +} + +private extension [FunctionResponse] { + func modelContent() -> [ModelContent] { + return self.map { ModelContent( + role: "function", + parts: [ModelContent.Part.functionResponse($0)] + ) + } + } +} diff --git a/Examples/GenerativeAISample/GenerativeAISample.xcodeproj/project.pbxproj b/Examples/GenerativeAISample/GenerativeAISample.xcodeproj/project.pbxproj index 795fe47..f2eece5 100644 --- a/Examples/GenerativeAISample/GenerativeAISample.xcodeproj/project.pbxproj +++ b/Examples/GenerativeAISample/GenerativeAISample.xcodeproj/project.pbxproj @@ -7,6 +7,8 @@ objects = { /* Begin PBXBuildFile section */ + 86FBBA072BBE0D49006031A1 /* FunctionCallingScreen.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86FBBA052BBE0D49006031A1 /* FunctionCallingScreen.swift */; }; + 86FBBA272BBF0594006031A1 /* FunctionCallingViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 86FBBA252BBF0594006031A1 /* FunctionCallingViewModel.swift */; }; 880266762B0FC39000CF7CB6 /* PhotoReasoningViewModel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8802666F2B0FC39000CF7CB6 /* PhotoReasoningViewModel.swift */; }; 880266792B0FC39000CF7CB6 /* PhotoReasoningScreen.swift in Sources */ = {isa = PBXBuildFile; fileRef = 880266752B0FC39000CF7CB6 /* PhotoReasoningScreen.swift */; }; 881B753A2B0FDCE600528058 /* APIKey.swift in Sources */ = {isa = PBXBuildFile; fileRef = 88209C192B0FBDC300F64795 /* APIKey.swift */; }; @@ -64,6 +66,8 @@ /* End PBXBuildFile section */ /* Begin PBXFileReference section */ + 86FBBA052BBE0D49006031A1 /* FunctionCallingScreen.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FunctionCallingScreen.swift; sourceTree = ""; }; + 86FBBA252BBF0594006031A1 /* FunctionCallingViewModel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FunctionCallingViewModel.swift; sourceTree = ""; }; 8802666F2B0FC39000CF7CB6 /* PhotoReasoningViewModel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PhotoReasoningViewModel.swift; sourceTree = ""; }; 880266752B0FC39000CF7CB6 /* PhotoReasoningScreen.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = PhotoReasoningScreen.swift; sourceTree = ""; }; 88209C142B0F928F00F64795 /* GenerativeAI-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "GenerativeAI-Info.plist"; sourceTree = ""; }; @@ -140,6 +144,31 @@ /* End PBXFrameworksBuildPhase section */ /* Begin PBXGroup section */ + 8610C1BB2BBE09BC00987CF2 /* FunctionCallingSample */ = { + isa = PBXGroup; + children = ( + 86FBBA262BBF0594006031A1 /* ViewModels */, + 86FBBA062BBE0D49006031A1 /* Screens */, + ); + path = FunctionCallingSample; + sourceTree = ""; + }; + 86FBBA062BBE0D49006031A1 /* Screens */ = { + isa = PBXGroup; + children = ( + 86FBBA052BBE0D49006031A1 /* FunctionCallingScreen.swift */, + ); + path = Screens; + sourceTree = ""; + }; + 86FBBA262BBF0594006031A1 /* ViewModels */ = { + isa = PBXGroup; + children = ( + 86FBBA252BBF0594006031A1 /* FunctionCallingViewModel.swift */, + ); + path = ViewModels; + sourceTree = ""; + }; 8802666E2B0FC39000CF7CB6 /* ViewModels */ = { isa = PBXGroup; children = ( @@ -188,6 +217,7 @@ 8848C8452B0D051E007B434F /* GenerativeAITextSample */, 8848C8572B0D056C007B434F /* GenerativeAIMultimodalSample */, 88E10F432B110D5300C08E95 /* ChatSample */, + 8610C1BB2BBE09BC00987CF2 /* FunctionCallingSample */, 8848C8302B0D04BC007B434F /* Products */, 88209C222B0FBE1700F64795 /* Frameworks */, ); @@ -601,6 +631,7 @@ 88263BF12B239C11008AB09B /* ErrorDetailsView.swift in Sources */, 8848C8352B0D04BC007B434F /* ContentView.swift in Sources */, 886F95D52B17BA010036F07A /* SummarizeScreen.swift in Sources */, + 86FBBA072BBE0D49006031A1 /* FunctionCallingScreen.swift in Sources */, 881B753B2B0FDCE600528058 /* APIKey.swift in Sources */, 8848C8332B0D04BC007B434F /* GenerativeAISampleApp.swift in Sources */, 886F95E02B17D5010036F07A /* ConversationViewModel.swift in Sources */, @@ -608,6 +639,7 @@ 886F95DC2B17BAEF0036F07A /* PhotoReasoningScreen.swift in Sources */, 886F95DB2B17BAEF0036F07A /* PhotoReasoningViewModel.swift in Sources */, 886F95E12B17D5010036F07A /* ConversationScreen.swift in Sources */, + 86FBBA272BBF0594006031A1 /* FunctionCallingViewModel.swift in Sources */, 88263BF02B239C09008AB09B /* ErrorView.swift in Sources */, 886F95D62B17BA010036F07A /* SummarizeViewModel.swift in Sources */, ); diff --git a/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift b/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift index 34331bf..6684350 100644 --- a/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift +++ b/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift @@ -17,6 +17,9 @@ import SwiftUI struct ContentView: View { @StateObject var viewModel = ConversationViewModel() + + @StateObject + var functionCallingViewModel = FunctionCallingViewModel() var body: some View { NavigationStack { @@ -37,6 +40,11 @@ struct ContentView: View { } label: { Label("Chat", systemImage: "ellipsis.message.fill") } + NavigationLink { + FunctionCallingScreen().environmentObject(functionCallingViewModel) + } label: { + Label("Function Calling", systemImage: "function") + } } .navigationTitle("Generative AI Samples") } From 458a59b64ec53f74fc15fb41db2f6d98e380f56d Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 9 Apr 2024 21:10:56 -0400 Subject: [PATCH 2/4] Fix formatting --- .../ViewModels/FunctionCallingViewModel.swift | 152 +++++++++--------- .../GenerativeAISample/ContentView.swift | 2 +- 2 files changed, 77 insertions(+), 77 deletions(-) diff --git a/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift index 31a541a..899dd6b 100644 --- a/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift +++ b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -20,10 +20,10 @@ import UIKit class FunctionCallingViewModel: ObservableObject { /// This array holds both the user's and the system's chat messages @Published var messages = [ChatMessage]() - + /// Indicates we're waiting for the model to finish @Published var busy = false - + @Published var error: Error? var hasError: Bool { return error != nil @@ -70,21 +70,21 @@ class FunctionCallingViewModel: ObservableObject { func sendMessage(_ text: String, streaming: Bool = true) async { error = nil chatTask?.cancel() - + chatTask = Task { busy = true defer { busy = false } - + // first, add the user's message to the chat let userMessage = ChatMessage(message: text, participant: .user) messages.append(userMessage) - + // add a pending message while we're waiting for a response from the backend let systemMessage = ChatMessage.pending(participant: .system) messages.append(systemMessage) - + print(messages) do { repeat { @@ -102,14 +102,14 @@ class FunctionCallingViewModel: ObservableObject { } } } - + func startNewChat() { stop() error = nil chat = model.startChat() messages.removeAll() } - + func stop() { chatTask?.cancel() error = nil @@ -145,81 +145,81 @@ class FunctionCallingViewModel: ObservableObject { processResponseContent(content: response) messages[pendingMessageIndex()].pending = false } - - func processResponseContent(content: GenerateContentResponse) { - guard let candidate = content.candidates.first else { - fatalError("No candidate.") - } - - for part in candidate.content.parts { - switch part { - case let .text(text): - // replace pending message with backend response - messages[pendingMessageIndex()].message += text - case let .functionCall(functionCall): - messages.insert(functionCall.chatMessage(), at: pendingMessageIndex()) - functionCalls.append(functionCall) - case .data, .functionResponse: - fatalError("Unsupported response content.") - } + + func processResponseContent(content: GenerateContentResponse) { + guard let candidate = content.candidates.first else { + fatalError("No candidate.") + } + + for part in candidate.content.parts { + switch part { + case let .text(text): + // replace pending message with backend response + messages[pendingMessageIndex()].message += text + case let .functionCall(functionCall): + messages.insert(functionCall.chatMessage(), at: pendingMessageIndex()) + functionCalls.append(functionCall) + case .data, .functionResponse: + fatalError("Unsupported response content.") } } - - func processFunctionCalls() async throws -> [FunctionResponse] { - var functionResponses = [FunctionResponse]() - for functionCall in functionCalls { - switch functionCall.name { - case "get_exchange_rate": - let exchangeRates = getExchangeRate(args: functionCall.args) - functionResponses.append(FunctionResponse( - name: "get_exchange_rate", - response: exchangeRates - )) - default: - fatalError("Unknown function named \"\(functionCall.name)\".") - } + } + + func processFunctionCalls() async throws -> [FunctionResponse] { + var functionResponses = [FunctionResponse]() + for functionCall in functionCalls { + switch functionCall.name { + case "get_exchange_rate": + let exchangeRates = getExchangeRate(args: functionCall.args) + functionResponses.append(FunctionResponse( + name: "get_exchange_rate", + response: exchangeRates + )) + default: + fatalError("Unknown function named \"\(functionCall.name)\".") } - functionCalls = [] - - return functionResponses } - + functionCalls = [] + + return functionResponses + } + private func pendingMessageIndex() -> Int { return messages.lastIndex(where: { chatMessage in chatMessage.participant == .system && chatMessage.pending }) ?? messages.endIndex } - - // MARK: - Callable Functions - - func getExchangeRate(args: JSONObject) -> JSONObject { - // 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) - guard case let .string(from) = args["currency_from"] else { - fatalError("Missing `currency_from` parameter.") - } - guard case let .string(to) = args["currency_to"] else { - fatalError("Missing `currency_to` parameter.") - } - - // 2. Get the exchange rate - let allRates: [String: [String: Double]] = [ - "AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379], - "CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362], - "EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932], - "GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836], - "JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679], - "USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26], - ] - guard let fromRates = allRates[from] else { - return ["error": .string("No data for currency \(from).")] - } - guard let toRate = fromRates[to] else { - return ["error": .string("No data for currency \(to).")] - } - - // 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) - return ["rates": .number(toRate)] + + // MARK: - Callable Functions + + func getExchangeRate(args: JSONObject) -> JSONObject { + // 1. Validate and extract the parameters provided by the model (from a `FunctionCall`) + guard case let .string(from) = args["currency_from"] else { + fatalError("Missing `currency_from` parameter.") + } + guard case let .string(to) = args["currency_to"] else { + fatalError("Missing `currency_to` parameter.") } + + // 2. Get the exchange rate + let allRates: [String: [String: Double]] = [ + "AUD": ["CAD": 0.89265, "EUR": 0.6072, "GBP": 0.51714, "JPY": 97.75, "USD": 0.66379], + "CAD": ["AUD": 1.1203, "EUR": 0.68023, "GBP": 0.57933, "JPY": 109.51, "USD": 0.74362], + "EUR": ["AUD": 1.6469, "CAD": 1.4701, "GBP": 0.85168, "JPY": 160.99, "USD": 1.0932], + "GBP": ["AUD": 1.9337, "CAD": 1.7261, "EUR": 1.1741, "JPY": 189.03, "USD": 1.2836], + "JPY": ["AUD": 0.01023, "CAD": 0.00913, "EUR": 0.00621, "GBP": 0.00529, "USD": 0.00679], + "USD": ["AUD": 1.5065, "CAD": 1.3448, "EUR": 0.91475, "GBP": 0.77907, "JPY": 147.26], + ] + guard let fromRates = allRates[from] else { + return ["error": .string("No data for currency \(from).")] + } + guard let toRate = fromRates[to] else { + return ["error": .string("No data for currency \(to).")] + } + + // 3. Return the exchange rates as a JSON object (returned to the model in a `FunctionResponse`) + return ["rates": .number(toRate)] + } } private extension FunctionCall { @@ -237,7 +237,7 @@ private extension FunctionCall { fatalError("Failed to convert JSON data to a String.") } let messageText = "Function call requested by model:\n```\n\(json)\n```" - + return ChatMessage(message: messageText, participant: .system) } } @@ -246,7 +246,7 @@ private extension FunctionResponse { func chatMessage() -> ChatMessage { let encoder = JSONEncoder() encoder.outputFormatting = .prettyPrinted - + let jsonData: Data do { jsonData = try encoder.encode(self) @@ -257,7 +257,7 @@ private extension FunctionResponse { fatalError("Failed to convert JSON data to a String.") } let messageText = "Function response returned by app:\n```\n\(json)\n```" - + return ChatMessage(message: messageText, participant: .user) } } diff --git a/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift b/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift index 6684350..2841d63 100644 --- a/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift +++ b/Examples/GenerativeAISample/GenerativeAISample/ContentView.swift @@ -17,7 +17,7 @@ import SwiftUI struct ContentView: View { @StateObject var viewModel = ConversationViewModel() - + @StateObject var functionCallingViewModel = FunctionCallingViewModel() From 95141beb54230f3b2512fcab3c0894cc67e33ee2 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 10 Apr 2024 15:29:28 -0400 Subject: [PATCH 3/4] Fix comment formatting --- .../FunctionCallingSample/Screens/FunctionCallingScreen.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift b/Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift index a210978..4848ec5 100644 --- a/Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift +++ b/Examples/GenerativeAISample/FunctionCallingSample/Screens/FunctionCallingScreen.swift @@ -46,7 +46,7 @@ struct FunctionCallingScreen: View { .listStyle(.plain) .onChange(of: viewModel.messages, perform: { newValue in if viewModel.hasError { - // wait for a short moment to make sure we can actually scroll to the bottom + // Wait for a short moment to make sure we can actually scroll to the bottom. DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { withAnimation { scrollViewProxy.scrollTo("errorView", anchor: .bottom) @@ -56,7 +56,7 @@ struct FunctionCallingScreen: View { } else { guard let lastMessage = viewModel.messages.last else { return } - // wait for a short moment to make sure we can actually scroll to the bottom + // Wait for a short moment to make sure we can actually scroll to the bottom. DispatchQueue.main.asyncAfter(deadline: .now() + 0.05) { withAnimation { scrollViewProxy.scrollTo(lastMessage.id, anchor: .bottom) From 6ec5ddd6ce19b5a20928a0c6ac304f39d3f642e7 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Wed, 10 Apr 2024 15:30:16 -0400 Subject: [PATCH 4/4] Fix pending messages --- .../ViewModels/FunctionCallingViewModel.swift | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift index 899dd6b..0fd64a4 100644 --- a/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift +++ b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -94,7 +94,6 @@ class FunctionCallingViewModel: ObservableObject { try await internalSendMessage(text) } } while !functionCalls.isEmpty - messages[pendingMessageIndex()].pending = false } catch { self.error = error print(error.localizedDescription) @@ -122,7 +121,7 @@ class FunctionCallingViewModel: ObservableObject { responseStream = chat.sendMessageStream(text) } else { for functionResponse in functionResponses { - messages.insert(functionResponse.chatMessage(), at: pendingMessageIndex()) + messages.insert(functionResponse.chatMessage(), at: messages.count - 1) } responseStream = chat.sendMessageStream(functionResponses.modelContent()) } @@ -138,12 +137,11 @@ class FunctionCallingViewModel: ObservableObject { response = try await chat.sendMessage(text) } else { for functionResponse in functionResponses { - messages.insert(functionResponse.chatMessage(), at: pendingMessageIndex()) + messages.insert(functionResponse.chatMessage(), at: messages.count - 1) } response = try await chat.sendMessage(functionResponses.modelContent()) } processResponseContent(content: response) - messages[pendingMessageIndex()].pending = false } func processResponseContent(content: GenerateContentResponse) { @@ -155,9 +153,10 @@ class FunctionCallingViewModel: ObservableObject { switch part { case let .text(text): // replace pending message with backend response - messages[pendingMessageIndex()].message += text + messages[messages.count - 1].message += text + messages[messages.count - 1].pending = false case let .functionCall(functionCall): - messages.insert(functionCall.chatMessage(), at: pendingMessageIndex()) + messages.insert(functionCall.chatMessage(), at: messages.count - 1) functionCalls.append(functionCall) case .data, .functionResponse: fatalError("Unsupported response content.") @@ -184,12 +183,6 @@ class FunctionCallingViewModel: ObservableObject { return functionResponses } - private func pendingMessageIndex() -> Int { - return messages.lastIndex(where: { chatMessage in - chatMessage.participant == .system && chatMessage.pending - }) ?? messages.endIndex - } - // MARK: - Callable Functions func getExchangeRate(args: JSONObject) -> JSONObject {