Skip to content

Commit

Permalink
Add multiple image assertion
Browse files Browse the repository at this point in the history
  • Loading branch information
takahirom committed Feb 7, 2025
1 parent 9d1abd1 commit cb8f9e1
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,26 @@ INPUT_PROMPT
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults

fun assert(
assertionTargetImages: AssertionTargetImages,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults

companion object {
const val DefaultMaxOutputTokens = 300
const val DefaultTemperature = 0.4F
}
}

class AssertionTargetImages(
val images: List<AssertionTargetImage>,
)

class AssertionTargetImage(
val filePath: String,
)


sealed interface AssertionImageType {
class Comparison : AssertionImageType
class Actual : AssertionImageType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@ import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import dev.shreyaspatil.ai.client.generativeai.GenerativeModel
import dev.shreyaspatil.ai.client.generativeai.type.*
import dev.shreyaspatil.ai.client.generativeai.type.FunctionType
import dev.shreyaspatil.ai.client.generativeai.type.GenerationConfig
import dev.shreyaspatil.ai.client.generativeai.type.PlatformImage
import dev.shreyaspatil.ai.client.generativeai.type.Schema
import dev.shreyaspatil.ai.client.generativeai.type.content
import dev.shreyaspatil.ai.client.generativeai.type.generationConfig
import kotlinx.coroutines.runBlocking
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
Expand All @@ -25,6 +30,45 @@ class GeminiAiAssertionModel(
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
val systemPrompt = aiAssertionOptions.systemPrompt

val template = aiAssertionOptions.promptTemplate

val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions)
val imageFilePath = when (aiAssertionOptions.assertionImageType) {
is AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath
is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath
}
return assert(
AiAssertionOptions.AssertionTargetImages(listOf(AiAssertionOptions.AssertionTargetImage(imageFilePath))),
template,
inputPrompt,
systemPrompt,
aiAssertionOptions
)
}

private fun assert(
assertionTargetImages: AiAssertionOptions.AssertionTargetImages,
template: String,
inputPrompt: String,
systemPrompt: String,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
val imageByteArrays =
assertionTargetImages.images.map { image -> readByteArrayFromFile(image.filePath) }

val inputContent = content {
imageByteArrays.forEach { imageByteArray ->
image(imageByteArray)
}
val prompt = template.replace("INPUT_PROMPT", inputPrompt)
text(prompt)

debugLog {
"RoborazziAi: prompt:$prompt"
}
}

val generativeModel = GenerativeModel(
modelName = modelName,
apiKey = apiKey,
Expand Down Expand Up @@ -59,24 +103,6 @@ class GeminiAiAssertionModel(
generationConfigBuilder()
},
)

val template = aiAssertionOptions.promptTemplate

val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions)
val imageFilePath = when (aiAssertionOptions.assertionImageType) {
is AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath
is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath
}
val inputContent = content {
image(readByteArrayFromFile(imageFilePath))
val prompt = template.replace("INPUT_PROMPT", inputPrompt)
text(prompt)

debugLog {
"RoborazziAi: prompt:$prompt"
}
}

val response = runBlocking { generativeModel.generateContent(inputContent) }
debugLog {
"RoborazziAi: response: ${response.text}"
Expand All @@ -102,6 +128,12 @@ class GeminiAiAssertionModel(
}
)
}

override fun assert(
assertionTargetImages: AiAssertionOptions.AssertionTargetImages,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@ package com.github.takahirom.roborazzi

import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import com.github.takahirom.roborazzi.AiAssertionOptions.AssertionTargetImage
import com.github.takahirom.roborazzi.CaptureResults.Companion.json
import io.ktor.client.*
import io.ktor.client.plugins.*
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
import io.ktor.client.plugins.HttpTimeout.Plugin.INFINITE_TIMEOUT_MS
import io.ktor.client.plugins.contentnegotiation.*
import io.ktor.client.plugins.logging.*
import io.ktor.client.request.*
import io.ktor.client.statement.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation
import io.ktor.client.plugins.logging.LogLevel
import io.ktor.client.plugins.logging.Logger
import io.ktor.client.plugins.logging.Logging
import io.ktor.client.plugins.logging.SIMPLE
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.setBody
import io.ktor.client.statement.HttpResponse
import io.ktor.client.statement.bodyAsText
import io.ktor.http.ContentType
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import kotlinx.coroutines.runBlocking
import kotlinx.io.buffered
import kotlinx.io.files.Path
Expand Down Expand Up @@ -73,8 +82,48 @@ class OpenAiAiAssertionModel(
is AiAssertionOptions.AssertionImageType.Comparison -> comparisonImageFilePath
is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath
}
val imageBytes = readByteArrayFromFile(imageFilePath)
val imageBase64 = imageBytes.encodeBase64()
return assert(
assertionTargetImages = AiAssertionOptions.AssertionTargetImages(
listOf(
AssertionTargetImage(
imageFilePath
)
)
),
systemPrompt = systemPrompt,
template = template,
inputPrompt = inputPrompt,
aiAssertionOptions = aiAssertionOptions
)
}

override fun assert(
assertionTargetImages: AiAssertionOptions.AssertionTargetImages,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
val systemPrompt = aiAssertionOptions.systemPrompt
val template = aiAssertionOptions.promptTemplate
val inputPrompt = aiAssertionOptions.inputPrompt(aiAssertionOptions)
return assert(
assertionTargetImages = assertionTargetImages,
systemPrompt = systemPrompt,
template = template,
inputPrompt = inputPrompt,
aiAssertionOptions = aiAssertionOptions
)
}

private fun assert(
assertionTargetImages: AiAssertionOptions.AssertionTargetImages,
systemPrompt: String,
template: String,
inputPrompt: String,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
val imageBase64s = assertionTargetImages.images.map { image ->
val imageBytes = readByteArrayFromFile(image.filePath)
imageBytes.encodeBase64()
}
val messages = listOf(
Message(
role = "system",
Expand All @@ -92,13 +141,12 @@ class OpenAiAiAssertionModel(
type = "text",
text = template.replace("INPUT_PROMPT", inputPrompt)
),
) + imageBase64s.map { imageBase64 ->
Content(
type = "image_url",
imageUrl = ImageUrl(
url = "data:image/png;base64,$imageBase64"
)
imageUrl = ImageUrl(url = "data:image/png;base64,$imageBase64")
)
)
}
)
)
val responseText = runBlocking {
Expand Down

0 comments on commit cb8f9e1

Please sign in to comment.