Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add ImageGenerations API #48

Merged
merged 12 commits into from
Mar 15, 2023
Merged
2 changes: 2 additions & 0 deletions buildSrc/src/main/kotlin/Dependencies.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ object Versions {
const val COMPOSE_ACTIVITY = "1.6.1"
const val COMPOSE_NAVIGATION = "2.5.3"
const val COMPOSE_LIVEDATA = "1.3.3"
const val COIL = "2.2.2"
const val KTOR = "2.2.2"
const val KOIN = "3.2.0"
const val MATERIAL_DESIGN = "1.6.1"
Expand Down Expand Up @@ -52,6 +53,7 @@ object Dependencies {
const val COMPOSE_ACTIVITY = "androidx.activity:activity-compose:${Versions.COMPOSE_ACTIVITY}"
const val COMPOSE_NAVIGATION = "androidx.navigation:navigation-compose:${Versions.COMPOSE_NAVIGATION}"
const val COMPOSE_LIVEDATA = "androidx.compose.runtime:runtime-livedata:${Versions.COMPOSE_LIVEDATA}"
const val COIL = "io.coil-kt:coil-compose:${Versions.COIL}"
}

object Test {
Expand Down
1 change: 1 addition & 0 deletions sample/android/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies {
implementation(Dependencies.UI.COMPOSE_ACTIVITY)
implementation(Dependencies.UI.COMPOSE_NAVIGATION)
implementation(Dependencies.UI.COMPOSE_LIVEDATA)
implementation(Dependencies.UI.COIL)
implementation(Dependencies.DI.KOIN_CORE)
implementation(Dependencies.DI.KOIN_ANDROID)
implementation(Dependencies.DI.KOIN_COMPOSE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import androidx.lifecycle.MutableLiveData
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import co.yml.ychat.YChat
import co.yml.ychat.YChat.Callback
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch

Expand All @@ -20,6 +21,10 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
)
}

private val imageGenerations by lazy {
chatGpt.imageGenerations()
}

private val _items = mutableStateListOf<MessageItem>()
val items = _items

Expand All @@ -30,7 +35,7 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
private var typingItem = mutableStateOf(MessageItem(message = typingTxt.value, isOut = false))

private fun setLoading(isLoading: Boolean) {
_isLoading.value = isLoading
_isLoading.postValue(isLoading)
}

fun onSendMessage(message: String, typingStr: String) {
Expand All @@ -41,6 +46,22 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
}
}

fun onImageRequest(prompt: String, typingStr: String) {
updateTypingMessage(typingStr)
viewModelScope.launch {
showTypingAnimation(prompt)
imageGenerations.execute(prompt, object : Callback<List<String>> {
override fun onSuccess(result: List<String>) {
showImages(result)
}

override fun onError(throwable: Throwable) {
writeResponse(ERROR)
}
})
}
}

private suspend fun showTypingAnimation(message: String) {
items.add(MessageItem(message = message, isOut = true))
delay((1000..2000).random().toLong())
Expand All @@ -54,6 +75,14 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
setLoading(false)
}

private fun showImages(result: List<String>) {
items.remove(items[items.lastIndex])
result.forEach {
items.add(MessageItem(message = IMAGE, isOut = false, url = it))
}
setLoading(false)
}

private suspend fun requestCompletion(message: String): String {
return try {
chatCompletions.execute(message).last().content
Expand All @@ -77,5 +106,6 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
companion object {
private const val ERROR = "Error"
private const val MAX_TOKENS = 1024
private const val IMAGE = "image"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ package co.yml.ychat.android
data class MessageItem(
val message: String,
val isOut: Boolean,
val url: String? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,15 @@ fun ChatLayout(
.padding(spaceMedium),
) {
items(messages) { message ->
MessageItemLayout(
messageText = message.message, isOut = message.isOut
)
message.url?.let {
ImageItemLayout(
messageText = message.url, isOut = message.isOut
)
} ?: run {
MessageItemLayout(
messageText = message.message, isOut = message.isOut
)
}
}
coroutineScope.launch {
listState.animateScrollToItem(messages.size)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package co.yml.ychat.android.ui

import android.content.res.Configuration.UI_MODE_NIGHT_YES
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.colorResource
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import co.yml.ychat.android.R
import co.yml.ychat.android.ui.Dimensions.default
import co.yml.ychat.android.ui.Dimensions.robotMessageIconSize
import co.yml.ychat.android.ui.Dimensions.robotMessagePaddingSize
import co.yml.ychat.android.ui.Dimensions.spaceExtraSmall
import co.yml.ychat.android.ui.Dimensions.spaceMedium
import co.yml.ychat.android.ui.Dimensions.spaceSmall
import coil.compose.AsyncImage

@Composable
fun ImageItemLayout(
messageText: String,
isOut: Boolean
) {
Column(
modifier = Modifier.fillMaxWidth(),
horizontalAlignment = if (isOut) Alignment.End else Alignment.Start
) {
Row(
modifier = Modifier.padding(top = spaceMedium),
verticalAlignment = Alignment.Bottom
) {
if (isOut.not()) {
Image(
painterResource(R.drawable.ic_robot),
contentDescription = "",
modifier = Modifier
.width(robotMessageIconSize)
.height(robotMessageIconSize)
.clip(shape = CircleShape)
.background(colorResource(id = R.color.softGreen))
.padding(robotMessagePaddingSize),
)
Spacer(modifier = Modifier.padding(spaceExtraSmall))
}
Box(
modifier = Modifier
.clip(
shape = RoundedCornerShape(
topStart = spaceMedium,
topEnd = spaceMedium,
bottomEnd = if (isOut) default else spaceMedium,
bottomStart = if (isOut) spaceMedium else default
)
)
.background(if (isOut) colorResource(id = R.color.softBlue) else colorResource(id = R.color.opaqueWhite))
.padding(spaceSmall)
) {
AsyncImage(
modifier = Modifier.clip(RoundedCornerShape(8.dp)),
model = messageText,
contentDescription = messageText,
placeholder = painterResource(R.drawable.ic_robot),
)
}
}
}
}

@Preview(uiMode = UI_MODE_NIGHT_YES)
@Composable
fun PreviewImageItemLayout() {
MessageItemLayout(messageText = "Message", isOut = false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ fun SendMessageLayout() {
val scope = rememberCoroutineScope()
val viewModel = koinViewModel<MainViewModel>()
val isLoading: Boolean by viewModel.isLoading.observeAsState(initial = false)

Row(
modifier = Modifier
.background(color = MaterialTheme.colors.background)
Expand Down Expand Up @@ -91,7 +92,11 @@ fun SendMessageLayout() {
.background(if (textFieldState.isNotEmpty() && isLoading.not()) colorResource(id = R.color.softBlue) else colorResource(id = R.color.opaqueWhite)),
onClick = {
scope.launch {
viewModel.onSendMessage(textFieldState, typingString)
if (textFieldState.startsWith("/image ")) {
viewModel.onImageRequest(textFieldState, typingString)
} else {
viewModel.onSendMessage(textFieldState, typingString)
}
textFieldState = ""
}
},
Expand Down
33 changes: 30 additions & 3 deletions sample/ios/YChatApp/Features/Completion/CompletionView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,14 @@ private extension CompletionView {
}
case .bot:
HStack {
botChatBubble(message: chatMessage.message)
Spacer().frame(width: 60)
Spacer()
if let imageUrl = chatMessage.url {
botImageBubble(imageUrl)
Spacer()
} else {
botChatBubble(message: chatMessage.message)
Spacer().frame(width: 60)
Spacer()
}
}
case .loading:
HStack {
Expand Down Expand Up @@ -120,6 +125,28 @@ private extension CompletionView {
.cornerRadius(16, corners: [.bottomLeft, .bottomLeft, .topRight])
}
}

@ViewBuilder
private func botImageBubble(_ url: String) -> some View {
HStack(alignment: .top, spacing: 4) {
Circle()
.fill(.green)
.frame(width: 40, height: 40)
.overlay {
Image(uiImage: Icon.bot.uiImage)
.renderingMode(.template)
.foregroundColor(.white)
}
ZStack {
AsyncImage(url: URL(string: url))
.foregroundColor(.grayDark)
}
.padding(.horizontal, 16)
.padding(.vertical, 8)
.background(Color.grayLight)
.cornerRadius(16, corners: [.bottomLeft, .bottomLeft, .topRight])
}
}

@ViewBuilder
private func sendMessageSection() -> some View {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct ChatMessage: Identifiable, Equatable {
let id: String
var message: String = ""
var type: MessageType = .human(error: false)
var url: String?

enum MessageType: Equatable {
case human(error: Bool), bot, loading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ internal final class CompletionViewModel: ObservableObject {
content: "You are a helpful assistant."
)

private var imageGenerations: ImageGenerations =
YChatCompanion.shared.create(apiKey: Config.apiKey)
.imageGenerations()

@Published
var message: String = ""

Expand All @@ -37,9 +41,15 @@ internal final class CompletionViewModel: ObservableObject {
cleanLastMessage()
addLoading()
do {
let result = try await chatCompletions.execute(content: input)[0].content
removeLoading()
addAIMessage(message: result)
if input.contains("/image ") {
let result = try await imageGenerations.execute(prompt: input)[0].url
removeLoading()
addAIImage(url: result)
} else {
let result = try await chatCompletions.execute(content: input)[0].content
removeLoading()
addAIMessage(message: result)
}
} catch {
removeLoading()
setError()
Expand All @@ -64,6 +74,15 @@ internal final class CompletionViewModel: ObservableObject {
)
chatMessageList.append(chatMessage)
}

private func addAIImage(url: String) {
let chatMessage = ChatMessage(
id: UUID().uuidString,
type: .bot,
url: url
)
chatMessageList.append(chatMessage)
}

private func addLoading() {
let chatMessage = ChatMessage(
Expand Down
16 changes: 15 additions & 1 deletion sample/jvm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,18 @@ This endpoint generates text based on the provided prompt and a specified topic.

##### Example:

`GET http://localhost:8080/api/ychat/chat-completions?input="Tell me an exercise plan"&topic=fitness`
`GET http://localhost:8080/api/ychat/chat-completions?input="Tell me an exercise plan"&topic=fitness`

### Image Generations Endpoint

This endpoint generates images based on the provided prompt.

##### Endpoint: http://localhost:[port_number]/api/ychat/generations

##### Parameters:

- `prompt`: The prompt for generating images.

##### Example:

`GET http://localhost:8080/api/ychat/generations?prompt="ocean"
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ public ResponseEntity<String> chatCompletions(
return ResponseEntity.ok(result);
}

@GetMapping("generations")
osugikoji marked this conversation as resolved.
Show resolved Hide resolved
public ResponseEntity<String> imageGenerations(
@RequestParam(value = "prompt", defaultValue = Defaults.IMAGE_GENERATION_TOPIC) String input
) throws Exception {
String result = YChatService.getImageGenerationsAnswer(input);
return ResponseEntity.ok(result);
}

private static class Defaults {
static final String COMPLETION_INPUT = "Say this is a test.";
static final String CHAT_COMPLETION_INPUT = "Tell me one strength exercise";
static final String CHAT_COMPLETION_TOPIC = "fitness";
static final String IMAGE_GENERATION_TOPIC = "ocean";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ public String getChatCompletionsAnswer(String input, String topic) throws Except
return future.get().get(0).getContent();
}

public String getImageGenerationsAnswer(String prompt) throws Exception {
final CompletableFuture<List<String>> future = new CompletableFuture<>();
ychat.imageGenerations()
.execute(prompt, new CompletionCallbackResult<>(future));
return future.get().get(0);
}

private static class CompletionCallbackResult<T> implements YChat.Callback<T> {

private final CompletableFuture<T> future;
Expand Down
Loading