diff --git a/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift index 7ebb821..8d65c5f 100644 --- a/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift +++ b/Examples/GenerativeAISample/FunctionCallingSample/ViewModels/FunctionCallingViewModel.swift @@ -157,7 +157,7 @@ class FunctionCallingViewModel: ObservableObject { case let .functionCall(functionCall): messages.insert(functionCall.chatMessage(), at: messages.count - 1) functionCalls.append(functionCall) - case .data, .fileData, .functionResponse: + case .data, .fileData, .functionResponse, .executableCode, .codeExecutionResult: fatalError("Unsupported response content.") } } diff --git a/Sources/GoogleAI/Chat.swift b/Sources/GoogleAI/Chat.swift index 6549df4..5f8dddf 100644 --- a/Sources/GoogleAI/Chat.swift +++ b/Sources/GoogleAI/Chat.swift @@ -160,7 +160,8 @@ public class Chat { case let .text(str): combinedText += str - case .data, .fileData, .functionCall, .functionResponse: + case .data, .fileData, .functionCall, .functionResponse, .executableCode, + .codeExecutionResult: // Don't combine it, just add to the content. If there's any text pending, add that as // a part. if !combinedText.isEmpty { diff --git a/Sources/GoogleAI/FunctionCalling.swift b/Sources/GoogleAI/FunctionCalling.swift index 57130eb..159c8b4 100644 --- a/Sources/GoogleAI/FunctionCalling.swift +++ b/Sources/GoogleAI/FunctionCalling.swift @@ -161,6 +161,9 @@ public struct Tool { /// A list of `FunctionDeclarations` available to the model. let functionDeclarations: [FunctionDeclaration]? + /// Enables the model to execute code as part of generation. + let codeExecution: CodeExecution? + /// Constructs a new `Tool`. /// /// - Parameters: @@ -172,8 +175,11 @@ public struct Tool { /// populating ``FunctionCall`` in the response. The next conversation turn may contain a /// ``FunctionResponse`` in ``ModelContent/Part/functionResponse(_:)`` with the /// ``ModelContent/role`` "function", providing generation context for the next model turn. - public init(functionDeclarations: [FunctionDeclaration]?) { + /// - codeExecution: Enables the model to execute code as part of generation, if provided. + public init(functionDeclarations: [FunctionDeclaration]? = nil, + codeExecution: CodeExecution? = nil) { self.functionDeclarations = functionDeclarations + self.codeExecution = codeExecution } } @@ -244,6 +250,55 @@ public struct FunctionResponse: Equatable { } } +/// Tool that executes code generated by the model, automatically returning the result to the model. +/// +/// This type has no fields. See ``ExecutableCode`` and ``CodeExecutionResult``, which are only +/// generated when using this tool. +public struct CodeExecution { + /// Constructs a new `CodeExecution` tool. + public init() {} +} + +/// Code generated by the model that is meant to be executed, and the result returned to the model. +/// +/// Only generated when using the ``CodeExecution`` tool, in which case the code will automatically +/// be executed, and a corresponding ``CodeExecutionResult`` will also be generated. +public struct ExecutableCode: Equatable { + /// The programming language of the ``code``. + public let language: String + + /// The code to be executed. + public let code: String +} + +/// Result of executing the ``ExecutableCode``. +/// +/// Only generated when using the ``CodeExecution`` tool, and always follows a part containing the +/// ``ExecutableCode``. +public struct CodeExecutionResult: Equatable { + /// Possible outcomes of the code execution. + public enum Outcome: String { + /// An unrecognized code execution outcome was provided. + case unknown = "OUTCOME_UNKNOWN" + /// Unspecified status; this value should not be used. + case unspecified = "OUTCOME_UNSPECIFIED" + /// Code execution completed successfully. + case ok = "OUTCOME_OK" + /// Code execution finished but with a failure; ``CodeExecutionResult/output`` should contain + /// the failure details from `stderr`. + case failed = "OUTCOME_FAILED" + /// Code execution ran for too long, and was cancelled. There may or may not be a partial + /// ``CodeExecutionResult/output`` present. + case deadlineExceeded = "OUTCOME_DEADLINE_EXCEEDED" + } + + /// Outcome of the code execution. + public let outcome: Outcome + + /// Contains `stdout` when code execution is successful, `stderr` or other description otherwise. + public let output: String +} + // MARK: - Codable Conformance extension FunctionCall: Decodable { @@ -293,3 +348,31 @@ extension FunctionCallingConfig.Mode: Encodable {} extension ToolConfig: Encodable {} extension FunctionResponse: Encodable {} + +extension CodeExecution: Encodable {} + +extension ExecutableCode: Codable {} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension CodeExecutionResult.Outcome: Codable { + public init(from decoder: any Decoder) throws { + let value = try decoder.singleValueContainer().decode(String.self) + guard let decodedOutcome = CodeExecutionResult.Outcome(rawValue: value) else { + Logging.default + .error("[GoogleGenerativeAI] Unrecognized Outcome with value \"\(value)\".") + self = .unknown + return + } + + self = decodedOutcome + } +} + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +extension CodeExecutionResult: Codable { + public init(from decoder: any Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + outcome = try container.decode(Outcome.self, forKey: .outcome) + output = try container.decodeIfPresent(String.self, forKey: .output) ?? "" + } +} diff --git a/Sources/GoogleAI/GenerateContentResponse.swift b/Sources/GoogleAI/GenerateContentResponse.swift index 04c41f7..44083c4 100644 --- a/Sources/GoogleAI/GenerateContentResponse.swift +++ b/Sources/GoogleAI/GenerateContentResponse.swift @@ -46,16 +46,31 @@ public struct GenerateContentResponse { return nil } let textValues: [String] = candidate.content.parts.compactMap { part in - guard case let .text(text) = part else { + switch part { + case let .text(text): + return text + case let .executableCode(executableCode): + let codeBlockLanguage: String + if executableCode.language == "LANGUAGE_UNSPECIFIED" { + codeBlockLanguage = "" + } else { + codeBlockLanguage = executableCode.language.lowercased() + } + return "```\(codeBlockLanguage)\n\(executableCode.code)\n```" + case let .codeExecutionResult(codeExecutionResult): + if codeExecutionResult.output.isEmpty { + return nil + } + return "```\n\(codeExecutionResult.output)\n```" + case .data, .fileData, .functionCall, .functionResponse: return nil } - return text } guard textValues.count > 0 else { Logging.default.error("Could not get a text part from the first candidate.") return nil } - return textValues.joined(separator: " ") + return textValues.joined(separator: "\n") } /// Returns function calls found in any `Part`s of the first candidate of the response, if any. diff --git a/Sources/GoogleAI/ModelContent.swift b/Sources/GoogleAI/ModelContent.swift index 979c406..59bf1be 100644 --- a/Sources/GoogleAI/ModelContent.swift +++ b/Sources/GoogleAI/ModelContent.swift @@ -51,6 +51,12 @@ public struct ModelContent: Equatable { /// A response to a function call. case functionResponse(FunctionResponse) + /// Code generated by the model that is meant to be executed. + case executableCode(ExecutableCode) + + /// Result of executing the ``ExecutableCode``. + case codeExecutionResult(CodeExecutionResult) + // MARK: Convenience Initializers /// Convenience function for populating a Part with JPEG data. @@ -129,6 +135,8 @@ extension ModelContent.Part: Codable { case fileData case functionCall case functionResponse + case executableCode + case codeExecutionResult } enum InlineDataKeys: String, CodingKey { @@ -164,6 +172,10 @@ extension ModelContent.Part: Codable { try container.encode(functionCall, forKey: .functionCall) case let .functionResponse(functionResponse): try container.encode(functionResponse, forKey: .functionResponse) + case let .executableCode(executableCode): + try container.encode(executableCode, forKey: .executableCode) + case let .codeExecutionResult(codeExecutionResult): + try container.encode(codeExecutionResult, forKey: .codeExecutionResult) } } @@ -181,6 +193,13 @@ extension ModelContent.Part: Codable { self = .data(mimetype: mimetype, bytes) } else if values.contains(.functionCall) { self = try .functionCall(values.decode(FunctionCall.self, forKey: .functionCall)) + } else if values.contains(.executableCode) { + self = try .executableCode(values.decode(ExecutableCode.self, forKey: .executableCode)) + } else if values.contains(.codeExecutionResult) { + self = try .codeExecutionResult(values.decode( + CodeExecutionResult.self, + forKey: .codeExecutionResult + )) } else { throw DecodingError.dataCorrupted(.init( codingPath: [CodingKeys.text, CodingKeys.inlineData], diff --git a/Tests/GoogleAITests/CodeExecutionTests.swift b/Tests/GoogleAITests/CodeExecutionTests.swift new file mode 100644 index 0000000..2818fe6 --- /dev/null +++ b/Tests/GoogleAITests/CodeExecutionTests.swift @@ -0,0 +1,154 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest + +@testable import GoogleGenerativeAI + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +final class CodeExecutionTests: XCTestCase { + let decoder = JSONDecoder() + let encoder = JSONEncoder() + + let languageKey = "language" + let languageValue = "PYTHON" + let codeKey = "code" + let codeValue = "print('Hello, world!')" + let outcomeKey = "outcome" + let outcomeValue = "OUTCOME_OK" + let outputKey = "output" + let outputValue = "Hello, world!" + + override func setUp() { + encoder.outputFormatting = .init( + arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes + ) + } + + func testEncodeCodeExecution() throws { + let jsonData = try encoder.encode(CodeExecution()) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + + } + """) + } + + func testDecodeExecutableCode() throws { + let expectedExecutableCode = ExecutableCode(language: languageValue, code: codeValue) + let json = """ + { + "\(languageKey)": "\(languageValue)", + "\(codeKey)": "\(codeValue)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let executableCode = try XCTUnwrap(decoder.decode(ExecutableCode.self, from: jsonData)) + + XCTAssertEqual(executableCode, expectedExecutableCode) + } + + func testEncodeExecutableCode() throws { + let executableCode = ExecutableCode(language: languageValue, code: codeValue) + + let jsonData = try encoder.encode(executableCode) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "\(codeKey)" : "\(codeValue)", + "\(languageKey)" : "\(languageValue)" + } + """) + } + + func testDecodeCodeExecutionResultOutcome_ok() throws { + let expectedOutcome = CodeExecutionResult.Outcome.ok + let json = "\"\(outcomeValue)\"" + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let outcome = try XCTUnwrap(decoder.decode(CodeExecutionResult.Outcome.self, from: jsonData)) + + XCTAssertEqual(outcome, expectedOutcome) + } + + func testDecodeCodeExecutionResultOutcome_unknown() throws { + let expectedOutcome = CodeExecutionResult.Outcome.unknown + let json = "\"OUTCOME_NEW_VALUE\"" + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let outcome = try XCTUnwrap(decoder.decode(CodeExecutionResult.Outcome.self, from: jsonData)) + + XCTAssertEqual(outcome, expectedOutcome) + } + + func testEncodeCodeExecutionResultOutcome() throws { + let jsonData = try encoder.encode(CodeExecutionResult.Outcome.ok) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, "\"\(outcomeValue)\"") + } + + func testDecodeCodeExecutionResult() throws { + let expectedCodeExecutionResult = CodeExecutionResult(outcome: .ok, output: "Hello, world!") + let json = """ + { + "\(outcomeKey)": "\(outcomeValue)", + "\(outputKey)": "\(outputValue)" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let codeExecutionResult = try XCTUnwrap(decoder.decode( + CodeExecutionResult.self, + from: jsonData + )) + + XCTAssertEqual(codeExecutionResult, expectedCodeExecutionResult) + } + + func testDecodeCodeExecutionResult_missingOutput() throws { + let expectedCodeExecutionResult = CodeExecutionResult(outcome: .deadlineExceeded, output: "") + let json = """ + { + "\(outcomeKey)": "OUTCOME_DEADLINE_EXCEEDED" + } + """ + let jsonData = try XCTUnwrap(json.data(using: .utf8)) + + let codeExecutionResult = try XCTUnwrap(decoder.decode( + CodeExecutionResult.self, + from: jsonData + )) + + XCTAssertEqual(codeExecutionResult, expectedCodeExecutionResult) + } + + func testEncodeCodeExecutionResult() throws { + let codeExecutionResult = CodeExecutionResult(outcome: .ok, output: outputValue) + + let jsonData = try encoder.encode(codeExecutionResult) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "\(outcomeKey)" : "\(outcomeValue)", + "\(outputKey)" : "\(outputValue)" + } + """) + } +} diff --git a/Tests/GoogleAITests/GenerateContentRequestTests.swift b/Tests/GoogleAITests/GenerateContentRequestTests.swift index a808799..0ef1ac4 100644 --- a/Tests/GoogleAITests/GenerateContentRequestTests.swift +++ b/Tests/GoogleAITests/GenerateContentRequestTests.swift @@ -42,11 +42,16 @@ final class GenerateContentRequestTests: XCTestCase { harmCategory: .dangerousContent, threshold: .blockLowAndAbove )], - tools: [Tool(functionDeclarations: [FunctionDeclaration( - name: "test-function-name", - description: "test-function-description", - parameters: nil - )])], + tools: [ + Tool(functionDeclarations: [ + FunctionDeclaration( + name: "test-function-name", + description: "test-function-description", + parameters: nil + ), + ]), + Tool(codeExecution: CodeExecution()), + ], toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)), systemInstruction: ModelContent(role: "system", parts: "test-system-instruction"), isStreaming: false, @@ -102,6 +107,11 @@ final class GenerateContentRequestTests: XCTestCase { } } ] + }, + { + "codeExecution" : { + + } } ] } diff --git a/Tests/GoogleAITests/GenerateContentResponseTests.swift b/Tests/GoogleAITests/GenerateContentResponseTests.swift new file mode 100644 index 0000000..46ee8b7 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponseTests.swift @@ -0,0 +1,173 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import XCTest + +@testable import GoogleGenerativeAI + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +final class GenerateContentResponseTests: XCTestCase { + let testText1 = "test-text-1" + let testText2 = "test-text-2" + let testLanguage = "PYTHON" + let testCode = "print('Hello, world!')" + let testOutput = "Hello, world!" + + override func setUp() {} + + func testText_textPart() throws { + let parts = [ModelContent.Part.text(testText1)] + let candidate = CandidateResponse( + content: ModelContent(role: "model", parts: parts), + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let response = GenerateContentResponse(candidates: [candidate]) + + let text = try XCTUnwrap(response.text) + + XCTAssertEqual(text, "\(testText1)") + } + + func testText_textParts_concatenated() throws { + let parts = [ModelContent.Part.text(testText1), ModelContent.Part.text(testText2)] + let candidate = CandidateResponse( + content: ModelContent(role: "model", parts: parts), + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let response = GenerateContentResponse(candidates: [candidate]) + + let text = try XCTUnwrap(response.text) + + XCTAssertEqual(text, """ + \(testText1) + \(testText2) + """) + } + + func testText_executableCodePart_python() throws { + let parts = [ModelContent.Part.executableCode(ExecutableCode( + language: testLanguage, + code: testCode + ))] + let candidate = CandidateResponse( + content: ModelContent(role: "model", parts: parts), + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let response = GenerateContentResponse(candidates: [candidate]) + + let text = try XCTUnwrap(response.text) + + XCTAssertEqual(text, """ + ```\(testLanguage.lowercased()) + \(testCode) + ``` + """) + } + + func testText_executableCodePart_unspecifiedLanguage() throws { + let parts = [ModelContent.Part.executableCode(ExecutableCode( + language: "LANGUAGE_UNSPECIFIED", + code: "echo $SHELL" + ))] + let candidate = CandidateResponse( + content: ModelContent(role: "model", parts: parts), + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let response = GenerateContentResponse(candidates: [candidate]) + + let text = try XCTUnwrap(response.text) + + XCTAssertEqual(text, """ + ``` + echo $SHELL + ``` + """) + } + + func testText_codeExecutionResultPart_hasOutput() throws { + let parts = [ModelContent.Part.codeExecutionResult(CodeExecutionResult( + outcome: .ok, + output: testOutput + ))] + let candidate = CandidateResponse( + content: ModelContent(role: "model", parts: parts), + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let response = GenerateContentResponse(candidates: [candidate]) + + let text = try XCTUnwrap(response.text) + + XCTAssertEqual(text, """ + ``` + \(testOutput) + ``` + """) + } + + func testText_codeExecutionResultPart_emptyOutput() throws { + let parts = [ModelContent.Part.codeExecutionResult(CodeExecutionResult( + outcome: .deadlineExceeded, + output: "" + ))] + let candidate = CandidateResponse( + content: ModelContent(role: "model", parts: parts), + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let response = GenerateContentResponse(candidates: [candidate]) + + XCTAssertNil(response.text) + } + + func testText_codeExecution_concatenated() throws { + let parts: [ModelContent.Part] = [ + .text("test-text-1"), + .executableCode(ExecutableCode(language: testLanguage, code: testCode)), + .codeExecutionResult(CodeExecutionResult(outcome: .ok, output: testOutput)), + .text("test-text-2"), + ] + let candidate = CandidateResponse( + content: ModelContent(role: "model", parts: parts), + safetyRatings: [], + finishReason: nil, + citationMetadata: nil + ) + let response = GenerateContentResponse(candidates: [candidate]) + + let text = try XCTUnwrap(response.text) + + XCTAssertEqual(text, """ + \(testText1) + ```\(testLanguage.lowercased()) + \(testCode) + ``` + ``` + \(testOutput) + ``` + \(testText2) + """) + } +} diff --git a/Tests/GoogleAITests/GenerateContentResponses/streaming-success-code-execution.txt b/Tests/GoogleAITests/GenerateContentResponses/streaming-success-code-execution.txt new file mode 100644 index 0000000..24c9ef6 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/streaming-success-code-execution.txt @@ -0,0 +1,16 @@ +data: {"candidates": [{"content": {"parts": [{"text": "Thoughts"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}} + +data: {"candidates": [{"content": {"parts": [{"text": ": I can use the `print()` function in Python to print strings. "}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 16,"totalTokenCount": 37}} + +data: {"candidates": [{"content": {"parts": [{"text": "\n\n"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 16,"totalTokenCount": 37}} + +data: {"candidates": [{"content": {"parts": [{"executableCode": {"language": "PYTHON","code": "\nprint(\"Hello, world!\")\n"}}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 29,"totalTokenCount": 50}} + +data: {"candidates": [{"content": {"parts": [{"codeExecutionResult": {"outcome": "OUTCOME_OK","output": "Hello, world!\n"}}],"role": "model"},"index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 29,"totalTokenCount": 50}} + +data: {"candidates": [{"content": {"parts": [{"text": "OK"}],"role": "model"},"finishReason": "STOP","index": 0}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 1,"totalTokenCount": 22}} + +data: {"candidates": [{"content": {"parts": [{"text": ". I have printed \"Hello, world!\" using the `print()` function in"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 17,"totalTokenCount": 38}} + +data: {"candidates": [{"content": {"parts": [{"text": " Python. \n"}],"role": "model"},"finishReason": "STOP","index": 0,"safetyRatings": [{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 21,"candidatesTokenCount": 19,"totalTokenCount": 40}} + diff --git a/Tests/GoogleAITests/GenerateContentResponses/unary-success-code-execution.json b/Tests/GoogleAITests/GenerateContentResponses/unary-success-code-execution.json new file mode 100644 index 0000000..0b5a955 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentResponses/unary-success-code-execution.json @@ -0,0 +1,54 @@ +{ + "candidates": [ + { + "content": { + "parts": [ + { + "text": "To print strings in Python, you use the `print()` function. Here's how you can print \"Hello, world!\":\n\n" + }, + { + "executableCode": { + "language": "PYTHON", + "code": "\nprint(\"Hello, world!\")\n" + } + }, + { + "codeExecutionResult": { + "outcome": "OUTCOME_OK", + "output": "Hello, world!\n" + } + }, + { + "text": "The code successfully prints the string \"Hello, world!\". \n" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 21, + "candidatesTokenCount": 11, + "totalTokenCount": 32 + } +} diff --git a/Tests/GoogleAITests/GenerativeModelTests.swift b/Tests/GoogleAITests/GenerativeModelTests.swift index 5a20343..3dbe7cd 100644 --- a/Tests/GoogleAITests/GenerativeModelTests.swift +++ b/Tests/GoogleAITests/GenerativeModelTests.swift @@ -285,7 +285,61 @@ final class GenerativeModelTests: XCTestCase { let functionCalls = response.functionCalls XCTAssertEqual(functionCalls.count, 2) let text = try XCTUnwrap(response.text) - XCTAssertEqual(text, "The sum of [1, 2, 3] is") + XCTAssertEqual(text, "The sum of [1, 2,\n3] is") + } + + func testGenerateContent_success_codeExecution() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "unary-success-code-execution", + withExtension: "json" + ) + let expectedText1 = """ + To print strings in Python, you use the `print()` function. \ + Here's how you can print \"Hello, world!\":\n\n + """ + let expectedText2 = "The code successfully prints the string \"Hello, world!\". \n" + let expectedLanguage = "PYTHON" + let expectedCode = "\nprint(\"Hello, world!\")\n" + let expectedOutput = "Hello, world!\n" + + let response = try await model.generateContent(testPrompt) + + XCTAssertEqual(response.candidates.count, 1) + let candidate = try XCTUnwrap(response.candidates.first) + XCTAssertEqual(candidate.content.parts.count, 4) + guard case let .text(text1) = candidate.content.parts[0] else { + XCTFail("Expected first part to be text.") + return + } + XCTAssertEqual(text1, expectedText1) + guard case let .executableCode(executableCode) = candidate.content.parts[1] else { + XCTFail("Expected second part to be executable code.") + return + } + XCTAssertEqual(executableCode.language, expectedLanguage) + XCTAssertEqual(executableCode.code, expectedCode) + guard case let .codeExecutionResult(codeExecutionResult) = candidate.content.parts[2] else { + XCTFail("Expected second part to be a code execution result.") + return + } + XCTAssertEqual(codeExecutionResult.outcome, .ok) + XCTAssertEqual(codeExecutionResult.output, expectedOutput) + guard case let .text(text2) = candidate.content.parts[3] else { + XCTFail("Expected fourth part to be text.") + return + } + XCTAssertEqual(text2, expectedText2) + XCTAssertEqual(try XCTUnwrap(response.text), """ + \(expectedText1) + ```\(expectedLanguage.lowercased()) + \(expectedCode) + ``` + ``` + \(expectedOutput) + ``` + \(expectedText2) + """) } func testGenerateContent_usageMetadata() async throws { @@ -818,6 +872,59 @@ final class GenerativeModelTests: XCTestCase { })) } + func testGenerateContentStream_success_codeExecution() async throws { + MockURLProtocol + .requestHandler = try httpRequestHandler( + forResource: "streaming-success-code-execution", + withExtension: "txt" + ) + let expectedTexts1 = [ + "Thoughts", + ": I can use the `print()` function in Python to print strings. ", + "\n\n", + ] + let expectedTexts2 = [ + "OK", + ". I have printed \"Hello, world!\" using the `print()` function in", + " Python. \n", + ] + let expectedTexts = Set(expectedTexts1 + expectedTexts2) + let expectedLanguage = "PYTHON" + let expectedCode = "\nprint(\"Hello, world!\")\n" + let expectedOutput = "Hello, world!\n" + + var textValues = [String]() + let stream = model.generateContentStream(testPrompt) + for try await content in stream { + let candidate = try XCTUnwrap(content.candidates.first) + let part = try XCTUnwrap(candidate.content.parts.first) + switch part { + case let .text(textPart): + XCTAssertTrue(expectedTexts.contains(textPart)) + case let .executableCode(executableCode): + XCTAssertEqual(executableCode.language, expectedLanguage) + XCTAssertEqual(executableCode.code, expectedCode) + case let .codeExecutionResult(codeExecutionResult): + XCTAssertEqual(codeExecutionResult.outcome, .ok) + XCTAssertEqual(codeExecutionResult.output, expectedOutput) + default: + XCTFail("Unexpected part type: \(part)") + } + try textValues.append(XCTUnwrap(content.text)) + } + + XCTAssertEqual(textValues.joined(separator: "\n"), """ + \(expectedTexts1.joined(separator: "\n")) + ```\(expectedLanguage.lowercased()) + \(expectedCode) + ``` + ``` + \(expectedOutput) + ``` + \(expectedTexts2.joined(separator: "\n")) + """) + } + func testGenerateContentStream_usageMetadata() async throws { MockURLProtocol .requestHandler = try httpRequestHandler(