Skip to content

Commit

Permalink
Add system instruction support in Vertex AI (#12749)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewheard committed May 1, 2024
1 parent f6dec4e commit f4fa5f5
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
2 changes: 2 additions & 0 deletions FirebaseVertexAI/Sources/GenerateContentRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct GenerateContentRequest {
let safetySettings: [SafetySetting]?
let tools: [Tool]?
let toolConfig: ToolConfig?
let systemInstruction: ModelContent?
let isStreaming: Bool
let options: RequestOptions
}
Expand All @@ -35,6 +36,7 @@ extension GenerateContentRequest: Encodable {
case safetySettings
case tools
case toolConfig
case systemInstruction
}
}

Expand Down
9 changes: 9 additions & 0 deletions FirebaseVertexAI/Sources/GenerativeModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ public final class GenerativeModel {
/// Tool configuration for any `Tool` specified in the request.
let toolConfig: ToolConfig?

/// Instructions that direct the model to behave a certain way.
let systemInstruction: ModelContent?

/// Configuration parameters for sending requests to the backend.
let requestOptions: RequestOptions

Expand All @@ -53,6 +56,8 @@ public final class GenerativeModel {
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
/// - toolConfig: Tool configuration for any `Tool` specified in the request.
/// - systemInstruction: Instructions that direct the model to behave a certain way; currently
/// only text content is supported.
/// - requestOptions: Configuration parameters for sending requests to the backend.
/// - urlSession: The `URLSession` to use for requests; defaults to `URLSession.shared`.
init(name: String,
Expand All @@ -61,6 +66,7 @@ public final class GenerativeModel {
safetySettings: [SafetySetting]? = nil,
tools: [Tool]?,
toolConfig: ToolConfig? = nil,
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions,
appCheck: AppCheckInterop?,
urlSession: URLSession = .shared) {
Expand All @@ -74,6 +80,7 @@ public final class GenerativeModel {
self.safetySettings = safetySettings
self.tools = tools
self.toolConfig = toolConfig
self.systemInstruction = systemInstruction
self.requestOptions = requestOptions

Logging.default.info("""
Expand Down Expand Up @@ -121,6 +128,7 @@ public final class GenerativeModel {
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
isStreaming: false,
options: requestOptions)
response = try await generativeAIService.loadRequest(request: generateContentRequest)
Expand Down Expand Up @@ -194,6 +202,7 @@ public final class GenerativeModel {
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
isStreaming: true,
options: requestOptions)

Expand Down
4 changes: 4 additions & 0 deletions FirebaseVertexAI/Sources/VertexAI.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ public class VertexAI: NSObject {
/// - safetySettings: A value describing what types of harmful content your model should allow.
/// - tools: A list of ``Tool`` objects that the model may use to generate the next response.
/// - toolConfig: Tool configuration for any `Tool` specified in the request.
/// - systemInstruction: Instructions that direct the model to behave a certain way; currently
/// only text content is supported.
/// - requestOptions: Configuration parameters for sending requests to the backend.
public func generativeModel(modelName: String,
generationConfig: GenerationConfig? = nil,
safetySettings: [SafetySetting]? = nil,
tools: [Tool]? = nil,
toolConfig: ToolConfig? = nil,
systemInstruction: ModelContent? = nil,
requestOptions: RequestOptions = RequestOptions())
-> GenerativeModel {
let modelResourceName = modelResourceName(modelName: modelName, location: location)
Expand All @@ -80,6 +83,7 @@ public class VertexAI: NSObject {
safetySettings: safetySettings,
tools: tools,
toolConfig: toolConfig,
systemInstruction: systemInstruction,
requestOptions: requestOptions,
appCheck: appCheck
)
Expand Down
9 changes: 8 additions & 1 deletion FirebaseVertexAI/Tests/Unit/VertexAIAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ final class VertexAIAPITests: XCTestCase {
maxOutputTokens: 256,
stopSequences: ["..."])
let filters = [SafetySetting(harmCategory: .dangerousContent, threshold: .blockOnlyHigh)]
let systemInstruction = ModelContent(role: "system", parts: [.text("Talk like a pirate.")])

// Instantiate Vertex AI SDK - Default App
let vertexAI = VertexAI.vertexAI()
Expand All @@ -53,11 +54,17 @@ final class VertexAIAPITests: XCTestCase {
generationConfig: config
)

let _ = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
systemInstruction: systemInstruction
)

// All arguments passed.
let genAI = vertexAI.generativeModel(
modelName: "gemini-1.0-pro",
generationConfig: config, // Optional
safetySettings: filters // Optional
safetySettings: filters, // Optional
systemInstruction: systemInstruction // Optional
)

// Full Typed Usage
Expand Down

0 comments on commit f4fa5f5

Please sign in to comment.