Skip to content

Commit

Permalink
Manage location on VertexAI instead of model (#12630)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulb777 authored Mar 26, 2024
1 parent a315fdf commit f85d31c
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ class ConversationViewModel: ObservableObject {
private var chatTask: Task<Void, Never>?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1"
)
model = VertexAI.vertexAI(region: "us-central1").generativeModel(modelName: "gemini-1.0-pro")
chat = model.startChat()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,7 @@ class PhotoReasoningViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro-vision",
location: "us-central1"
)
model = VertexAI.vertexAI(region: "us-central1").generativeModel(modelName: "gemini-1.0-pro")
}

func reason() async {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,7 @@ class SummarizeViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = VertexAI.vertexAI().generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1"
)
model = VertexAI.vertexAI(region: "us-central1").generativeModel(modelName: "gemini-1.0-pro")
}

func summarize(inputText: String) async {
Expand Down
40 changes: 24 additions & 16 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,35 @@ public class VertexAI: NSObject {

/// The default `VertexAI` instance.
///
/// - Parameter region: The region identifier, e.g., `us-central1`; see
/// [Vertex AI
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
/// for a list of supported regions.
/// - Returns: An instance of `VertexAI`, configured with the default `FirebaseApp`.
public static func vertexAI() -> VertexAI {
public static func vertexAI(region: String) -> VertexAI {
guard let app = FirebaseApp.app() else {
fatalError("No instance of the default Firebase app was found.")
}

return vertexAI(app: app)
return vertexAI(app: app, region: region)
}

/// Creates an instance of `VertexAI` configured with a custom `FirebaseApp`.
///
/// - Parameter app: The custom `FirebaseApp` used for initialization.
/// - Parameters:
/// - app: The custom `FirebaseApp` used for initialization.
/// - region: The region identifier, e.g., `us-central1`; see
/// [Vertex AI
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
/// for a list of supported regions.
/// - Returns: A `VertexAI` instance, configured with the custom `FirebaseApp`.
public static func vertexAI(app: FirebaseApp) -> VertexAI {
public static func vertexAI(app: FirebaseApp, region: String) -> VertexAI {
guard let provider = ComponentType<VertexAIProvider>.instance(for: VertexAIProvider.self,
in: app.container) else {
fatalError("No \(VertexAIProvider.self) instance found for Firebase app: \(app.name)")
}

return provider.vertexAI()
return provider.vertexAI(region)
}

/// Initializes a generative model with the given parameters.
Expand All @@ -54,19 +63,15 @@ public class VertexAI: NSObject {
/// [Gemini
/// models](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models)
/// for a list of supported model names.
/// - location: The location identifier, e.g., `us-central1`; see
/// [Vertex AI
/// regions](https://cloud.google.com/vertex-ai/docs/general/locations#vertex-ai-regions)
/// for a list of supported locations.
/// - generationConfig: The content generation parameters your model should use.
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - requestOptions: Configuration parameters for sending requests to the backend.
public func generativeModel(modelName: String, location: String,
public func generativeModel(modelName: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
requestOptions: RequestOptions = RequestOptions())
-> GenerativeModel {
let modelResourceName = modelResourceName(modelName: modelName, location: location)
let modelResourceName = modelResourceName(modelName: modelName, region: region)

guard let apiKey = app.options.apiKey else {
fatalError("The Firebase app named \"\(app.name)\" has no API key in its configuration.")
Expand All @@ -89,26 +94,29 @@ public class VertexAI: NSObject {

private let appCheck: AppCheckInterop?

init(app: FirebaseApp) {
private let region: String

init(app: FirebaseApp, region: String) {
self.app = app
self.region = region
appCheck = ComponentType<AppCheckInterop>.instance(for: AppCheckInterop.self, in: app.container)
}

private func modelResourceName(modelName: String, location: String) -> String {
private func modelResourceName(modelName: String, region: String) -> String {
if modelName.contains("/") {
return modelName
}
guard let projectID = app.options.projectID else {
fatalError("The Firebase app named \"\(app.name)\" has no project ID in its configuration.")
}
guard !location.isEmpty else {
guard !region.isEmpty else {
fatalError("""
No location specified; see
No region specified; see
https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions for a
list of available regions.
""")
}

return "projects/\(projectID)/locations/\(location)/publishers/google/models/\(modelName)"
return "projects/\(projectID)/locations/\(region)/publishers/google/models/\(modelName)"
}
}
6 changes: 3 additions & 3 deletions FirebaseVertexAI/Sources/VertexAIComponent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Foundation
@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
@objc(FIRVertexAIProvider)
protocol VertexAIProvider {
@objc func vertexAI() -> VertexAI
@objc func vertexAI(_ location: String) -> VertexAI
}

@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *)
Expand Down Expand Up @@ -64,7 +64,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {

// MARK: - VertexAIProvider conformance

func vertexAI() -> VertexAI {
func vertexAI(_ region: String) -> VertexAI {
os_unfair_lock_lock(&instancesLock)

// Unlock before the function returns.
Expand All @@ -73,7 +73,7 @@ class VertexAIComponent: NSObject, Library, VertexAIProvider {
if let instance = instances[app.name] {
return instance
}
let newInstance = VertexAI(app: app)
let newInstance = VertexAI(app: app, region: region)
instances[app.name] = newInstance
return newInstance
}
Expand Down
9 changes: 3 additions & 6 deletions FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,28 @@ final class VertexAIAPITests: XCTestCase {
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]

// Instantiate Vertex AI SDK - Default App
let vertexAI = VertexAI.vertexAI()
let vertexAI = VertexAI.vertexAI(region: "my-region")

// Instantiate Vertex AI SDK - Custom App
let _ = VertexAI.vertexAI(app: app!)
let _ = VertexAI.vertexAI(app: app!, region: "my-region")

// Permutations without optional arguments.

let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro", location: "us-central1")
let _ = vertexAI.generativeModel(modelName: "gemini-1.0-pro")

let _ = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
safetySettings: filters
)

let _ = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
generationConfig: config
)

// All arguments passed.
let genAI = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
location: "us-central1",
generationConfig: config, // Optional
safetySettings: filters // Optional
)
Expand Down

0 comments on commit f85d31c

Please sign in to comment.