Skip to content

Commit

Permalink
Add code execution tool
Browse files Browse the repository at this point in the history
  • Loading branch information
longseespace committed Jun 29, 2024
1 parent 5478400 commit dad49b6
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Sources/GoogleAI/Chat.swift
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ public class Chat {
case let .text(str):
combinedText += str

case .data, .fileData, .functionCall, .functionResponse:
case .data, .fileData, .functionCall, .functionResponse, .executableCode, .codeExecutionResult:
// Don't combine it, just add to the content. If there's any text pending, add that as
// a part.
if !combinedText.isEmpty {
Expand Down
71 changes: 69 additions & 2 deletions Sources/GoogleAI/FunctionCalling.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,20 @@ public struct FunctionDeclaration {
}
}

public struct CodeExecutionDeclaration {

}

/// Helper tools that the model may use to generate response.
///
/// A `Tool` is a piece of code that enables the system to interact with external systems to
/// perform an action, or set of actions, outside of knowledge and scope of the model.
public struct Tool {
/// A list of `FunctionDeclarations` available to the model.
let functionDeclarations: [FunctionDeclaration]?

/// Code Execution Tool
let codeExecution: CodeExecutionDeclaration?

/// Constructs a new `Tool`.
///
Expand All @@ -172,8 +179,9 @@ public struct Tool {
/// populating ``FunctionCall`` in the response. The next conversation turn may contain a
/// ``FunctionResponse`` in ``ModelContent/Part/functionResponse(_:)`` with the
/// ``ModelContent/role`` "function", providing generation context for the next model turn.
public init(functionDeclarations: [FunctionDeclaration]?) {
public init(functionDeclarations: [FunctionDeclaration]?, codeExecutionDeclaration: CodeExecutionDeclaration? = nil) {
self.functionDeclarations = functionDeclarations
self.codeExecution = codeExecutionDeclaration
}
}

Expand Down Expand Up @@ -244,6 +252,50 @@ public struct FunctionResponse: Equatable {
}
}


public struct ExecutableCode: Equatable {
/// The language of the executable code.
public let language: String

/// The source code.
public let code: String

/// Constructs a new `ExecutableCode`.
///
/// - Parameters:
/// - language: The language of the executable code.
/// - code: The source code.
public init(language: String, code: String) {
self.language = language
self.code = code
}
}

public struct CodeExecutionResult: Equatable {
/// Outcome of the code execution
public let outcome: Outcome

/// Contains stdout when code execution is successful, stderr or other description otherwise.
public let output: String?

public enum Outcome: String, CaseIterable {
case unspecified = "OUTCOME_UNSPECIFIED"
case ok = "OUTCOME_OK"
case failed = "OUTCOME_FAILED"
case deadlineExceeded = "OUTCOME_DEADLINE_EXCEEDED"
}

/// Constructs a new `CodeExecutionResult`.
///
/// - Parameters:
/// - outcome: Outcome of the code execution
/// - output: Contains stdout when code execution is successful, stderr or other description otherwise.
public init(outcome: Outcome, output: String?) {
self.outcome = outcome
self.output = output
}
}

// MARK: - Codable Conformance

extension FunctionCall: Decodable {
Expand Down Expand Up @@ -284,7 +336,14 @@ extension Schema: Encodable {}

extension DataType: Encodable {}

extension Tool: Encodable {}
extension Tool: Encodable {
public static let codeExecution = Tool(functionDeclarations: nil, codeExecutionDeclaration: .init())

enum CodingKeys: String, CodingKey {
case functionDeclarations
case codeExecution
}
}

extension FunctionCallingConfig: Encodable {}

Expand All @@ -293,3 +352,11 @@ extension FunctionCallingConfig.Mode: Encodable {}
extension ToolConfig: Encodable {}

extension FunctionResponse: Encodable {}

extension CodeExecutionDeclaration: Encodable {}

extension ExecutableCode: Codable {}

extension CodeExecutionResult: Codable {}

extension CodeExecutionResult.Outcome: Codable {}
16 changes: 16 additions & 0 deletions Sources/GoogleAI/ModelContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public struct ModelContent: Equatable {

/// A response to a function call.
case functionResponse(FunctionResponse)

/// Executable code
case executableCode(ExecutableCode)

/// Code Execution Result
case codeExecutionResult(CodeExecutionResult)

// MARK: Convenience Initializers

Expand Down Expand Up @@ -129,6 +135,8 @@ extension ModelContent.Part: Codable {
case fileData
case functionCall
case functionResponse
case executableCode
case codeExecutionResult
}

enum InlineDataKeys: String, CodingKey {
Expand Down Expand Up @@ -164,6 +172,10 @@ extension ModelContent.Part: Codable {
try container.encode(functionCall, forKey: .functionCall)
case let .functionResponse(functionResponse):
try container.encode(functionResponse, forKey: .functionResponse)
case let .executableCode(executableCode):
try container.encode(executableCode, forKey: .executableCode)
case let .codeExecutionResult(codeExecutionResult):
try container.encode(codeExecutionResult, forKey: .codeExecutionResult)
}
}

Expand All @@ -181,6 +193,10 @@ extension ModelContent.Part: Codable {
self = .data(mimetype: mimetype, bytes)
} else if values.contains(.functionCall) {
self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall))
} else if values.contains(.executableCode) {
self = try .executableCode(values.decode(ExecutableCode.self, forKey: .executableCode))
} else if values.contains(.codeExecutionResult) {
self = try .codeExecutionResult(values.decode(CodeExecutionResult.self, forKey: .codeExecutionResult))
} else {
throw DecodingError.dataCorrupted(.init(
codingPath: [CodingKeys.text, CodingKeys.inlineData],
Expand Down
18 changes: 13 additions & 5 deletions Tests/GoogleAITests/GenerateContentRequestTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ final class GenerateContentRequestTests: XCTestCase {
harmCategory: .dangerousContent,
threshold: .blockLowAndAbove
)],
tools: [Tool(functionDeclarations: [FunctionDeclaration(
name: "test-function-name",
description: "test-function-description",
parameters: nil
)])],
tools: [
Tool(functionDeclarations: [FunctionDeclaration(
name: "test-function-name",
description: "test-function-description",
parameters: nil
)]),
Tool.codeExecution,
],
toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)),
systemInstruction: ModelContent(role: "system", parts: "test-system-instruction"),
isStreaming: false,
Expand Down Expand Up @@ -102,6 +105,11 @@ final class GenerateContentRequestTests: XCTestCase {
}
}
]
},
{
"codeExecution" : {
}
}
]
}
Expand Down

0 comments on commit dad49b6

Please sign in to comment.