Skip to content

Commit

Permalink
Add prototype for calling Vertex AI API
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed Feb 1, 2024
1 parent ccf2c94 commit ee7174e
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 48 deletions.
44 changes: 14 additions & 30 deletions Examples/GenerativeAICLI/Sources/GenerateContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ import GoogleGenerativeAI

@main
struct GenerateContent: AsyncParsableCommand {
@Option(help: "The API key to use when calling the Generative Language API.")
@Option(help: "The access token to use when calling the Vertex AI API.")
var apiKey: String

@Option()
var projectID: String

@Option(name: .customLong("model"), help: "The name of the model to use (e.g., \"gemini-pro\").")
var modelName: String?

Expand All @@ -34,11 +37,6 @@ struct GenerateContent: AsyncParsableCommand {
)
var imageURL: URL?

@Flag(
name: .customLong("streaming"),
help: "Stream response data, printing it incrementally as it's received."
) var isStreaming = false

@Flag(
name: .customLong("GoogleGenerativeAIDebugLogEnabled", withSingleDash: true),
help: "Enable additional debug logging."
Expand All @@ -55,22 +53,12 @@ struct GenerateContent: AsyncParsableCommand {

mutating func run() async throws {
do {
let safetySettings = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockNone)]
// Let the server pick the default config.
let config = GenerationConfig(
temperature: 0.2,
topP: 0.1,
topK: 16,
candidateCount: 1,
maxOutputTokens: isStreaming ? nil : 256,
stopSequences: nil
)

let model = GenerativeModel(
name: modelNameOrDefault(),
apiKey: apiKey,
generationConfig: config,
safetySettings: safetySettings
projectID: projectID,
generationConfig: nil,
safetySettings: nil
)

var parts = [ModelContent.Part]()
Expand All @@ -95,18 +83,14 @@ struct GenerateContent: AsyncParsableCommand {

let input = [ModelContent(parts: parts)]

if isStreaming {
let contentStream = model.generateContentStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
if let text = content.text {
print(text)
}
}
} else {
let content = try await model.generateContent(input)
let countTokensResponse = try await model.countTokens(input)
print("Total Token Count: \(countTokensResponse.totalTokens)")

let contentStream = model.generateContentStream(input)
print("Generated Content <streaming>:")
for try await content in contentStream {
if let text = content.text {
print("Generated Content:\n\(text)")
print(text)
}
}
} catch {
Expand Down
19 changes: 19 additions & 0 deletions Examples/GenerativeAISample/APIKey/APIKey.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,22 @@ enum APIKey {
return value
}
}

enum ProjectID {
/// Fetch the Project ID from `GenerativeAI-Info.plist`
/// This is just *one* way how you can retrieve the API key for your app.
static var `default`: String {
guard let filePath = Bundle.main.path(forResource: "GenerativeAI-Info", ofType: "plist")
else {
fatalError("Couldn't find file 'GenerativeAI-Info.plist'.")
}
let plist = NSDictionary(contentsOfFile: filePath)
guard let value = plist?.object(forKey: "PROJECT_ID") as? String else {
fatalError("Couldn't find key 'PROJECT_ID' in 'GenerativeAI-Info.plist'.")
}
if value.starts(with: "_") || value.isEmpty {
fatalError("Invalid Project ID for Vertex AI.")
}
return value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ class ConversationViewModel: ObservableObject {
private var chatTask: Task<Void, Never>?

init() {
model = GenerativeModel(name: "gemini-pro", apiKey: APIKey.default)
model = GenerativeModel(
name: "gemini-pro",
apiKey: APIKey.default,
projectID: ProjectID.default
)
chat = model.startChat()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class PhotoReasoningViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = GenerativeModel(name: "gemini-pro-vision", apiKey: APIKey.default)
model = GenerativeModel(
name: "gemini-pro-vision",
apiKey: APIKey.default,
projectID: ProjectID.default
)
}

func reason() async {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ class SummarizeViewModel: ObservableObject {
private var model: GenerativeModel?

init() {
model = GenerativeModel(name: "gemini-pro", apiKey: APIKey.default)
model = GenerativeModel(
name: "gemini-pro",
apiKey: APIKey.default,
projectID: ProjectID.default
)
}

func summarize(inputText: String) async {
Expand Down
4 changes: 3 additions & 1 deletion Sources/GoogleAI/CountTokensRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ struct CountTokensRequest {
let model: String
let contents: [ModelContent]
let options: RequestOptions
let projectID: String
}

extension CountTokensRequest: Encodable {
Expand All @@ -30,7 +31,8 @@ extension CountTokensRequest: GenerativeAIRequest {
typealias Response = CountTokensResponse

var url: URL {
URL(string: "\(GenerativeAISwift.baseURL)/\(model):countTokens")!
let modelResource = "projects/\(projectID)/locations/us-central1/publishers/google/\(model)"
return URL(string: "\(GenerativeAISwift.baseURL)/\(modelResource):countTokens")!
}
}

Expand Down
8 changes: 6 additions & 2 deletions Sources/GoogleAI/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ struct GenerateContentRequest {
let safetySettings: [SafetySetting]?
let isStreaming: Bool
let options: RequestOptions
let projectID: String
}

extension GenerateContentRequest: Encodable {
Expand All @@ -36,10 +37,13 @@ extension GenerateContentRequest: GenerativeAIRequest {
typealias Response = GenerateContentResponse

var url: URL {
let modelResource = "projects/\(projectID)/locations/us-central1/publishers/google/\(model)"
if isStreaming {
URL(string: "\(GenerativeAISwift.baseURL)/\(model):streamGenerateContent?alt=sse")!
return URL(
string: "\(GenerativeAISwift.baseURL)/\(modelResource):streamGenerateContent?alt=sse"
)!
} else {
URL(string: "\(GenerativeAISwift.baseURL)/\(model):generateContent")!
return URL(string: "\(GenerativeAISwift.baseURL)/\(modelResource):generateContent")!
}
}
}
2 changes: 1 addition & 1 deletion Sources/GoogleAI/GenerativeAIService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ struct GenerativeAIService {
private func urlRequest<T: GenerativeAIRequest>(request: T) throws -> URLRequest {
var urlRequest = URLRequest(url: request.url)
urlRequest.httpMethod = "POST"
urlRequest.setValue(apiKey, forHTTPHeaderField: "x-goog-api-key")
urlRequest.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
urlRequest.setValue("genai-swift/\(GenerativeAISwift.version)",
forHTTPHeaderField: "x-goog-api-client")
urlRequest.setValue("application/json", forHTTPHeaderField: "Content-Type")
Expand Down
2 changes: 1 addition & 1 deletion Sources/GoogleAI/GenerativeAISwift.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ import Foundation
public enum GenerativeAISwift {
/// String value of the SDK version
public static let version = "0.4.7"
static let baseURL = "https://generativelanguage.googleapis.com/v1"
static let baseURL = "https://us-central1-aiplatform.googleapis.com/v1"
}
15 changes: 12 additions & 3 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public final class GenerativeModel {
/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService

let projectID: String

/// Configuration parameters used for the MultiModalModel.
let generationConfig: GenerationConfig?

Expand All @@ -45,12 +47,14 @@ public final class GenerativeModel {
/// - requestOptions Configuration parameters for sending requests to the backend.
public convenience init(name: String,
apiKey: String,
projectID: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
requestOptions: RequestOptions = RequestOptions()) {
self.init(
name: name,
apiKey: apiKey,
projectID: projectID,
generationConfig: generationConfig,
safetySettings: safetySettings,
requestOptions: requestOptions,
Expand All @@ -61,12 +65,14 @@ public final class GenerativeModel {
/// The designated initializer for this class.
init(name: String,
apiKey: String,
projectID: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
requestOptions: RequestOptions = RequestOptions(),
urlSession: URLSession) {
modelResourceName = GenerativeModel.modelResourceName(name: name)
generativeAIService = GenerativeAIService(apiKey: apiKey, urlSession: urlSession)
self.projectID = projectID
self.generationConfig = generationConfig
self.safetySettings = safetySettings
self.requestOptions = requestOptions
Expand Down Expand Up @@ -112,7 +118,8 @@ public final class GenerativeModel {
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: false,
options: requestOptions)
options: requestOptions,
projectID: projectID)
let response: GenerateContentResponse
do {
response = try await generativeAIService.loadRequest(request: generateContentRequest)
Expand Down Expand Up @@ -166,7 +173,8 @@ public final class GenerativeModel {
generationConfig: generationConfig,
safetySettings: safetySettings,
isStreaming: true,
options: requestOptions)
options: requestOptions,
projectID: projectID)

var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest)
.makeAsyncIterator()
Expand Down Expand Up @@ -233,7 +241,8 @@ public final class GenerativeModel {
let countTokensRequest = CountTokensRequest(
model: modelResourceName,
contents: content,
options: requestOptions
options: requestOptions,
projectID: projectID
)

do {
Expand Down
7 changes: 6 additions & 1 deletion Tests/GoogleAITests/ChatTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ final class ChatTests: XCTestCase {
return (response, fileURL.lines)
}

let model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
let model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
projectID: "test-project-id",
urlSession: urlSession
)
let chat = Chat(model: model, history: [])
let input = "Test input"
let stream = chat.sendMessageStream(input)
Expand Down
9 changes: 7 additions & 2 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ final class GenerativeModelTests: XCTestCase {
let configuration = URLSessionConfiguration.default
configuration.protocolClasses = [MockURLProtocol.self]
urlSession = try XCTUnwrap(URLSession(configuration: configuration))
model = GenerativeModel(name: "my-model", apiKey: "API_KEY", urlSession: urlSession)
model = GenerativeModel(
name: "my-model",
apiKey: "API_KEY",
projectID: "test-project-id",
urlSession: urlSession
)
}

override func tearDown() {
Expand Down Expand Up @@ -162,7 +167,7 @@ final class GenerativeModelTests: XCTestCase {
let model = GenerativeModel(
// Model name is prefixed with "models/".
name: "models/test-model",
apiKey: "API_KEY",
apiKey: "API_KEY", projectID: "test-project-id",
urlSession: urlSession
)

Expand Down
18 changes: 14 additions & 4 deletions Tests/GoogleAITests/GoogleAITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,23 @@ final class GoogleGenerativeAITests: XCTestCase {
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]

// Permutations without optional arguments.
let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY")
let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", safetySettings: filters)
let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", generationConfig: config)
let _ = GenerativeModel(name: "gemini-pro@001", apiKey: "API_KEY", projectID: "test-project-id")
let _ = GenerativeModel(
name: "gemini-pro@001",
apiKey: "API_KEY",
projectID: "test-project-id",
safetySettings: filters
)
let _ = GenerativeModel(
name: "gemini-pro@001",
apiKey: "API_KEY",
projectID: "test-project-id",
generationConfig: config
)

// All arguments passed.
let genAI = GenerativeModel(name: "gemini-pro@001",
apiKey: "API_KEY",
apiKey: "API_KEY", projectID: "test-project-id",
generationConfig: config, // Optional
safetySettings: filters // Optional
)
Expand Down

0 comments on commit ee7174e

Please sign in to comment.