Skip to content

Commit

Permalink
Fix TargetImage's position of class
Browse files Browse the repository at this point in the history
  • Loading branch information
takahirom committed Feb 10, 2025
1 parent 8409881 commit 633220e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ INPUT_PROMPT
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults

class TargetImages(
val images: List<TargetImage>,
)

class TargetImage(
val filePath: String,
)

companion object {
const val DefaultMaxOutputTokens = 300
const val DefaultTemperature = 0.4F
}
}

class TargetImages(
val images: List<TargetImage>,
)

class TargetImage(
val filePath: String,
)


sealed interface AssertionImageType {
class Comparison : AssertionImageType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package com.github.takahirom.roborazzi
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.TargetImage
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.TargetImages
import dev.shreyaspatil.ai.client.generativeai.GenerativeModel
import dev.shreyaspatil.ai.client.generativeai.type.FunctionType
import dev.shreyaspatil.ai.client.generativeai.type.GenerationConfig
Expand Down Expand Up @@ -39,16 +41,16 @@ class GeminiAiAssertionModel(
is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath
}
return assert(
AiAssertionOptions.TargetImages(listOf(AiAssertionOptions.TargetImage(imageFilePath))),
template,
inputPrompt,
systemPrompt,
aiAssertionOptions
targetImages = TargetImages(listOf(TargetImage(imageFilePath))),
template = template,
inputPrompt = inputPrompt,
systemPrompt = systemPrompt,
aiAssertionOptions = aiAssertionOptions
)
}

private fun assert(
targetImages: AiAssertionOptions.TargetImages,
targetImages: TargetImages,
template: String,
inputPrompt: String,
systemPrompt: String,
Expand Down Expand Up @@ -130,7 +132,7 @@ class GeminiAiAssertionModel(
}

override fun assert(
targetImages: AiAssertionOptions.TargetImages,
targetImages: TargetImages,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
val systemPrompt = aiAssertionOptions.systemPrompt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ package com.github.takahirom.roborazzi

import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultMaxOutputTokens
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.Companion.DefaultTemperature
import com.github.takahirom.roborazzi.AiAssertionOptions.TargetImage
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.TargetImage
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.TargetImages
import com.github.takahirom.roborazzi.CaptureResults.Companion.json
import io.ktor.client.HttpClient
import io.ktor.client.plugins.HttpTimeout
Expand Down Expand Up @@ -83,7 +84,7 @@ class OpenAiAiAssertionModel(
is AiAssertionOptions.AssertionImageType.Actual -> actualImageFilePath
}
return assert(
targetImages = AiAssertionOptions.TargetImages(
targetImages = TargetImages(
listOf(
TargetImage(
imageFilePath
Expand All @@ -98,7 +99,7 @@ class OpenAiAiAssertionModel(
}

override fun assert(
targetImages: AiAssertionOptions.TargetImages,
targetImages: TargetImages,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
val systemPrompt = aiAssertionOptions.systemPrompt
Expand All @@ -114,7 +115,7 @@ class OpenAiAiAssertionModel(
}

private fun assert(
targetImages: AiAssertionOptions.TargetImages,
targetImages: TargetImages,
systemPrompt: String,
template: String,
inputPrompt: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import androidx.compose.ui.test.junit4.createAndroidComposeRule
import androidx.compose.ui.test.onRoot
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.github.takahirom.roborazzi.AiAssertionOptions
import com.github.takahirom.roborazzi.AiAssertionOptions.AiAssertionModel.TargetImages
import com.github.takahirom.roborazzi.AiAssertionResult
import com.github.takahirom.roborazzi.AiAssertionResults
import com.github.takahirom.roborazzi.DEFAULT_ROBORAZZI_OUTPUT_DIR_PATH
Expand Down Expand Up @@ -69,7 +70,7 @@ class AiManualTest {
aiAssertionOptions = AiAssertionOptions(
aiAssertionModel = object : AiAssertionOptions.AiAssertionModel {
override fun assert(
targetImages: AiAssertionOptions.TargetImages,
targetImages: TargetImages,
aiAssertionOptions: AiAssertionOptions
): AiAssertionResults {
return AiAssertionResults(
Expand Down

0 comments on commit 633220e

Please sign in to comment.