Skip to content

Commit

Permalink
feat: Add Edits API
Browse files Browse the repository at this point in the history
  • Loading branch information
renatoarg committed Mar 16, 2023
1 parent 9b6440b commit 9cb929f
Show file tree
Hide file tree
Showing 18 changed files with 345 additions and 1 deletion.
16 changes: 16 additions & 0 deletions ychat/src/commonMain/kotlin/co/yml/ychat/YChat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package co.yml.ychat

import co.yml.ychat.entrypoint.features.ChatCompletions
import co.yml.ychat.entrypoint.features.Completion
import co.yml.ychat.entrypoint.features.Edits
import co.yml.ychat.entrypoint.features.ImageGenerations
import co.yml.ychat.entrypoint.impl.YChatImpl
import kotlin.jvm.JvmStatic
Expand Down Expand Up @@ -93,6 +94,21 @@ interface YChat {
*/
fun imageGenerations(): ImageGenerations

/**
* The edits api is used to edit prompts and re-generate. Given a prompt and an instruction,
* the model will return an edited version of the prompt.
*
* You can configure the parameters of the edits before executing it. Example:
* ```
* val result = YChat.create(apiKey).edits()
* .setInput("As Descartes said, I think, therefore")
* .setInstruction("Fix spelling mistakes")
* .set...
* .execute()
* ```
*/
fun edits(): Edits

/**
* Callback is an interface used for handling the results of an operation.
* It provides two methods, `onSuccess` and `onError`, for handling the success
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import co.yml.ychat.data.dto.ChatCompletionParamsDto
import co.yml.ychat.data.dto.ChatCompletionsDto
import co.yml.ychat.data.dto.CompletionDto
import co.yml.ychat.data.dto.CompletionParamsDto
import co.yml.ychat.data.dto.EditsDto
import co.yml.ychat.data.dto.EditsParamsDto
import co.yml.ychat.data.dto.ImageGenerationsDto
import co.yml.ychat.data.dto.ImageGenerationsParamsDto
import co.yml.ychat.data.infrastructure.ApiResult
Expand All @@ -15,4 +17,6 @@ internal interface ChatGptApi {
suspend fun chatCompletions(paramsDto: ChatCompletionParamsDto): ApiResult<ChatCompletionsDto>

suspend fun imageGenerations(paramsDto: ImageGenerationsParamsDto): ApiResult<ImageGenerationsDto>

suspend fun edits(paramsDto: EditsParamsDto): ApiResult<EditsDto>
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import co.yml.ychat.data.dto.ChatCompletionParamsDto
import co.yml.ychat.data.dto.ChatCompletionsDto
import co.yml.ychat.data.dto.CompletionDto
import co.yml.ychat.data.dto.CompletionParamsDto
import co.yml.ychat.data.dto.EditsDto
import co.yml.ychat.data.dto.EditsParamsDto
import co.yml.ychat.data.dto.ImageGenerationsDto
import co.yml.ychat.data.dto.ImageGenerationsParamsDto
import co.yml.ychat.data.infrastructure.ApiExecutor
Expand Down Expand Up @@ -36,4 +38,12 @@ internal class ChatGptApiImpl(private val apiExecutor: ApiExecutor) : ChatGptApi
.setBody(paramsDto)
.execute()
}

override suspend fun edits(paramsDto: EditsParamsDto): ApiResult<EditsDto> {
return apiExecutor
.setEndpoint("v1/edits")
.setHttpMethod(HttpMethod.Post)
.setBody(paramsDto)
.execute()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package co.yml.ychat.data.dto

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
internal data class EditsChoiceDto(
@SerialName("text")
val text: String,
@SerialName("index")
val index: Int
)
16 changes: 16 additions & 0 deletions ychat/src/commonMain/kotlin/co/yml/ychat/data/dto/EditsDto.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package co.yml.ychat.data.dto

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
internal data class EditsDto(
@SerialName("object")
val objectType: String,
@SerialName("created")
val created: Long,
@SerialName("choices")
val choices: List<EditsChoiceDto>,
@SerialName("usage")
val usage: UsageDto,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package co.yml.ychat.data.dto

import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable

@Serializable
internal data class EditsParamsDto(
@SerialName("model")
val model: String,
@SerialName("input")
val input: String,
@SerialName("instruction")
val instruction: String,
@SerialName("n")
val results: Int = 1,
@SerialName("temperature")
val temperature: Double,
@SerialName("top_p")
val topP: Double,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@ import co.yml.ychat.data.storage.ChatLogStorage
import co.yml.ychat.di.provider.NetworkProvider
import co.yml.ychat.domain.usecases.ChatCompletionsUseCase
import co.yml.ychat.domain.usecases.CompletionUseCase
import co.yml.ychat.domain.usecases.EditsUseCase
import co.yml.ychat.domain.usecases.ImageGenerationsUseCase
import co.yml.ychat.entrypoint.features.ChatCompletions
import co.yml.ychat.entrypoint.features.Completion
import co.yml.ychat.entrypoint.features.Edits
import co.yml.ychat.entrypoint.features.ImageGenerations
import co.yml.ychat.entrypoint.impl.ChatCompletionsImpl
import co.yml.ychat.entrypoint.impl.CompletionImpl
import co.yml.ychat.entrypoint.impl.EditsImpl
import co.yml.ychat.entrypoint.impl.ImageGenerationsImpl
import kotlinx.coroutines.Dispatchers
import org.koin.core.module.Module
Expand All @@ -27,12 +30,14 @@ internal class LibraryModule(private val apiKey: String) {
factory<Completion> { CompletionImpl(Dispatchers.Default, get()) }
factory<ChatCompletions> { ChatCompletionsImpl(Dispatchers.Default, get()) }
factory<ImageGenerations> { ImageGenerationsImpl(Dispatchers.Default, get()) }
factory<Edits> { EditsImpl(Dispatchers.Default, get()) }
}

private val domainModule = module {
factory { CompletionUseCase(get(), get()) }
factory { ChatCompletionsUseCase(get()) }
factory { ImageGenerationsUseCase(get()) }
factory { EditsUseCase(get()) }
}

private val dataModule = module {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package co.yml.ychat.domain.mapper

import co.yml.ychat.data.dto.EditsDto
import co.yml.ychat.data.dto.EditsParamsDto
import co.yml.ychat.domain.model.EditsParams

internal fun EditsDto.toEditsModel(): List<String> {
return this.choices.map { it.text }
}

internal fun EditsParams.toEditsParamsDto(): EditsParamsDto {
return EditsParamsDto(
model = this.model,
input = this.input,
instruction = this.instruction,
results = this.results,
temperature = this.temperature,
topP = this.topP
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package co.yml.ychat.domain.model

internal data class EditsParams(
var model: String = "text-davinci-edit-001",
var input: String = "",
var instruction: String = "",
var results: Int = 1,
var temperature: Double = 1.0,
var topP: Double = 1.0,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package co.yml.ychat.domain.usecases

import co.yml.ychat.data.api.ChatGptApi
import co.yml.ychat.domain.mapper.toEditsModel
import co.yml.ychat.domain.mapper.toEditsParamsDto
import co.yml.ychat.domain.model.EditsParams

internal data class EditsUseCase(private val chatGptApi: ChatGptApi) {

suspend fun requestEdits(params: EditsParams): List<String> {
val requestDto = params.toEditsParamsDto()
val response = chatGptApi.edits(requestDto)
return response.getBodyOrThrow().toEditsModel()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package co.yml.ychat.entrypoint.features

import co.yml.ychat.YChat
import co.yml.ychat.data.exception.ChatGptException
import kotlin.coroutines.cancellation.CancellationException

interface Edits {

fun setInput(input: String): Edits

fun setResults(results: Int): Edits

fun setModel(model: String): Edits

fun setTemperature(temperature: Double): Edits

fun setTopP(topP: Double): Edits

@Throws(CancellationException::class, ChatGptException::class)
suspend fun execute(instruction: String): List<String>

fun execute(instruction: String, callback: YChat.Callback<List<String>>)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package co.yml.ychat.entrypoint.impl

import co.yml.ychat.YChat
import co.yml.ychat.domain.model.EditsParams
import co.yml.ychat.domain.usecases.EditsUseCase
import co.yml.ychat.entrypoint.features.Edits
import kotlinx.coroutines.CoroutineDispatcher
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.launch

internal class EditsImpl(
private val dispatcher: CoroutineDispatcher,
private val editsUseCase: EditsUseCase
) : Edits {

private val scope by lazy { CoroutineScope(SupervisorJob() + dispatcher) }

private var params: EditsParams = EditsParams()

override fun setInput(input: String): Edits {
params.input = input
return this
}

override fun setResults(results: Int): Edits {
params.results = results
return this
}

override fun setModel(model: String): Edits {
params.model = model
return this
}

override fun setTemperature(temperature: Double): Edits {
params.temperature = temperature
return this
}

override fun setTopP(topP: Double): Edits {
params.topP = topP
return this
}

override suspend fun execute(instruction: String): List<String> {
params.instruction = instruction
return editsUseCase.requestEdits(params)
}

override fun execute(instruction: String, callback: YChat.Callback<List<String>>) {
scope.launch {
runCatching { execute(instruction) }
.onSuccess { callback.onSuccess(it) }
.onFailure { callback.onError(it) }
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import co.yml.ychat.YChat
import co.yml.ychat.di.module.LibraryModule
import co.yml.ychat.entrypoint.features.ChatCompletions
import co.yml.ychat.entrypoint.features.Completion
import co.yml.ychat.entrypoint.features.Edits
import co.yml.ychat.entrypoint.features.ImageGenerations
import org.koin.core.KoinApplication

Expand All @@ -27,4 +28,8 @@ internal class YChatImpl(apiKey: String) : YChat {
override fun imageGenerations(): ImageGenerations {
return koinApp.koin.get()
}

override fun edits(): Edits {
return koinApp.koin.get()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import co.yml.ychat.domain.usecases.CompletionUseCase
import co.yml.ychat.domain.usecases.ImageGenerationsUseCase
import co.yml.ychat.entrypoint.features.ChatCompletions
import co.yml.ychat.entrypoint.features.Completion
import co.yml.ychat.entrypoint.features.Edits
import co.yml.ychat.entrypoint.features.ImageGenerations
import io.ktor.client.HttpClient
import kotlin.test.AfterTest
Expand Down Expand Up @@ -43,5 +44,6 @@ class LibraryModuleTest : KoinTest {
get<ChatCompletions>()
get<ImageGenerationsUseCase>()
get<ImageGenerations>()
get<Edits>()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package co.yml.ychat.domain.mapper

import co.yml.ychat.data.dto.EditsChoiceDto
import co.yml.ychat.data.dto.EditsDto
import co.yml.ychat.data.dto.ImageGenerationsDto
import co.yml.ychat.data.dto.UsageDto
import co.yml.ychat.domain.model.EditsParams
import co.yml.ychat.domain.model.ImageGeneratedDto
import co.yml.ychat.domain.model.ImageGenerationsParams
import kotlin.test.Test
import kotlin.test.assertEquals

class EditsMapperTest {

@Test
fun `on convert EditsDto to EditsModel`() {
val listOfChoicesDto = listOf(EditsChoiceDto("text 1", 1), EditsChoiceDto("text 2", 2))
val editsDto = EditsDto(
created = 12345,
objectType = "edit",
choices = listOfChoicesDto,
usage = UsageDto(1, 1, 1)
)
assertEquals(listOfChoicesDto.map { it.text }, editsDto.toEditsModel())
}

@Test
fun `on convert EditsParams to EditsDto`() {
val editsParams = EditsParams(input = "this is a test")
assertEquals("text-davinci-edit-001", editsParams.toEditsParamsDto().model)
assertEquals("this is a test", editsParams.toEditsParamsDto().input)
assertEquals("", editsParams.toEditsParamsDto().instruction)
assertEquals(1, editsParams.toEditsParamsDto().results)
assertEquals(1.0, editsParams.toEditsParamsDto().temperature)
assertEquals(1.0, editsParams.toEditsParamsDto().topP)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package co.yml.ychat.domain.model

import kotlin.test.Test
import kotlin.test.assertEquals

class EditsParamsTest {

@Test
fun `on EditsParams verify default values`() {
// arrange
val params = EditsParams()

// assert
assertEquals("text-davinci-edit-001", params.model)
assertEquals("", params.input)
assertEquals("", params.instruction)
assertEquals(1, params.results)
assertEquals(1.0, params.temperature)
assertEquals(1.0, params.topP)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import kotlin.test.assertEquals
class ImageGenerationsParamsTest {

@Test
fun `on ChatCompletionsParams verify default values`() {
fun `on ImageGenerationsParams verify default values`() {
// arrange
val params = ImageGenerationsParams()

Expand Down
Loading

0 comments on commit 9cb929f

Please sign in to comment.