From 257cc9b2827c07b9981bb91ee556a10a804da161 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 30 Jan 2024 11:20:12 -0500 Subject: [PATCH] Add prototype for calling Vertex AI API --- .../Sources/GenerateContent.swift | 28 +++++--------- .../GenerativeAISample/APIKey/APIKey.swift | 19 ++++++++++ .../ViewModels/ConversationViewModel.swift | 6 ++- .../ViewModels/PhotoReasoningViewModel.swift | 6 ++- .../ViewModels/SummarizeViewModel.swift | 6 ++- Sources/GoogleAI/CountTokensRequest.swift | 4 +- Sources/GoogleAI/GenerateContentRequest.swift | 8 +++- Sources/GoogleAI/GenerativeAIService.swift | 2 +- Sources/GoogleAI/GenerativeAISwift.swift | 2 +- Sources/GoogleAI/GenerativeModel.swift | 38 +++++++++++++------ Tests/GoogleAITests/ChatTests.swift | 7 +++- .../GoogleAITests/GenerativeModelTests.swift | 9 ++++- Tests/GoogleAITests/GoogleAITests.swift | 18 +++++++-- 13 files changed, 109 insertions(+), 44 deletions(-) diff --git a/Examples/GenerativeAICLI/Sources/GenerateContent.swift b/Examples/GenerativeAICLI/Sources/GenerateContent.swift index aace076..0c0871e 100644 --- a/Examples/GenerativeAICLI/Sources/GenerateContent.swift +++ b/Examples/GenerativeAICLI/Sources/GenerateContent.swift @@ -18,9 +18,12 @@ import GoogleGenerativeAI @main struct GenerateContent: AsyncParsableCommand { - @Option(help: "The API key to use when calling the Generative Language API.") + @Option(help: "The access token to use when calling the Vertex AI API.") var apiKey: String + @Option() + var projectID: String + @Option(name: .customLong("model"), help: "The name of the model to use (e.g., \"gemini-pro\").") var modelName: String? @@ -34,11 +37,6 @@ struct GenerateContent: AsyncParsableCommand { ) var imageURL: URL? - @Flag( - name: .customLong("streaming"), - help: "Stream response data, printing it incrementally as it's received." - ) var isStreaming = false - @Flag( name: .customLong("GoogleGenerativeAIDebugLogEnabled", withSingleDash: true), help: "Enable additional debug logging." @@ -62,13 +60,14 @@ struct GenerateContent: AsyncParsableCommand { topP: 0.1, topK: 16, candidateCount: 1, - maxOutputTokens: isStreaming ? nil : 256, + maxOutputTokens: nil, stopSequences: nil ) let model = GenerativeModel( name: modelNameOrDefault(), apiKey: apiKey, + projectID: projectID, generationConfig: config, safetySettings: safetySettings ) @@ -95,18 +94,11 @@ struct GenerateContent: AsyncParsableCommand { let input = [ModelContent(parts: parts)] - if isStreaming { - let contentStream = model.generateContentStream(input) - print("Generated Content :") - for try await content in contentStream { - if let text = content.text { - print(text) - } - } - } else { - let content = try await model.generateContent(input) + let contentStream = model.generateContentStream(input) + print("Generated Content :") + for try await content in contentStream { if let text = content.text { - print("Generated Content:\n\(text)") + print(text) } } } catch { diff --git a/Examples/GenerativeAISample/APIKey/APIKey.swift b/Examples/GenerativeAISample/APIKey/APIKey.swift index 3f458ca..b8ca347 100644 --- a/Examples/GenerativeAISample/APIKey/APIKey.swift +++ b/Examples/GenerativeAISample/APIKey/APIKey.swift @@ -34,3 +34,22 @@ enum APIKey { return value } } + +enum ProjectID { + /// Fetch the Project ID from `GenerativeAI-Info.plist` + /// This is just *one* way how you can retrieve the API key for your app. + static var `default`: String { + guard let filePath = Bundle.main.path(forResource: "GenerativeAI-Info", ofType: "plist") + else { + fatalError("Couldn't find file 'GenerativeAI-Info.plist'.") + } + let plist = NSDictionary(contentsOfFile: filePath) + guard let value = plist?.object(forKey: "PROJECT_ID") as? String else { + fatalError("Couldn't find key 'PROJECT_ID' in 'GenerativeAI-Info.plist'.") + } + if value.starts(with: "_") || value.isEmpty { + fatalError("Invalid Project ID for Vertex AI.") + } + return value + } +} diff --git a/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift b/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift index 0c59e82..7d80663 100644 --- a/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift +++ b/Examples/GenerativeAISample/ChatSample/ViewModels/ConversationViewModel.swift @@ -36,7 +36,11 @@ class ConversationViewModel: ObservableObject { private var chatTask: Task? init() { - model = GenerativeModel(name: "gemini-pro", apiKey: APIKey.default) + model = GenerativeModel( + name: "gemini-pro", + apiKey: APIKey.default, + projectID: ProjectID.default + ) chat = model.startChat() } diff --git a/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift b/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift index dc98bb8..2e579ef 100644 --- a/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift +++ b/Examples/GenerativeAISample/GenerativeAIMultimodalSample/ViewModels/PhotoReasoningViewModel.swift @@ -44,7 +44,11 @@ class PhotoReasoningViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = GenerativeModel(name: "gemini-pro-vision", apiKey: APIKey.default) + model = GenerativeModel( + name: "gemini-pro-vision", + apiKey: APIKey.default, + projectID: ProjectID.default + ) } func reason() async { diff --git a/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift b/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift index 55289e4..211b662 100644 --- a/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift +++ b/Examples/GenerativeAISample/GenerativeAITextSample/ViewModels/SummarizeViewModel.swift @@ -32,7 +32,11 @@ class SummarizeViewModel: ObservableObject { private var model: GenerativeModel? init() { - model = GenerativeModel(name: "gemini-pro", apiKey: APIKey.default) + model = GenerativeModel( + name: "gemini-pro", + apiKey: APIKey.default, + projectID: ProjectID.default + ) } func summarize(inputText: String) async { diff --git a/Sources/GoogleAI/CountTokensRequest.swift b/Sources/GoogleAI/CountTokensRequest.swift index a8705fe..d94f446 100644 --- a/Sources/GoogleAI/CountTokensRequest.swift +++ b/Sources/GoogleAI/CountTokensRequest.swift @@ -17,6 +17,7 @@ import Foundation struct CountTokensRequest { let model: String let contents: [ModelContent] + let projectID: String } extension CountTokensRequest: Encodable { @@ -29,7 +30,8 @@ extension CountTokensRequest: GenerativeAIRequest { typealias Response = CountTokensResponse var url: URL { - URL(string: "\(GenerativeAISwift.baseURL)/\(model):countTokens")! + let modelResource = "projects/\(projectID)/locations/us-central1/publishers/google/\(model)" + return URL(string: "\(GenerativeAISwift.baseURL)/\(modelResource):countTokens")! } } diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index 6629056..72be021 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -21,6 +21,7 @@ struct GenerateContentRequest { let generationConfig: GenerationConfig? let safetySettings: [SafetySetting]? let isStreaming: Bool + let projectID: String } extension GenerateContentRequest: Encodable { @@ -35,10 +36,13 @@ extension GenerateContentRequest: GenerativeAIRequest { typealias Response = GenerateContentResponse var url: URL { + let modelResource = "projects/\(projectID)/locations/us-central1/publishers/google/\(model)" if isStreaming { - URL(string: "\(GenerativeAISwift.baseURL)/\(model):streamGenerateContent?alt=sse")! + return URL( + string: "\(GenerativeAISwift.baseURL)/\(modelResource):streamGenerateContent?alt=sse" + )! } else { - URL(string: "\(GenerativeAISwift.baseURL)/\(model):generateContent")! + return URL(string: "\(GenerativeAISwift.baseURL)/\(modelResource):generateContent")! } } } diff --git a/Sources/GoogleAI/GenerativeAIService.swift b/Sources/GoogleAI/GenerativeAIService.swift index 95d39d8..0395c01 100644 --- a/Sources/GoogleAI/GenerativeAIService.swift +++ b/Sources/GoogleAI/GenerativeAIService.swift @@ -148,7 +148,7 @@ struct GenerativeAIService { private func urlRequest(request: T) throws -> URLRequest { var urlRequest = URLRequest(url: request.url) urlRequest.httpMethod = "POST" - urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key") + urlRequest.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") urlRequest.setValue("genai-swift/\(GenerativeAISwift.version)", forHTTPHeaderField: "x-goog-api-client") urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type") diff --git a/Sources/GoogleAI/GenerativeAISwift.swift b/Sources/GoogleAI/GenerativeAISwift.swift index 0c2cd31..290ffcd 100644 --- a/Sources/GoogleAI/GenerativeAISwift.swift +++ b/Sources/GoogleAI/GenerativeAISwift.swift @@ -21,5 +21,5 @@ import Foundation public enum GenerativeAISwift { /// String value of the SDK version public static let version = "0.4.7" - static let baseURL = "https://generativelanguage.googleapis.com/v1" + static let baseURL = "https://us-central1-aiplatform.googleapis.com/v1" } diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index f387378..69b72b7 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -26,6 +26,8 @@ public final class GenerativeModel { /// The backing service responsible for sending and receiving model requests to the backend. let generativeAIService: GenerativeAIService + let projectID: String + /// Configuration parameters used for the MultiModalModel. let generationConfig: GenerationConfig? @@ -42,11 +44,13 @@ public final class GenerativeModel { /// should allow. public convenience init(name: String, apiKey: String, + projectID: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil) { self.init( name: name, apiKey: apiKey, + projectID: projectID, generationConfig: generationConfig, safetySettings: safetySettings, urlSession: .shared @@ -56,11 +60,13 @@ public final class GenerativeModel { /// The designated initializer for this class. init(name: String, apiKey: String, + projectID: String, generationConfig: GenerationConfig? = nil, safetySettings: [SafetySetting]? = nil, urlSession: URLSession) { modelResourceName = GenerativeModel.modelResourceName(name: name) generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession) + self.projectID = projectID self.generationConfig = generationConfig self.safetySettings = safetySettings @@ -99,11 +105,14 @@ public final class GenerativeModel { /// - Returns: The generated content response from the model. /// - Throws: A ``GenerateContentError`` if the request failed. public func generateContent(_ content: [ModelContent]) async throws -> GenerateContentResponse { - let generateContentRequest = GenerateContentRequest(model: modelResourceName, - contents: content, - generationConfig: generationConfig, - safetySettings: safetySettings, - isStreaming: false) + let generateContentRequest = GenerateContentRequest( + model: modelResourceName, + contents: content, + generationConfig: generationConfig, + safetySettings: safetySettings, + isStreaming: false, + projectID: projectID + ) let response: GenerateContentResponse do { response = try await generativeAIService.loadRequest(request: generateContentRequest) @@ -152,11 +161,14 @@ public final class GenerativeModel { @available(macOS 12.0, *) public func generateContentStream(_ content: [ModelContent]) -> AsyncThrowingStream { - let generateContentRequest = GenerateContentRequest(model: modelResourceName, - contents: content, - generationConfig: generationConfig, - safetySettings: safetySettings, - isStreaming: true) + let generateContentRequest = GenerateContentRequest( + model: modelResourceName, + contents: content, + generationConfig: generationConfig, + safetySettings: safetySettings, + isStreaming: true, + projectID: projectID + ) var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) .makeAsyncIterator() @@ -220,7 +232,11 @@ public final class GenerativeModel { /// - Throws: A ``CountTokensError`` if the tokenization request failed. public func countTokens(_ content: [ModelContent]) async throws -> CountTokensResponse { - let countTokensRequest = CountTokensRequest(model: modelResourceName, contents: content) + let countTokensRequest = CountTokensRequest( + model: modelResourceName, + contents: content, + projectID: projectID + ) do { return try await generativeAIService.loadRequest(request: countTokensRequest) diff --git a/Tests/GoogleAITests/ChatTests.swift b/Tests/GoogleAITests/ChatTests.swift index 4020d4b..75e361d 100644 --- a/Tests/GoogleAITests/ChatTests.swift +++ b/Tests/GoogleAITests/ChatTests.swift @@ -46,7 +46,12 @@ final class ChatTests: XCTestCase { return (response, fileURL.lines) } - let model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession) + let model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + projectID: "test-project-id", + urlSession: urlSession + ) let chat = Chat(model: model, history: []) let input = "Test input" let stream = chat.sendMessageStream(input) diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 64c0398..3e604ac 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -32,7 +32,12 @@ final class GenerativeModelTests: XCTestCase { let configuration = URLSessionConfiguration.default configuration.protocolClasses = [MockURLProtocol.self] urlSession = try XCTUnwrap(URLSession(configuration: configuration)) - model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession) + model = GenerativeModel( + name: "my-model", + apiKey: "API_KEY", + projectID: "test-project-id", + urlSession: urlSession + ) } override func tearDown() { @@ -162,7 +167,7 @@ final class GenerativeModelTests: XCTestCase { let model = GenerativeModel( // Model name is prefixed with "models/". name: "models/test-model", - apiKey: "API_KEY", + apiKey: "API_KEY", projectID: "test-project-id", urlSession: urlSession ) diff --git a/Tests/GoogleAITests/GoogleAITests.swift b/Tests/GoogleAITests/GoogleAITests.swift index 6389ddb..2c92012 100644 --- a/Tests/GoogleAITests/GoogleAITests.swift +++ b/Tests/GoogleAITests/GoogleAITests.swift @@ -31,13 +31,23 @@ final class GoogleGenerativeAITests: XCTestCase { let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)] // Permutations without optional arguments. - let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY") - let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", safetySettings: filters) - let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", generationConfig: config) + let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", projectID: "test-project-id") + let _ = GenerativeModel( + name: "gemini-pro@001", + apiKey: "API_KEY", + projectID: "test-project-id", + safetySettings: filters + ) + let _ = GenerativeModel( + name: "gemini-pro@001", + apiKey: "API_KEY", + projectID: "test-project-id", + generationConfig: config + ) // All arguments passed. let genAI = GenerativeModel(name: "gemini-pro@001", - apiKey: "API_KEY", + apiKey: "API_KEY", projectID: "test-project-id", generationConfig: config, // Optional safetySettings: filters // Optional )