From cb8f9e141f2a61f40aecd389a58d348b99f42641 Mon Sep 17 00:00:00 2001 From: takahirom Date: Fri, 7 Feb 2025 19:48:38 +0900 Subject: [PATCH] Add multiple image assertion --- .../takahirom/roborazzi/AiAssertionOptions.kt | 14 ++++ .../takahirom/roborazzi/GeminiRoborazziAi.kt | 70 ++++++++++++----- .../roborazzi/OpenAiAiAssertionModel.kt | 76 +++++++++++++++---- 3 files changed, 127 insertions(+), 33 deletions(-) diff --git a/include-build/roborazzi-core/src/commonMain/kotlin/com/github/takahirom/roborazzi/AiAssertionOptions.kt b/include-build/roborazzi-core/src/commonMain/kotlin/com/github/takahirom/roborazzi/AiAssertionOptions.kt index 2c02e6a8..f5042d45 100644 --- a/include-build/roborazzi-core/src/commonMain/kotlin/com/github/takahirom/roborazzi/AiAssertionOptions.kt +++ b/include-build/roborazzi-core/src/commonMain/kotlin/com/github/takahirom/roborazzi/AiAssertionOptions.kt @@ -52,12 +52,26 @@ INPUT_PROMPT aiAssertionOptions: AiAssertionOptions ): AiAssertionResults + fun assert( + assertionTargetImages: AssertionTargetImages, + aiAssertionOptions: AiAssertionOptions + ): AiAssertionResults + companion object { const val DefaultMaxOutputTokens = 300 const val DefaultTemperature = 0.4F } } + class AssertionTargetImages( + val images: List, + ) + + class AssertionTargetImage( + val filePath: String, + ) + + sealed interface AssertionImageType { class Comparison : AssertionImageType class Actual : AssertionImageType diff --git a/roborazzi-ai-gemini/src/commonMain/kotlin/com/github/takahirom/roborazzi/GeminiRoborazziAi.kt b/roborazzi-ai-gemini/src/commonMain/kotlin/com/github/takahirom/roborazzi/GeminiRoborazziAi.kt index f7e73eb7..1734e257 100644 --- a/roborazzi-ai-gemini/src/commonMain/kotlin/com/github/takahirom/roborazzi/GeminiRoborazziAi.kt +++ b/roborazzi-ai-gemini/src/commonMain/kotlin/com/github/takahirom/roborazzi/GeminiRoborazziAi.kt @@ -4,7 +4,12 @@ import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature import dev.shreyaspatil.ai.client.generativeai.GenerativeModel -import dev.shreyaspatil.ai.client.generativeai.type.* +import dev.shreyaspatil.ai.client.generativeai.type.FunctionType +import dev.shreyaspatil.ai.client.generativeai.type.GenerationConfig +import dev.shreyaspatil.ai.client.generativeai.type.PlatformImage +import dev.shreyaspatil.ai.client.generativeai.type.Schema +import dev.shreyaspatil.ai.client.generativeai.type.content +import dev.shreyaspatil.ai.client.generativeai.type.generationConfig import kotlinx.coroutines.runBlocking import kotlinx.serialization.SerialName import kotlinx.serialization.Serializable @@ -25,6 +30,45 @@ class GeminiAiAssertionModel( aiAssertionOptions: AiAssertionOptions ): AiAssertionResults { val systemPrompt = aiAssertionOptions.systemPrompt + + val template = aiAssertionOptions.promptTemplate + + val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions) + val imageFilePath = when (aiAssertionOptions.assertionImageType) { + is AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath + is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath + } + return assert( + AiAssertionOptions.AssertionTargetImages(listOf(AiAssertionOptions.AssertionTargetImage(imageFilePath))), + template, + inputPrompt, + systemPrompt, + aiAssertionOptions + ) + } + + private fun assert( + assertionTargetImages: AiAssertionOptions.AssertionTargetImages, + template: String, + inputPrompt: String, + systemPrompt: String, + aiAssertionOptions: AiAssertionOptions + ): AiAssertionResults { + val imageByteArrays = + assertionTargetImages.images.map { image -> readByteArrayFromFile(image.filePath) } + + val inputContent = content { + imageByteArrays.forEach { imageByteArray -> + image(imageByteArray) + } + val prompt = template.replace("INPUT_PROMPT", inputPrompt) + text(prompt) + + debugLog { + "RoborazziAi: prompt:$prompt" + } + } + val generativeModel = GenerativeModel( modelName = modelName, apiKey = apiKey, @@ -59,24 +103,6 @@ class GeminiAiAssertionModel( generationConfigBuilder() }, ) - - val template = aiAssertionOptions.promptTemplate - - val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions) - val imageFilePath = when (aiAssertionOptions.assertionImageType) { - is AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath - is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath - } - val inputContent = content { - image(readByteArrayFromFile(imageFilePath)) - val prompt = template.replace("INPUT_PROMPT", inputPrompt) - text(prompt) - - debugLog { - "RoborazziAi: prompt:$prompt" - } - } - val response = runBlocking { generativeModel.generateContent(inputContent) } debugLog { "RoborazziAi: response: ${response.text}" @@ -102,6 +128,12 @@ class GeminiAiAssertionModel( } ) } + + override fun assert( + assertionTargetImages: AiAssertionOptions.AssertionTargetImages, + aiAssertionOptions: AiAssertionOptions + ): AiAssertionResults { + } } diff --git a/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt b/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt index a85850e7..fea70228 100644 --- a/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt +++ b/roborazzi-ai-openai/src/commonMain/kotlin/com/github/takahirom/roborazzi/OpenAiAiAssertionModel.kt @@ -2,16 +2,25 @@ package com.github.takahirom.roborazzi import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature +import com.github.takahirom.roborazzi.AiAssertionOptions.AssertionTargetImage import com.github.takahirom.roborazzi.CaptureResults.Companion.json -import io.ktor.client.* -import io.ktor.client.plugins.* +import io.ktor.client.HttpClient +import io.ktor.client.plugins.HttpTimeout import io.ktor.client.plugins.HttpTimeout.Plugin.INFINITE_TIMEOUT_MS -import io.ktor.client.plugins.contentnegotiation.* -import io.ktor.client.plugins.logging.* -import io.ktor.client.request.* -import io.ktor.client.statement.* -import io.ktor.http.* -import io.ktor.serialization.kotlinx.json.* +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation +import io.ktor.client.plugins.logging.LogLevel +import io.ktor.client.plugins.logging.Logger +import io.ktor.client.plugins.logging.Logging +import io.ktor.client.plugins.logging.SIMPLE +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.setBody +import io.ktor.client.statement.HttpResponse +import io.ktor.client.statement.bodyAsText +import io.ktor.http.ContentType +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json import kotlinx.coroutines.runBlocking import kotlinx.io.buffered import kotlinx.io.files.Path @@ -73,8 +82,48 @@ class OpenAiAiAssertionModel( is AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath } - val imageBytes = readByteArrayFromFile(imageFilePath) - val imageBase64 = imageBytes.encodeBase64() + return assert( + assertionTargetImages = AiAssertionOptions.AssertionTargetImages( + listOf( + AssertionTargetImage( + imageFilePath + ) + ) + ), + systemPrompt = systemPrompt, + template = template, + inputPrompt = inputPrompt, + aiAssertionOptions = aiAssertionOptions + ) + } + + override fun assert( + assertionTargetImages: AiAssertionOptions.AssertionTargetImages, + aiAssertionOptions: AiAssertionOptions + ): AiAssertionResults { + val systemPrompt = aiAssertionOptions.systemPrompt + val template = aiAssertionOptions.promptTemplate + val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions) + return assert( + assertionTargetImages = assertionTargetImages, + systemPrompt = systemPrompt, + template = template, + inputPrompt = inputPrompt, + aiAssertionOptions = aiAssertionOptions + ) + } + + private fun assert( + assertionTargetImages: AiAssertionOptions.AssertionTargetImages, + systemPrompt: String, + template: String, + inputPrompt: String, + aiAssertionOptions: AiAssertionOptions + ): AiAssertionResults { + val imageBase64s = assertionTargetImages.images.map { image -> + val imageBytes = readByteArrayFromFile(image.filePath) + imageBytes.encodeBase64() + } val messages = listOf( Message( role = "system", @@ -92,13 +141,12 @@ class OpenAiAiAssertionModel( type = "text", text = template.replace("INPUT_PROMPT", inputPrompt) ), + ) + imageBase64s.map { imageBase64 -> Content( type = "image_url", - imageUrl = ImageUrl( - url = "data:image/png;base64,$imageBase64" - ) + imageUrl = ImageUrl(url = "data:image/png;base64,$imageBase64") ) - ) + } ) ) val responseText = runBlocking {