From eae0721fe317de244451d0d1888abb841a01a2e4 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Tue, 28 May 2024 18:10:45 -0400 Subject: [PATCH 1/2] Send `GenerateContentRequest` in `CountTokensRequest` --- Sources/GoogleAI/CountTokensRequest.swift | 4 +- Sources/GoogleAI/GenerateContentRequest.swift | 1 + Sources/GoogleAI/GenerativeModel.swift | 13 +- .../GenerateContentRequestTests.swift | 144 ++++++++++++++++++ 4 files changed, 158 insertions(+), 4 deletions(-) create mode 100644 Tests/GoogleAITests/GenerateContentRequestTests.swift diff --git a/Sources/GoogleAI/CountTokensRequest.swift b/Sources/GoogleAI/CountTokensRequest.swift index de852ae..d8bfc0e 100644 --- a/Sources/GoogleAI/CountTokensRequest.swift +++ b/Sources/GoogleAI/CountTokensRequest.swift @@ -17,7 +17,7 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) struct CountTokensRequest { let model: String - let contents: [ModelContent] + let generateContentRequest: GenerateContentRequest let options: RequestOptions } @@ -42,7 +42,7 @@ public struct CountTokensResponse { @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) extension CountTokensRequest: Encodable { enum CodingKeys: CodingKey { - case contents + case generateContentRequest } } diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index 05abadf..c360583 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -31,6 +31,7 @@ struct GenerateContentRequest { @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) extension GenerateContentRequest: Encodable { enum CodingKeys: String, CodingKey { + case model case contents case generationConfig case safetySettings diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index ed1aecd..fc9c985 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -325,9 +325,18 @@ public final class GenerativeModel { public func countTokens(_ content: @autoclosure () throws -> [ModelContent]) async throws -> CountTokensResponse { do { - let countTokensRequest = try CountTokensRequest( + let generateContentRequest = try GenerateContentRequest(model: modelResourceName, + contents: content(), + generationConfig: generationConfig, + safetySettings: safetySettings, + tools: tools, + toolConfig: toolConfig, + systemInstruction: systemInstruction, + isStreaming: false, + options: requestOptions) + let countTokensRequest = CountTokensRequest( model: modelResourceName, - contents: content(), + generateContentRequest: generateContentRequest, options: requestOptions ) return try await generativeAIService.loadRequest(request: countTokensRequest) diff --git a/Tests/GoogleAITests/GenerateContentRequestTests.swift b/Tests/GoogleAITests/GenerateContentRequestTests.swift new file mode 100644 index 0000000..a808799 --- /dev/null +++ b/Tests/GoogleAITests/GenerateContentRequestTests.swift @@ -0,0 +1,144 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import XCTest + +@testable import GoogleGenerativeAI + +@available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) +final class GenerateContentRequestTests: XCTestCase { + let encoder = JSONEncoder() + let role = "test-role" + let prompt = "test-prompt" + let modelName = "test-model-name" + + override func setUp() { + encoder.outputFormatting = .init( + arrayLiteral: .prettyPrinted, .sortedKeys, .withoutEscapingSlashes + ) + } + + // MARK: GenerateContentRequest Encoding + + func testEncodeRequest_allFieldsIncluded() throws { + let content = [ModelContent(role: role, parts: prompt)] + let request = GenerateContentRequest( + model: modelName, + contents: content, + generationConfig: GenerationConfig(temperature: 0.5), + safetySettings: [SafetySetting( + harmCategory: .dangerousContent, + threshold: .blockLowAndAbove + )], + tools: [Tool(functionDeclarations: [FunctionDeclaration( + name: "test-function-name", + description: "test-function-description", + parameters: nil + )])], + toolConfig: ToolConfig(functionCallingConfig: FunctionCallingConfig(mode: .auto)), + systemInstruction: ModelContent(role: "system", parts: "test-system-instruction"), + isStreaming: false, + options: RequestOptions() + ) + + let jsonData = try encoder.encode(request) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "contents" : [ + { + "parts" : [ + { + "text" : "\(prompt)" + } + ], + "role" : "\(role)" + } + ], + "generationConfig" : { + "temperature" : 0.5 + }, + "model" : "\(modelName)", + "safetySettings" : [ + { + "category" : "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold" : "BLOCK_LOW_AND_ABOVE" + } + ], + "systemInstruction" : { + "parts" : [ + { + "text" : "test-system-instruction" + } + ], + "role" : "system" + }, + "toolConfig" : { + "functionCallingConfig" : { + "mode" : "AUTO" + } + }, + "tools" : [ + { + "functionDeclarations" : [ + { + "description" : "test-function-description", + "name" : "test-function-name", + "parameters" : { + "type" : "OBJECT" + } + } + ] + } + ] + } + """) + } + + func testEncodeRequest_optionalFieldsOmitted() throws { + let content = [ModelContent(role: role, parts: prompt)] + let request = GenerateContentRequest( + model: modelName, + contents: content, + generationConfig: nil, + safetySettings: nil, + tools: nil, + toolConfig: nil, + systemInstruction: nil, + isStreaming: false, + options: RequestOptions() + ) + + let jsonData = try encoder.encode(request) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "contents" : [ + { + "parts" : [ + { + "text" : "\(prompt)" + } + ], + "role" : "\(role)" + } + ], + "model" : "\(modelName)" + } + """) + } +} From 8a0c9bbf49233d672037a0245fd18dae2dab5ff9 Mon Sep 17 00:00:00 2001 From: Andrew Heard Date: Thu, 16 May 2024 14:15:02 -0400 Subject: [PATCH 2/2] Encode `model` in `GenerateContentRequest` only when needed --- Sources/GoogleAI/GenerateContentRequest.swift | 29 ++++++++++++- Sources/GoogleAI/GenerativeModel.swift | 43 +++++++++++-------- .../GenerateContentRequestTests.swift | 38 +++++++++++++++- 3 files changed, 90 insertions(+), 20 deletions(-) diff --git a/Sources/GoogleAI/GenerateContentRequest.swift b/Sources/GoogleAI/GenerateContentRequest.swift index c360583..ace2795 100644 --- a/Sources/GoogleAI/GenerateContentRequest.swift +++ b/Sources/GoogleAI/GenerateContentRequest.swift @@ -16,8 +16,11 @@ import Foundation @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) struct GenerateContentRequest { - /// Model name. + // Model name. let model: String + // If true, the `model` field above is encoded in requests; currently only required when nested in + // a `CountTokensRequest`. + let isModelEncoded: Bool let contents: [ModelContent] let generationConfig: GenerationConfig? let safetySettings: [SafetySetting]? @@ -39,6 +42,30 @@ extension GenerateContentRequest: Encodable { case toolConfig case systemInstruction } + + func encode(to encoder: any Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + + if isModelEncoded { + try container.encode(model, forKey: .model) + } + try container.encode(contents, forKey: .contents) + if let generationConfig { + try container.encode(generationConfig, forKey: .generationConfig) + } + if let safetySettings { + try container.encode(safetySettings, forKey: .safetySettings) + } + if let tools { + try container.encode(tools, forKey: .tools) + } + if let toolConfig { + try container.encode(toolConfig, forKey: .toolConfig) + } + if let systemInstruction { + try container.encode(systemInstruction, forKey: .systemInstruction) + } + } } @available(iOS 15.0, macOS 11.0, macCatalyst 15.0, *) diff --git a/Sources/GoogleAI/GenerativeModel.swift b/Sources/GoogleAI/GenerativeModel.swift index fc9c985..438ac2d 100644 --- a/Sources/GoogleAI/GenerativeModel.swift +++ b/Sources/GoogleAI/GenerativeModel.swift @@ -175,15 +175,18 @@ public final class GenerativeModel { -> GenerateContentResponse { let response: GenerateContentResponse do { - let generateContentRequest = try GenerateContentRequest(model: modelResourceName, - contents: content(), - generationConfig: generationConfig, - safetySettings: safetySettings, - tools: tools, - toolConfig: toolConfig, - systemInstruction: systemInstruction, - isStreaming: false, - options: requestOptions) + let generateContentRequest = try GenerateContentRequest( + model: modelResourceName, + isModelEncoded: false, + contents: content(), + generationConfig: generationConfig, + safetySettings: safetySettings, + tools: tools, + toolConfig: toolConfig, + systemInstruction: systemInstruction, + isStreaming: false, + options: requestOptions + ) response = try await generativeAIService.loadRequest(request: generateContentRequest) } catch { if let imageError = error as? ImageConversionError { @@ -249,15 +252,18 @@ public final class GenerativeModel { } } - let generateContentRequest = GenerateContentRequest(model: modelResourceName, - contents: evaluatedContent, - generationConfig: generationConfig, - safetySettings: safetySettings, - tools: tools, - toolConfig: toolConfig, - systemInstruction: systemInstruction, - isStreaming: true, - options: requestOptions) + let generateContentRequest = GenerateContentRequest( + model: modelResourceName, + isModelEncoded: false, + contents: evaluatedContent, + generationConfig: generationConfig, + safetySettings: safetySettings, + tools: tools, + toolConfig: toolConfig, + systemInstruction: systemInstruction, + isStreaming: true, + options: requestOptions + ) var responseIterator = generativeAIService.loadRequestStream(request: generateContentRequest) .makeAsyncIterator() @@ -326,6 +332,7 @@ public final class GenerativeModel { -> CountTokensResponse { do { let generateContentRequest = try GenerateContentRequest(model: modelResourceName, + isModelEncoded: true, contents: content(), generationConfig: generationConfig, safetySettings: safetySettings, diff --git a/Tests/GoogleAITests/GenerateContentRequestTests.swift b/Tests/GoogleAITests/GenerateContentRequestTests.swift index a808799..72f6783 100644 --- a/Tests/GoogleAITests/GenerateContentRequestTests.swift +++ b/Tests/GoogleAITests/GenerateContentRequestTests.swift @@ -36,6 +36,7 @@ final class GenerateContentRequestTests: XCTestCase { let content = [ModelContent(role: role, parts: prompt)] let request = GenerateContentRequest( model: modelName, + isModelEncoded: true, contents: content, generationConfig: GenerationConfig(temperature: 0.5), safetySettings: [SafetySetting( @@ -108,10 +109,11 @@ final class GenerateContentRequestTests: XCTestCase { """) } - func testEncodeRequest_optionalFieldsOmitted() throws { + func testEncodeRequest_optionalFieldsOmitted_modelNameEncoded() throws { let content = [ModelContent(role: role, parts: prompt)] let request = GenerateContentRequest( model: modelName, + isModelEncoded: true, contents: content, generationConfig: nil, safetySettings: nil, @@ -141,4 +143,38 @@ final class GenerateContentRequestTests: XCTestCase { } """) } + + func testEncodeRequest_optionalFieldsOmitted_modelNameNotEncoded() throws { + let content = [ModelContent(role: role, parts: prompt)] + let request = GenerateContentRequest( + model: modelName, + isModelEncoded: false, + contents: content, + generationConfig: nil, + safetySettings: nil, + tools: nil, + toolConfig: nil, + systemInstruction: nil, + isStreaming: false, + options: RequestOptions() + ) + + let jsonData = try encoder.encode(request) + + let json = try XCTUnwrap(String(data: jsonData, encoding: .utf8)) + XCTAssertEqual(json, """ + { + "contents" : [ + { + "parts" : [ + { + "text" : "\(prompt)" + } + ], + "role" : "\(role)" + } + ] + } + """) + } }