From f85d31c3d28967ec5d15268e986ccf34129f09b2 Mon Sep 17 00:00:00 2001 From: Paul Beusterien Date: Tue, 26 Mar 2024 13:24:29 -0700 Subject: [PATCH] Manage location on VertexAI instead of model (#12630) --- .../ViewModels/ConversationViewModel.swift | 5 +-- .../ViewModels/PhotoReasoningViewModel.swift | 5 +-- .../ViewModels/SummarizeViewModel.swift | 5 +-- FirebaseVertexAI/Sources/VertexAI.swift | 40 +++++++++++-------- .../Sources/VertexAIComponent.swift | 6 +-- .../Tests/Unit/VertexAIAPITests.swift | 9 ++--- 6 files changed, 33 insertions(+), 37 deletions(-) diff --git a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift index 05cbe11250f..81667be2514 100644 --- a/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift +++ b/FirebaseVertexAI/Sample/ChatSample/ViewModels/ConversationViewModel.swift @@ -36,10 +36,7 @@ class ConversationViewModel: ObservableObject { private var chatTask: Task? init() { - model = VertexAI.vertexAI().generativeModel( - modelName: "gemini-1.0-pro", - location: "us-central1" - ) + model = VertexAI.vertexAI(region: "us-central1").generativeModel(modelName: "gemini-1.0-pro") chat = model.startChat() } diff --git a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index 2f2ed88d4a1..b4880bdc825 100644 --- a/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -44,10 +44,7 @@ class PhotoReasoningViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = VertexAI.vertexAI().generativeModel( - modelName: "gemini-1.0-pro-vision", - location: "us-central1" - ) + model = VertexAI.vertexAI(region: "us-central1").generativeModel(modelName: "gemini-1.0-pro") } func reason() async { diff --git a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift index 0e3073d6da2..e3c78d09060 100644 --- a/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift +++ b/FirebaseVertexAI/Sample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift @@ -32,10 +32,7 @@ class SummarizeViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = VertexAI.vertexAI().generativeModel( - modelName: "gemini-1.0-pro", - location: "us-central1" - ) + model = VertexAI.vertexAI(region: "us-central1").generativeModel(modelName: "gemini-1.0-pro") } func summarize(inputText: String) async { diff --git a/FirebaseVertexAI/Sources/VertexAI.swift b/FirebaseVertexAI/Sources/VertexAI.swift index 64e76d3b4f5..2abb4abf7b0 100644 --- a/FirebaseVertexAI/Sources/VertexAI.swift +++ b/FirebaseVertexAI/Sources/VertexAI.swift @@ -25,26 +25,35 @@ public class VertexAI: NSObject { /// The default `VertexAI` instance. /// + /// - Parameter region: The region identifier, e.g., `us-central1`; see + /// [Vertex AI + /// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) + /// for a list of supported regions. /// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`. - public static func vertexAI() -> VertexAI { + public static func vertexAI(region: String) -> VertexAI { guard let app = FirebaseApp.app() else { fatalError("No instance of the default Firebase app was found.") } - return vertexAI(app: app) + return vertexAI(app: app, region: region) } /// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`. /// - /// - Parameter app: The custom `FirebaseApp` used for initialization. + /// - Parameters: + /// - app: The custom `FirebaseApp` used for initialization. + /// - region: The region identifier, e.g., `us-central1`; see + /// [Vertex AI + /// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) + /// for a list of supported regions. /// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`. - public static func vertexAI(app: FirebaseApp) -> VertexAI { + public static func vertexAI(app: FirebaseApp, region: String) -> VertexAI { guard let provider = ComponentType.instance(for: VertexAIProvider.self, in: app.container) else { fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)") } - return provider.vertexAI() + return provider.vertexAI(region) } /// Initializes a generative model with the given parameters. @@ -54,19 +63,15 @@ public class VertexAI: NSObject { /// [Gemini /// models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models) /// for a list of supported model names. - /// - location: The location identifier, e.g., `us-central1`; see - /// [Vertex AI - /// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions) - /// for a list of supported locations. /// - generationConfig: The content generation parameters your model should use. /// - safetySettings: A value describing what types of harmful content your model should allow. /// - requestOptions: Configuration parameters for sending requests to the backend. - public func generativeModel(modelName: String, location: String, + public func generativeModel(modelName: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, requestOptions: RequestOptions = RequestOptions()) -> GenerativeModel { - let modelResourceName = modelResourceName(modelName: modelName, location: location) + let modelResourceName = modelResourceName(modelName: modelName, region: region) guard let apiKey = app.options.apiKey else { fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.") @@ -89,26 +94,29 @@ public class VertexAI: NSObject { private let appCheck: AppCheckInterop? - init(app: FirebaseApp) { + private let region: String + + init(app: FirebaseApp, region: String) { self.app = app + self.region = region appCheck = ComponentType.instance(for: AppCheckInterop.self, in: app.container) } - private func modelResourceName(modelName: String, location: String) -> String { + private func modelResourceName(modelName: String, region: String) -> String { if modelName.contains("/") { return modelName } guard let projectID = app.options.projectID else { fatalError("The Firebase app named \"\(app.name)\" has no project ID in its configuration.") } - guard !location.isEmpty else { + guard !region.isEmpty else { fatalError(""" - No location specified; see + No region specified; see https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions for a list of available regions. """) } - return "projects/\(projectID)/locations/\(location)/publishers/google/models/\(modelName)" + return "projects/\(projectID)/locations/\(region)/publishers/google/models/\(modelName)" } } diff --git a/FirebaseVertexAI/Sources/VertexAIComponent.swift b/FirebaseVertexAI/Sources/VertexAIComponent.swift index 1378f812626..a9b7aa669a6 100644 --- a/FirebaseVertexAI/Sources/VertexAIComponent.swift +++ b/FirebaseVertexAI/Sources/VertexAIComponent.swift @@ -22,7 +22,7 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @objc(FIRVertexAIProvider) protocol VertexAIProvider { - @objc func vertexAI() -> VertexAI + @objc func vertexAI(_ location: String) -> VertexAI } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) @@ -64,7 +64,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider { // MARK: - VertexAIProvider conformance - func vertexAI() -> VertexAI { + func vertexAI(_ region: String) -> VertexAI { os_unfair_lock_lock(&instancesLock) // Unlock before the function returns. @@ -73,7 +73,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider { if let instance = instances[app.name] { return instance } - let newInstance = VertexAI(app: app) + let newInstance = VertexAI(app: app, region: region) instances[app.name] = newInstance return newInstance } diff --git a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift index a1fd27ab4c5..a4983eb8834 100644 --- a/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift +++ b/FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift @@ -34,31 +34,28 @@ final class VertexAIAPITests: XCTestCase { let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] // Instantiate Vertex AI SDK - Default App - let vertexAI = VertexAI.vertexAI() + let vertexAI = VertexAI.vertexAI(region: "my-region") // Instantiate Vertex AI SDK - Custom App - let _ = VertexAI.vertexAI(app: app!) + let _ = VertexAI.vertexAI(app: app!, region: "my-region") // Permutations without optional arguments. - let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1") + let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro") let _ = vertexAI.generativeModel( modelName: "gemini-1.0-pro", - location: "us-central1", safetySettings: filters ) let _ = vertexAI.generativeModel( modelName: "gemini-1.0-pro", - location: "us-central1", generationConfig: config ) // All arguments passed. let genAI = vertexAI.generativeModel( modelName: "gemini-1.0-pro", - location: "us-central1", generationConfig: config, // Optional safetySettings: filters // Optional )