Skip to content

Commit

Permalink
Merge pull request #48 from yml-org/feature/CM-1218/image-generations
Browse files Browse the repository at this point in the history
feat: Add ImageGenerations API
  • Loading branch information
osugikoji authored Mar 15, 2023
2 parents d34eb8f + 1f69fb9 commit 9b6440b
Show file tree
Hide file tree
Showing 33 changed files with 613 additions and 12 deletions.
2 changes: 2 additions & 0 deletions buildSrc/src/main/kotlin/Dependencies.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ object Versions {
const val COMPOSE_ACTIVITY = "1.6.1"
const val COMPOSE_NAVIGATION = "2.5.3"
const val COMPOSE_LIVEDATA = "1.3.3"
const val COIL = "2.2.2"
const val KTOR = "2.2.2"
const val KOIN = "3.2.0"
const val MATERIAL_DESIGN = "1.6.1"
Expand Down Expand Up @@ -52,6 +53,7 @@ object Dependencies {
const val COMPOSE_ACTIVITY = "androidx.activity:activity-compose:${Versions.COMPOSE_ACTIVITY}"
const val COMPOSE_NAVIGATION = "androidx.navigation:navigation-compose:${Versions.COMPOSE_NAVIGATION}"
const val COMPOSE_LIVEDATA = "androidx.compose.runtime:runtime-livedata:${Versions.COMPOSE_LIVEDATA}"
const val COIL = "io.coil-kt:coil-compose:${Versions.COIL}"
}

object Test {
Expand Down
1 change: 1 addition & 0 deletions sample/android/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ dependencies {
implementation(Dependencies.UI.COMPOSE_ACTIVITY)
implementation(Dependencies.UI.COMPOSE_NAVIGATION)
implementation(Dependencies.UI.COMPOSE_LIVEDATA)
implementation(Dependencies.UI.COIL)
implementation(Dependencies.DI.KOIN_CORE)
implementation(Dependencies.DI.KOIN_ANDROID)
implementation(Dependencies.DI.KOIN_COMPOSE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import androidx.lifecycle.MutableLiveData
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import co.yml.ychat.YChat
import co.yml.ychat.YChat.Callback
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch

Expand All @@ -20,6 +21,10 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
)
}

private val imageGenerations by lazy {
chatGpt.imageGenerations()
}

private val _items = mutableStateListOf<MessageItem>()
val items = _items

Expand All @@ -30,7 +35,7 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
private var typingItem = mutableStateOf(MessageItem(message = typingTxt.value, isOut = false))

private fun setLoading(isLoading: Boolean) {
_isLoading.value = isLoading
_isLoading.postValue(isLoading)
}

fun onSendMessage(message: String, typingStr: String) {
Expand All @@ -41,6 +46,22 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
}
}

fun onImageRequest(prompt: String, typingStr: String) {
updateTypingMessage(typingStr)
viewModelScope.launch {
showTypingAnimation(prompt)
imageGenerations.execute(prompt, object : Callback<List<String>> {
override fun onSuccess(result: List<String>) {
showImages(result)
}

override fun onError(throwable: Throwable) {
writeResponse(ERROR)
}
})
}
}

private suspend fun showTypingAnimation(message: String) {
items.add(MessageItem(message = message, isOut = true))
delay((1000..2000).random().toLong())
Expand All @@ -54,6 +75,14 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
setLoading(false)
}

private fun showImages(result: List<String>) {
items.remove(items[items.lastIndex])
result.forEach {
items.add(MessageItem(message = IMAGE, isOut = false, url = it))
}
setLoading(false)
}

private suspend fun requestCompletion(message: String): String {
return try {
chatCompletions.execute(message).last().content
Expand All @@ -77,5 +106,6 @@ class MainViewModel(private val chatGpt: YChat) : ViewModel() {
companion object {
private const val ERROR = "Error"
private const val MAX_TOKENS = 1024
private const val IMAGE = "image"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ package co.yml.ychat.android
data class MessageItem(
val message: String,
val isOut: Boolean,
val url: String? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,15 @@ fun ChatLayout(
.padding(spaceMedium),
) {
items(messages) { message ->
MessageItemLayout(
messageText = message.message, isOut = message.isOut
)
message.url?.let {
ImageItemLayout(
messageText = message.url, isOut = message.isOut
)
} ?: run {
MessageItemLayout(
messageText = message.message, isOut = message.isOut
)
}
}
coroutineScope.launch {
listState.animateScrollToItem(messages.size)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package co.yml.ychat.android.ui

import android.content.res.Configuration.UI_MODE_NIGHT_YES
import androidx.compose.foundation.Image
import androidx.compose.foundation.background
import androidx.compose.foundation.layout.Box
import androidx.compose.foundation.layout.Column
import androidx.compose.foundation.layout.Row
import androidx.compose.foundation.layout.Spacer
import androidx.compose.foundation.layout.fillMaxWidth
import androidx.compose.foundation.layout.height
import androidx.compose.foundation.layout.padding
import androidx.compose.foundation.layout.width
import androidx.compose.foundation.shape.CircleShape
import androidx.compose.foundation.shape.RoundedCornerShape
import androidx.compose.runtime.Composable
import androidx.compose.ui.Alignment
import androidx.compose.ui.Modifier
import androidx.compose.ui.draw.clip
import androidx.compose.ui.res.colorResource
import androidx.compose.ui.res.painterResource
import androidx.compose.ui.tooling.preview.Preview
import androidx.compose.ui.unit.dp
import co.yml.ychat.android.R
import co.yml.ychat.android.ui.Dimensions.default
import co.yml.ychat.android.ui.Dimensions.robotMessageIconSize
import co.yml.ychat.android.ui.Dimensions.robotMessagePaddingSize
import co.yml.ychat.android.ui.Dimensions.spaceExtraSmall
import co.yml.ychat.android.ui.Dimensions.spaceMedium
import co.yml.ychat.android.ui.Dimensions.spaceSmall
import coil.compose.AsyncImage

@Composable
fun ImageItemLayout(
messageText: String,
isOut: Boolean
) {
Column(
modifier = Modifier.fillMaxWidth(),
horizontalAlignment = if (isOut) Alignment.End else Alignment.Start
) {
Row(
modifier = Modifier.padding(top = spaceMedium),
verticalAlignment = Alignment.Bottom
) {
if (isOut.not()) {
Image(
painterResource(R.drawable.ic_robot),
contentDescription = "",
modifier = Modifier
.width(robotMessageIconSize)
.height(robotMessageIconSize)
.clip(shape = CircleShape)
.background(colorResource(id = R.color.softGreen))
.padding(robotMessagePaddingSize),
)
Spacer(modifier = Modifier.padding(spaceExtraSmall))
}
Box(
modifier = Modifier
.clip(
shape = RoundedCornerShape(
topStart = spaceMedium,
topEnd = spaceMedium,
bottomEnd = if (isOut) default else spaceMedium,
bottomStart = if (isOut) spaceMedium else default
)
)
.background(if (isOut) colorResource(id = R.color.softBlue) else colorResource(id = R.color.opaqueWhite))
.padding(spaceSmall)
) {
AsyncImage(
modifier = Modifier.clip(RoundedCornerShape(8.dp)),
model = messageText,
contentDescription = messageText,
placeholder = painterResource(R.drawable.ic_robot),
)
}
}
}
}

@Preview(uiMode = UI_MODE_NIGHT_YES)
@Composable
fun PreviewImageItemLayout() {
MessageItemLayout(messageText = "Message", isOut = false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ fun SendMessageLayout() {
val scope = rememberCoroutineScope()
val viewModel = koinViewModel<MainViewModel>()
val isLoading: Boolean by viewModel.isLoading.observeAsState(initial = false)

Row(
modifier = Modifier
.background(color = MaterialTheme.colors.background)
Expand Down Expand Up @@ -91,7 +92,11 @@ fun SendMessageLayout() {
.background(if (textFieldState.isNotEmpty() && isLoading.not()) colorResource(id = R.color.softBlue) else colorResource(id = R.color.opaqueWhite)),
onClick = {
scope.launch {
viewModel.onSendMessage(textFieldState, typingString)
if (textFieldState.startsWith("/image ")) {
viewModel.onImageRequest(textFieldState, typingString)
} else {
viewModel.onSendMessage(textFieldState, typingString)
}
textFieldState = ""
}
},
Expand Down
33 changes: 30 additions & 3 deletions sample/ios/YChatApp/Features/Completion/CompletionView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,14 @@ private extension CompletionView {
}
case .bot:
HStack {
botChatBubble(message: chatMessage.message)
Spacer().frame(width: 60)
Spacer()
if let imageUrl = chatMessage.url {
botImageBubble(imageUrl)
Spacer()
} else {
botChatBubble(message: chatMessage.message)
Spacer().frame(width: 60)
Spacer()
}
}
case .loading:
HStack {
Expand Down Expand Up @@ -120,6 +125,28 @@ private extension CompletionView {
.cornerRadius(16, corners: [.bottomLeft, .bottomLeft, .topRight])
}
}

@ViewBuilder
private func botImageBubble(_ url: String) -> some View {
HStack(alignment: .top, spacing: 4) {
Circle()
.fill(.green)
.frame(width: 40, height: 40)
.overlay {
Image(uiImage: Icon.bot.uiImage)
.renderingMode(.template)
.foregroundColor(.white)
}
ZStack {
AsyncImage(url: URL(string: url))
.foregroundColor(.grayDark)
}
.padding(.horizontal, 16)
.padding(.vertical, 8)
.background(Color.grayLight)
.cornerRadius(16, corners: [.bottomLeft, .bottomLeft, .topRight])
}
}

@ViewBuilder
private func sendMessageSection() -> some View {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct ChatMessage: Identifiable, Equatable {
let id: String
var message: String = ""
var type: MessageType = .human(error: false)
var url: String?

enum MessageType: Equatable {
case human(error: Bool), bot, loading
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ internal final class CompletionViewModel: ObservableObject {
content: "You are a helpful assistant."
)

private var imageGenerations: ImageGenerations =
YChatCompanion.shared.create(apiKey: Config.apiKey)
.imageGenerations()

@Published
var message: String = ""

Expand All @@ -37,9 +41,15 @@ internal final class CompletionViewModel: ObservableObject {
cleanLastMessage()
addLoading()
do {
let result = try await chatCompletions.execute(content: input)[0].content
removeLoading()
addAIMessage(message: result)
if input.contains("/image ") {
let result = try await imageGenerations.execute(prompt: input)[0].url
removeLoading()
addAIImage(url: result)
} else {
let result = try await chatCompletions.execute(content: input)[0].content
removeLoading()
addAIMessage(message: result)
}
} catch {
removeLoading()
setError()
Expand All @@ -64,6 +74,15 @@ internal final class CompletionViewModel: ObservableObject {
)
chatMessageList.append(chatMessage)
}

private func addAIImage(url: String) {
let chatMessage = ChatMessage(
id: UUID().uuidString,
type: .bot,
url: url
)
chatMessageList.append(chatMessage)
}

private func addLoading() {
let chatMessage = ChatMessage(
Expand Down
16 changes: 15 additions & 1 deletion sample/jvm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,18 @@ This endpoint generates text based on the provided prompt and a specified topic.

##### Example:

`GET http://localhost:8080/api/ychat/chat-completions?input="Tell me an exercise plan"&topic=fitness`
`GET http://localhost:8080/api/ychat/chat-completions?input="Tell me an exercise plan"&topic=fitness`

### Image Generations Endpoint

This endpoint generates images based on the provided prompt.

##### Endpoint: http://localhost:[port_number]/api/ychat/generations

##### Parameters:

- `prompt`: The prompt for generating images.

##### Example:

`GET http://localhost:8080/api/ychat/generations?prompt="ocean"
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ public ResponseEntity<String> chatCompletions(
return ResponseEntity.ok(result);
}

@GetMapping("generations")
public ResponseEntity<String> imageGenerations(
@RequestParam(value = "prompt", defaultValue = Defaults.IMAGE_GENERATION_TOPIC) String input
) throws Exception {
String result = YChatService.getImageGenerationsAnswer(input);
return ResponseEntity.ok(result);
}

private static class Defaults {
static final String COMPLETION_INPUT = "Say this is a test.";
static final String CHAT_COMPLETION_INPUT = "Tell me one strength exercise";
static final String CHAT_COMPLETION_TOPIC = "fitness";
static final String IMAGE_GENERATION_TOPIC = "ocean";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ public String getChatCompletionsAnswer(String input, String topic) throws Except
return future.get().get(0).getContent();
}

public String getImageGenerationsAnswer(String prompt) throws Exception {
final CompletableFuture<List<String>> future = new CompletableFuture<>();
ychat.imageGenerations()
.execute(prompt, new CompletionCallbackResult<>(future));
return future.get().get(0);
}

private static class CompletionCallbackResult<T> implements YChat.Callback<T> {

private final CompletableFuture<T> future;
Expand Down
Loading

0 comments on commit 9b6440b

Please sign in to comment.