diff --git a/Examples/GenerativeAICLI/Sources/GenerateContent.swift b/Examples/GenerativeAICLI/Sources/GenerateContent.swift index aace076..c2a8e46 100644 --- a/Examples/GenerativeAICLI/Sources/GenerateContent.swift +++ b/Examples/GenerativeAICLI/Sources/GenerateContent.swift @@ -18,9 +18,12 @@ import GoogleGenerativeAI @main struct GenerateContent: AsyncParsableCommand { - @Option(help: "The API key to use when calling the Generative Language API.") + @Option(help: "The access token to use when calling the Vertex AI API.") var apiKey: String + @Option() + var projectID: String + @Option(name: .customLong("model"), help: "The name of the model to use (e.g., \"gemini-pro\").") var modelName: String? @@ -34,11 +37,6 @@ struct GenerateContent: AsyncParsableCommand { ) var imageURL: URL? - @Flag( - name: .customLong("streaming"), - help: "Stream response data, printing it incrementally as it's received." - ) var isStreaming = false - @Flag( name: .customLong("GoogleGenerativeAIDebugLogEnabled", withSingleDash: true), help: "Enable additional debug logging." @@ -55,22 +53,12 @@ struct GenerateContent: AsyncParsableCommand { mutating func run() async throws { do { - let safetySettings = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone)] - // Let the server pick the default config. - let config = GenerationConfig( - temperature: 0.2, - topP: 0.1, - topK: 16, - candidateCount: 1, - maxOutputTokens: isStreaming ? nil : 256, - stopSequences: nil - ) - let model = GenerativeModel( name: modelNameOrDefault(), apiKey: apiKey, - generationConfig: config, - safetySettings: safetySettings + projectID: projectID, + generationConfig: nil, + safetySettings: nil ) var parts = [ModelContent.Part]() @@ -95,18 +83,14 @@ struct GenerateContent: AsyncParsableCommand { let input = [ModelContent(parts: parts)] - if isStreaming { - let contentStream = model.generateContentStream(input) - print("Generated Content :") - for try await content in contentStream { - if let text = content.text { - print(text) - } - } - } else { - let content = try await model.generateContent(input) + let countTokensResponse = try await model.countTokens(input) + print("Total Token Count: \(countTokensResponse.totalTokens)") + + let contentStream = model.generateContentStream(input) + print("Generated Content :") + for try await content in contentStream { if let text = content.text { - print("Generated Content:\n\(text)") + print(text) } } } catch { diff --git a/Examples/GenerativeAISample/APIKey/APIKey.swift b/Examples/GenerativeAISample/APIKey/APIKey.swift index 3f458ca..b8ca347 100644 --- a/Examples/GenerativeAISample/APIKey/APIKey.swift +++ b/Examples/GenerativeAISample/APIKey/APIKey.swift @@ -34,3 +34,22 @@ enum APIKey { return value } } + +enum ProjectID { + /// Fetch the Project ID from `GenerativeAI-Info.plist` + /// This is just *one* way how you can retrieve the API key for your app. + static var `default`: String { + guard let filePath = Bundle.main.path(forResource: "GenerativeAI-Info", ofType: "plist") + else { + fatalError("Couldn't find file 'GenerativeAI-Info.plist'.") + } + let plist = NSDictionary(contentsOfFile: filePath) + guard let value = plist?.object(forKey: "PROJECT_ID") as? String else { + fatalError("Couldn't find key 'PROJECT_ID' in 'GenerativeAI-Info.plist'.") + } + if value.starts(with: "_") || value.isEmpty { + fatalError("Invalid Project ID for Vertex AI.") + } + return value + } +} diff --git a/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift b/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift index 0c59e82..7d80663 100644 --- a/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift +++ b/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift @@ -36,7 +36,11 @@ class ConversationViewModel: ObservableObject { private var chatTask: Task? init() { - model = GenerativeModel(name: "gemini-pro", apiKey: APIKey.default) + model = GenerativeModel( + name: "gemini-pro", + apiKey: APIKey.default, + projectID: ProjectID.default + ) chat = model.startChat() } diff --git a/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index dc98bb8..2e579ef 100644 --- a/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -44,7 +44,11 @@ class PhotoReasoningViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = GenerativeModel(name: "gemini-pro-vision", apiKey: APIKey.default) + model = GenerativeModel( + name: "gemini-pro-vision", + apiKey: APIKey.default, + projectID: ProjectID.default + ) } func reason() async { diff --git a/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift b/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift index 55289e4..211b662 100644 --- a/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift +++ b/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift @@ -32,7 +32,11 @@ class SummarizeViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = GenerativeModel(name: "gemini-pro", apiKey: APIKey.default) + model = GenerativeModel( + name: "gemini-pro", + apiKey: APIKey.default, + projectID: ProjectID.default + ) } func summarize(inputText: String) async { diff --git a/Sources/GoogleAI/CountTokensRequest.swift b/Sources/GoogleAI/CountTokensRequest.swift index ddd146a..d81bb47 100644 --- a/Sources/GoogleAI/CountTokensRequest.swift +++ b/Sources/GoogleAI/CountTokensRequest.swift @@ -18,6 +18,7 @@ struct CountTokensRequest { let model: String let contents: [ModelContent] let options: RequestOptions + let projectID: String } extension CountTokensRequest: Encodable { @@ -30,7 +31,8 @@ extension CountTokensRequest: GenerativeAIRequest { typealias Response = CountTokensResponse var url: URL { - URL(string: "\(GenerativeAISwift.baseURL)/\(model):countTokens")! + let modelResource = "projects/\(projectID)/locations/us-central1/publishers/google/\(model)" + return URL(string: "\(GenerativeAISwift.baseURL)/\(modelResource):countTokens")! } } diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index 5dc8f11..d99b07b 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -22,6 +22,7 @@ struct GenerateContentRequest { let safetySettings: [SafetySetting]? let isStreaming: Bool let options: RequestOptions + let projectID: String } extension GenerateContentRequest: Encodable { @@ -36,10 +37,13 @@ extension GenerateContentRequest: GenerativeAIRequest { typealias Response = GenerateContentResponse var url: URL { + let modelResource = "projects/\(projectID)/locations/us-central1/publishers/google/\(model)" if isStreaming { - URL(string: "\(GenerativeAISwift.baseURL)/\(model):streamGenerateContent?alt=sse")! + return URL( + string: "\(GenerativeAISwift.baseURL)/\(modelResource):streamGenerateContent?alt=sse" + )! } else { - URL(string: "\(GenerativeAISwift.baseURL)/\(model):generateContent")! + return URL(string: "\(GenerativeAISwift.baseURL)/\(modelResource):generateContent")! } } } diff --git a/Sources/GoogleAI/GenerativeAIService.swift b/Sources/GoogleAI/GenerativeAIService.swift index a92095d..fb184d9 100644 --- a/Sources/GoogleAI/GenerativeAIService.swift +++ b/Sources/GoogleAI/GenerativeAIService.swift @@ -148,7 +148,7 @@ struct GenerativeAIService { private func urlRequest(request: T) throws -> URLRequest { var urlRequest = URLRequest(url: request.url) urlRequest.httpMethod = "POST" - urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key") + urlRequest.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") urlRequest.setValue("genai-swift/\(GenerativeAISwift.version)", forHTTPHeaderField: "x-goog-api-client") urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") diff --git a/Sources/GoogleAI/GenerativeAISwift.swift b/Sources/GoogleAI/GenerativeAISwift.swift index 0c2cd31..290ffcd 100644 --- a/Sources/GoogleAI/GenerativeAISwift.swift +++ b/Sources/GoogleAI/GenerativeAISwift.swift @@ -21,5 +21,5 @@ import Foundation public enum GenerativeAISwift { /// String value of the SDK version public static let version = "0.4.7" - static let baseURL = "https://generativelanguage.googleapis.com/v1" + static let baseURL = "https://us-central1-aiplatform.googleapis.com/v1" } diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index c43ca14..4ee27f3 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -26,6 +26,8 @@ public final class GenerativeModel { /// The backing service responsible for sending and receiving model requests to the backend. let generativeAIService: GenerativeAIService + let projectID: String + /// Configuration parameters used for the MultiModalModel. let generationConfig: GenerationConfig? @@ -45,12 +47,14 @@ public final class GenerativeModel { /// - requestOptions Configuration parameters for sending requests to the backend. public convenience init(name: String, apiKey: String, + projectID: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, requestOptions: RequestOptions = RequestOptions()) { self.init( name: name, apiKey: apiKey, + projectID: projectID, generationConfig: generationConfig, safetySettings: safetySettings, requestOptions: requestOptions, @@ -61,12 +65,14 @@ public final class GenerativeModel { /// The designated initializer for this class. init(name: String, apiKey: String, + projectID: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, requestOptions: RequestOptions = RequestOptions(), urlSession: URLSession) { modelResourceName = GenerativeModel.modelResourceName(name: name) generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession) + self.projectID = projectID self.generationConfig = generationConfig self.safetySettings = safetySettings self.requestOptions = requestOptions @@ -112,7 +118,8 @@ public final class GenerativeModel { generationConfig: generationConfig, safetySettings: safetySettings, isStreaming: false, - options: requestOptions) + options: requestOptions, + projectID: projectID) let response: GenerateContentResponse do { response = try await generativeAIService.loadRequest(request: generateContentRequest) @@ -166,7 +173,8 @@ public final class GenerativeModel { generationConfig: generationConfig, safetySettings: safetySettings, isStreaming: true, - options: requestOptions) + options: requestOptions, + projectID: projectID) var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) .makeAsyncIterator() @@ -233,7 +241,8 @@ public final class GenerativeModel { let countTokensRequest = CountTokensRequest( model: modelResourceName, contents: content, - options: requestOptions + options: requestOptions, + projectID: projectID ) do { diff --git a/Tests/GoogleAITests/ChatTests.swift b/Tests/GoogleAITests/ChatTests.swift index 4020d4b..75e361d 100644 --- a/Tests/GoogleAITests/ChatTests.swift +++ b/Tests/GoogleAITests/ChatTests.swift @@ -46,7 +46,12 @@ final class ChatTests: XCTestCase { return (response, fileURL.lines) } - let model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession) + let model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + projectID: "test-project-id", + urlSession: urlSession + ) let chat = Chat(model: model, history: []) let input = "Test input" let stream = chat.sendMessageStream(input) diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 142341b..9841035 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -32,7 +32,12 @@ final class GenerativeModelTests: XCTestCase { let configuration = URLSessionConfiguration.default configuration.protocolClasses = [MockURLProtocol.self] urlSession = try XCTUnwrap(URLSession(configuration: configuration)) - model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession) + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + projectID: "test-project-id", + urlSession: urlSession + ) } override func tearDown() { @@ -162,7 +167,7 @@ final class GenerativeModelTests: XCTestCase { let model = GenerativeModel( // Model name is prefixed with "models/". name: "models/test-model", - apiKey: "API_KEY", + apiKey: "API_KEY", projectID: "test-project-id", urlSession: urlSession ) @@ -467,6 +472,7 @@ final class GenerativeModelTests: XCTestCase { model = GenerativeModel( name: "my-model", apiKey: "API_KEY", + projectID: "test-project", requestOptions: requestOptions, urlSession: urlSession ) @@ -779,6 +785,7 @@ final class GenerativeModelTests: XCTestCase { model = GenerativeModel( name: "my-model", apiKey: "API_KEY", + projectID: "test-project", requestOptions: requestOptions, urlSession: urlSession ) @@ -836,6 +843,7 @@ final class GenerativeModelTests: XCTestCase { model = GenerativeModel( name: "my-model", apiKey: "API_KEY", + projectID: "test-project", requestOptions: requestOptions, urlSession: urlSession ) diff --git a/Tests/GoogleAITests/GoogleAITests.swift b/Tests/GoogleAITests/GoogleAITests.swift index 6389ddb..2c92012 100644 --- a/Tests/GoogleAITests/GoogleAITests.swift +++ b/Tests/GoogleAITests/GoogleAITests.swift @@ -31,13 +31,23 @@ final class GoogleGenerativeAITests: XCTestCase { let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] // Permutations without optional arguments. - let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY") - let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", safetySettings: filters) - let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", generationConfig: config) + let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", projectID: "test-project-id") + let _ = GenerativeModel( + name: "gemini-pro@001", + apiKey: "API_KEY", + projectID: "test-project-id", + safetySettings: filters + ) + let _ = GenerativeModel( + name: "gemini-pro@001", + apiKey: "API_KEY", + projectID: "test-project-id", + generationConfig: config + ) // All arguments passed. let genAI = GenerativeModel(name: "gemini-pro@001", - apiKey: "API_KEY", + apiKey: "API_KEY", projectID: "test-project-id", generationConfig: config, // Optional safetySettings: filters // Optional )