Skip to content

Commit

Permalink
Add support for additional model name prefixes (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard authored Feb 13, 2024
1 parent 90a1c6a commit d60695f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
7 changes: 5 additions & 2 deletions Sources/GoogleAI/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ public final class GenerativeModel {
// The prefix for a model resource in the Gemini API.
private static let modelResourcePrefix = "models/"

// The prefix for a tuned model resource in the Gemini API.
private static let tunedModelResourcePrefix = "tunedModels/"

/// The resource name of the model in the backend; has the format "models/model-name".
private let modelResourceName: String
let modelResourceName: String

/// The backing service responsible for sending and receiving model requests to the backend.
let generativeAIService: GenerativeAIService
Expand Down Expand Up @@ -246,7 +249,7 @@ public final class GenerativeModel {

/// Returns a model resource name of the form "models/model-name" based on `name`.
private static func modelResourceName(name: String) -> String {
if name.hasPrefix(modelResourcePrefix) {
if name.contains("/") {
return name
} else {
return modelResourcePrefix + name
Expand Down
27 changes: 27 additions & 0 deletions Tests/GoogleAITests/GenerativeModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,33 @@ final class GenerativeModelTests: XCTestCase {
XCTAssertEqual(response.totalTokens, 6)
}

// MARK: - Model Resource Name

func testModelResourceName_noPrefix() async throws {
let modelName = "my-model"
let modelResourceName = "models/\(modelName)"

model = GenerativeModel(name: modelName, apiKey: "API_KEY")

XCTAssertEqual(model.modelResourceName, modelResourceName)
}

func testModelResourceName_modelsPrefix() async throws {
let modelResourceName = "models/my-model"

model = GenerativeModel(name: modelResourceName, apiKey: "API_KEY")

XCTAssertEqual(model.modelResourceName, modelResourceName)
}

func testModelResourceName_tunedModelsPrefix() async throws {
let tunedModelResourceName = "tunedModels/my-model"

model = GenerativeModel(name: tunedModelResourceName, apiKey: "API_KEY")

XCTAssertEqual(model.modelResourceName, tunedModelResourceName)
}

// MARK: - Helpers

private func nonHTTPRequestHandler() throws -> ((URLRequest) -> (
Expand Down

0 comments on commit d60695f

Please sign in to comment.