diff --git a/FirebaseVertexAI/CHANGELOG.md b/FirebaseVertexAI/CHANGELOG.md index 901378a12ba..d24cbc20dbd 100644 --- a/FirebaseVertexAI/CHANGELOG.md +++ b/FirebaseVertexAI/CHANGELOG.md @@ -1,6 +1,9 @@ # 11.2.0 - [fixed] Resolved a decoding error for citations without a `uri` and added support for decoding `title` fields, which were previously ignored. (#13518) +- [changed] **Breaking Change**: The methods for starting streaming requests + (`generateContentStream` and `sendMessageStream`) and creating a chat instance + (`startChat`) are now asynchronous and must be called with `await`. (#13545) # 10.29.0 - [feature] Added community support for watchOS. (#13215) diff --git a/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift b/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift index 78c903e3412..43da223aa78 100644 --- a/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift +++ b/FirebaseVertexAI/Sample/ChatSample/Screens/ConversationScreen.swift @@ -104,7 +104,9 @@ struct ConversationScreen: View { } private func newChat() { - viewModel.startNewChat() + Task { + await viewModel.startNewChat() + } } } diff --git a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift index 04d8eeea33c..f3f15f35b86 100644 --- a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift +++ b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift @@ -21,8 +21,8 @@ class ConversationViewModel: ObservableObject { /// This array holds both the user's and the system's chat messages @Published var messages = [ChatMessage]() - /// Indicates we're waiting for the model to finish - @Published var busy = false + /// Indicates we're waiting for the model to finish or the UI is loading + @Published var busy = true @Published var error: Error? var hasError: Bool { @@ -30,18 +30,20 @@ class ConversationViewModel: ObservableObject { } private var model: GenerativeModel - private var chat: Chat + private var chat: Chat? = nil private var stopGenerating = false private var chatTask: Task? init() { model = VertexAI.vertexAI().generativeModel(modelName: "gemini-1.5-flash") - chat = model.startChat() + Task { + await startNewChat() + } } func sendMessage(_ text: String, streaming: Bool = true) async { - error = nil + stop() if streaming { await internalSendMessageStreaming(text) } else { @@ -49,11 +51,14 @@ class ConversationViewModel: ObservableObject { } } - func startNewChat() { + func startNewChat() async { + busy = true + defer { + busy = false + } stop() - error = nil - chat = model.startChat() messages.removeAll() + chat = await model.startChat() } func stop() { @@ -62,8 +67,6 @@ class ConversationViewModel: ObservableObject { } private func internalSendMessageStreaming(_ text: String) async { - chatTask?.cancel() - chatTask = Task { busy = true defer { @@ -79,7 +82,10 @@ class ConversationViewModel: ObservableObject { messages.append(systemMessage) do { - let responseStream = chat.sendMessageStream(text) + guard let chat else { + throw ChatError.notInitialized + } + let responseStream = await chat.sendMessageStream(text) for try await chunk in responseStream { messages[messages.count - 1].pending = false if let text = chunk.text { @@ -95,8 +101,6 @@ class ConversationViewModel: ObservableObject { } private func internalSendMessage(_ text: String) async { - chatTask?.cancel() - chatTask = Task { busy = true defer { @@ -112,10 +116,12 @@ class ConversationViewModel: ObservableObject { messages.append(systemMessage) do { - var response: GenerateContentResponse? - response = try await chat.sendMessage(text) + guard let chat = chat else { + throw ChatError.notInitialized + } + let response = try await chat.sendMessage(text) - if let responseText = response?.text { + if let responseText = response.text { // replace pending message with backend response messages[messages.count - 1].message = responseText messages[messages.count - 1].pending = false @@ -127,4 +133,8 @@ class ConversationViewModel: ObservableObject { } } } + + enum ChatError: Error { + case notInitialized + } } diff --git a/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift b/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift index f16da39e22f..dbfd04eb52c 100644 --- a/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift +++ b/FirebaseVertexAI/Sample/FunctionCallingSample/Screens/FunctionCallingScreen.swift @@ -106,7 +106,9 @@ struct FunctionCallingScreen: View { } private func newChat() { - viewModel.startNewChat() + Task { + await viewModel.startNewChat() + } } } diff --git a/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift index 13ad5afe23c..56c4c2453a3 100644 --- a/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift +++ b/FirebaseVertexAI/Sample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -33,7 +33,7 @@ class FunctionCallingViewModel: ObservableObject { private var functionCalls = [FunctionCall]() private var model: GenerativeModel - private var chat: Chat + private var chat: Chat? = nil private var chatTask: Task? @@ -62,13 +62,13 @@ class FunctionCallingViewModel: ObservableObject { ), ])] ) - chat = model.startChat() + Task { + await startNewChat() + } } func sendMessage(_ text: String, streaming: Bool = true) async { - error = nil - chatTask?.cancel() - + stop() chatTask = Task { busy = true defer { @@ -100,11 +100,14 @@ class FunctionCallingViewModel: ObservableObject { } } - func startNewChat() { + func startNewChat() async { + busy = true + defer { + busy = false + } stop() - error = nil - chat = model.startChat() messages.removeAll() + chat = await model.startChat() } func stop() { @@ -114,14 +117,17 @@ class FunctionCallingViewModel: ObservableObject { private func internalSendMessageStreaming(_ text: String) async throws { let functionResponses = try await processFunctionCalls() + guard let chat else { + throw ChatError.notInitialized + } let responseStream: AsyncThrowingStream if functionResponses.isEmpty { - responseStream = chat.sendMessageStream(text) + responseStream = await chat.sendMessageStream(text) } else { for functionResponse in functionResponses { messages.insert(functionResponse.chatMessage(), at: messages.count - 1) } - responseStream = chat.sendMessageStream(functionResponses.modelContent()) + responseStream = await chat.sendMessageStream(functionResponses.modelContent()) } for try await chunk in responseStream { processResponseContent(content: chunk) @@ -130,6 +136,9 @@ class FunctionCallingViewModel: ObservableObject { private func internalSendMessage(_ text: String) async throws { let functionResponses = try await processFunctionCalls() + guard let chat else { + throw ChatError.notInitialized + } let response: GenerateContentResponse if functionResponses.isEmpty { response = try await chat.sendMessage(text) @@ -181,6 +190,10 @@ class FunctionCallingViewModel: ObservableObject { return functionResponses } + enum ChatError: Error { + case notInitialized + } + // MARK: - Callable Functions func getExchangeRate(args: JSONObject) -> JSONObject { diff --git a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index d937b92f716..a8b3972561b 100644 --- a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -84,7 +84,7 @@ class PhotoReasoningViewModel: ObservableObject { } } - let outputContentStream = model.generateContentStream(prompt, images) + let outputContentStream = await model.generateContentStream(prompt, images) // stream response for try await outputContent in outputContentStream { diff --git a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift index 8b08ec71682..025e30abc39 100644 --- a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift @@ -50,7 +50,7 @@ class SummarizeViewModel: ObservableObject { let prompt = "Summarize the following text for me: \(inputText)" - let outputContentStream = model.generateContentStream(prompt) + let outputContentStream = await model.generateContentStream(prompt) // stream response for try await outputContent in outputContentStream { diff --git a/FirebaseVertexAI/Sources/Chat.swift b/FirebaseVertexAI/Sources/Chat.swift index 81d6e50b3fe..9e200a47890 100644 --- a/FirebaseVertexAI/Sources/Chat.swift +++ b/FirebaseVertexAI/Sources/Chat.swift @@ -17,7 +17,7 @@ import Foundation /// An object that represents a back-and-forth chat with a model, capturing the history and saving /// the context in memory between each message sent. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public class Chat { +public actor Chat { private let model: GenerativeModel /// Initializes a new chat representing a 1:1 conversation between model and user. @@ -121,7 +121,7 @@ public class Chat { // Send the history alongside the new message as context. let request = history + newContent - let stream = model.generateContentStream(request) + let stream = await model.generateContentStream(request) do { for try await chunk in stream { // Capture any content that's streaming. This should be populated if there's no error. diff --git a/FirebaseVertexAI/Sources/GenerativeModel.swift b/FirebaseVertexAI/Sources/GenerativeModel.swift index fb65209cd9f..c8642e83fc1 100644 --- a/FirebaseVertexAI/Sources/GenerativeModel.swift +++ b/FirebaseVertexAI/Sources/GenerativeModel.swift @@ -19,7 +19,7 @@ import Foundation /// A type that represents a remote multimodal model (like Gemini), with the ability to generate /// content based on various input types. @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, tvOS 15.0, watchOS 8.0, *) -public final class GenerativeModel { +public final actor GenerativeModel { /// The resource name of the model in the backend; has the format "models/model-name". let modelResourceName: String @@ -217,33 +217,31 @@ public final class GenerativeModel { isStreaming: true, options: requestOptions) - var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) - .makeAsyncIterator() + let responseStream = generativeAIService.loadRequestStream(request: generateContentRequest) + return AsyncThrowingStream { - let response: GenerateContentResponse? do { - response = try await responseIterator.next() - } catch { - throw GenerativeModel.generateContentError(from: error) - } + for try await response in responseStream { + // Check the prompt feedback to see if the prompt was blocked. + if response.promptFeedback?.blockReason != nil { + throw GenerateContentError.promptBlocked(response: response) + } - // The responseIterator will return `nil` when it's done. - guard let response = response else { + // If the stream ended early unexpectedly, throw an error. + if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { + throw GenerateContentError.responseStoppedEarly( + reason: finishReason, + response: response + ) + } else { + // Response was valid content, pass it along and continue. + return response + } + } // This is the end of the stream! Signal it by sending `nil`. return nil - } - - // Check the prompt feedback to see if the prompt was blocked. - if response.promptFeedback?.blockReason != nil { - throw GenerateContentError.promptBlocked(response: response) - } - - // If the stream ended early unexpectedly, throw an error. - if let finishReason = response.candidates.first?.finishReason, finishReason != .stop { - throw GenerateContentError.responseStoppedEarly(reason: finishReason, response: response) - } else { - // Response was valid content, pass it along and continue. - return response + } catch { + throw GenerativeModel.generateContentError(from: error) } } } diff --git a/FirebaseVertexAI/Tests/Unit/ChatTests.swift b/FirebaseVertexAI/Tests/Unit/ChatTests.swift index 389fcec1c5f..6191eb234ee 100644 --- a/FirebaseVertexAI/Tests/Unit/ChatTests.swift +++ b/FirebaseVertexAI/Tests/Unit/ChatTests.swift @@ -64,19 +64,20 @@ final class ChatTests: XCTestCase { ) let chat = Chat(model: model, history: []) let input = "Test input" - let stream = chat.sendMessageStream(input) + let stream = await chat.sendMessageStream(input) // Ensure the values are parsed correctly for try await value in stream { XCTAssertNotNil(value.text) } - XCTAssertEqual(chat.history.count, 2) - XCTAssertEqual(chat.history[0].parts[0].text, input) + let history = await chat.history + XCTAssertEqual(history.count, 2) + XCTAssertEqual(history[0].parts[0].text, input) let finalText = "1 2 3 4 5 6 7 8" let assembledExpectation = ModelContent(role: "model", parts: finalText) - XCTAssertEqual(chat.history[0].parts[0].text, input) - XCTAssertEqual(chat.history[1], assembledExpectation) + XCTAssertEqual(history[0].parts[0].text, input) + XCTAssertEqual(history[1], assembledExpectation) } } diff --git a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift index b62d224122f..19c7a4bf80b 100644 --- a/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift +++ b/FirebaseVertexAI/Tests/Unit/GenerativeModelTests.swift @@ -760,7 +760,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } @@ -784,7 +784,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = model.generateContentStream(testPrompt) + let stream = await model.generateContentStream(testPrompt) for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } @@ -807,7 +807,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await _ in stream { XCTFail("No content is there, this shouldn't happen.") } @@ -827,7 +827,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } @@ -847,7 +847,7 @@ final class GenerativeModelTests: XCTestCase { ) do { - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await _ in stream { XCTFail("Content shouldn't be shown, this shouldn't happen.") } @@ -866,7 +866,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") do { for try await content in stream { XCTAssertNotNil(content.text) @@ -887,7 +887,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = 0 - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) responses += 1 @@ -904,7 +904,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = 0 - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) responses += 1 @@ -921,7 +921,7 @@ final class GenerativeModelTests: XCTestCase { ) var hadUnknown = false - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) if let ratings = content.candidates.first?.safetyRatings, @@ -940,7 +940,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") var citations = [Citation]() var responses = [GenerateContentResponse]() for try await content in stream { @@ -996,7 +996,7 @@ final class GenerativeModelTests: XCTestCase { appCheckToken: appCheckToken ) - let stream = model.generateContentStream(testPrompt) + let stream = await model.generateContentStream(testPrompt) for try await _ in stream {} } @@ -1018,7 +1018,7 @@ final class GenerativeModelTests: XCTestCase { appCheckToken: AppCheckInteropFake.placeholderTokenValue ) - let stream = model.generateContentStream(testPrompt) + let stream = await model.generateContentStream(testPrompt) for try await _ in stream {} } @@ -1030,7 +1030,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = [GenerateContentResponse]() - let stream = model.generateContentStream(testPrompt) + let stream = await model.generateContentStream(testPrompt) for try await response in stream { responses.append(response) } @@ -1056,7 +1056,7 @@ final class GenerativeModelTests: XCTestCase { var responseCount = 0 do { - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") for try await content in stream { XCTAssertNotNil(content.text) responseCount += 1 @@ -1076,7 +1076,7 @@ final class GenerativeModelTests: XCTestCase { func testGenerateContentStream_nonHTTPResponse() async throws { MockURLProtocol.requestHandler = try nonHTTPRequestHandler() - let stream = model.generateContentStream("Hi") + let stream = await model.generateContentStream("Hi") do { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") @@ -1096,7 +1096,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = model.generateContentStream(testPrompt) + let stream = await model.generateContentStream(testPrompt) do { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") @@ -1120,7 +1120,7 @@ final class GenerativeModelTests: XCTestCase { withExtension: "txt" ) - let stream = model.generateContentStream(testPrompt) + let stream = await model.generateContentStream(testPrompt) do { for try await content in stream { XCTFail("Unexpected content in stream: \(content)") @@ -1159,7 +1159,7 @@ final class GenerativeModelTests: XCTestCase { ) var responses = 0 - let stream = model.generateContentStream(testPrompt) + let stream = await model.generateContentStream(testPrompt) for try await content in stream { XCTAssertNotNil(content.text) responses += 1 diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index c68b69b03ec..f2c38a03e61 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -170,8 +170,8 @@ final class VertexAIAPITests: XCTestCase { #endif // Chat - _ = genAI.startChat() - _ = genAI.startChat(history: [ModelContent(parts: "abc")]) + _ = await genAI.startChat() + _ = await genAI.startChat(history: [ModelContent(parts: "abc")]) } // Public API tests for GenerateContentResponse. diff --git a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift index 9d5b9c65251..3ee12eb1c4d 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexComponentTests.swift @@ -106,15 +106,20 @@ class VertexComponentTests: XCTestCase { let app = try XCTUnwrap(VertexComponentTests.app) let vertex = VertexAI.vertexAI(app: app, location: location) let modelName = "test-model-name" - let modelResourceName = vertex.modelResourceName(modelName: modelName) - let systemInstruction = ModelContent(role: "system", parts: "test-system-instruction-prompt") + let expectedModelResourceName = vertex.modelResourceName(modelName: modelName) + let expectedSystemInstruction = ModelContent( + role: "system", + parts: "test-system-instruction-prompt" + ) let generativeModel = vertex.generativeModel( modelName: modelName, - systemInstruction: systemInstruction + systemInstruction: expectedSystemInstruction ) - XCTAssertEqual(generativeModel.modelResourceName, modelResourceName) - XCTAssertEqual(generativeModel.systemInstruction, systemInstruction) + let modelResourceName = await generativeModel.modelResourceName + let systemInstruction = await generativeModel.systemInstruction + XCTAssertEqual(modelResourceName, expectedModelResourceName) + XCTAssertEqual(systemInstruction, expectedSystemInstruction) } }