Skip to content

Commit

Permalink
feat(ui): add abort button's logic, fix rehypePrism type error
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh990918 committed Apr 26, 2023
1 parent 36da48c commit aae7830
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 58 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"@evan-yang/eslint-config": "^1.0.9",
"@iconify-json/carbon": "^1.1.16",
"@iconify-json/simple-icons": "^1.1.48",
"@types/mapbox__rehype-prism": "^0.8.0",
"@typescript-eslint/parser": "^5.57.1",
"@unocss/preset-attributify": "^0.50.6",
"@unocss/preset-icons": "^0.50.6",
Expand Down
33 changes: 16 additions & 17 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 20 additions & 7 deletions src/components/Send.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export default () => {
const $currentErrorMessage = useStore(currentErrorMessage)
const $streamsMap = useStore(streamsMap)
const $loadingStateMap = useStore(loadingStateMap)
const [controller, setController] = createSignal<AbortController>()

const [inputPrompt, setInputPrompt] = createSignal('')
const isEditing = () => inputPrompt() || $isSendBoxFocus()
Expand Down Expand Up @@ -93,15 +94,26 @@ export default () => {
</div>
)

const clearPrompt = () => {
setInputPrompt('')
inputRef.value = ''
isSendBoxFocus.set(false)
}

const handleAbortFetch = () => {
controller()!.abort()
clearPrompt()
}

const LoadingState = () => (
<div class="max-w-base h-full fi flex-row gap-2">
<div class="flex-1 op-50">Thinking...</div>
{/* <div
<div
class="border border-darker px-2 py-1 rounded-md text-sm op-40 hv-base hover:bg-white"
onClick={() => { }}
onClick={() => { handleAbortFetch() }}
>
Abort
</div> */}
</div>
</div>
)

Expand All @@ -110,10 +122,11 @@ export default () => {
return
if (!currentConversation())
addConversation()
handlePrompt(currentConversation(), inputRef.value)
setInputPrompt('')
inputRef.value = ''
isSendBoxFocus.set(false)

const controller = new AbortController()
setController(controller)
handlePrompt(currentConversation(), inputRef.value, controller.signal)
clearPrompt()
scrollController().scrollToBottom()
}

Expand Down
31 changes: 17 additions & 14 deletions src/logics/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import type { CallProviderPayload, HandlerPayload, PromptResponse } from '@/type
import type { Conversation } from '@/types/conversation'
import type { ErrorMessage } from '@/types/message'

export const handlePrompt = async(conversation: Conversation, prompt: string) => {
export const handlePrompt = async(conversation: Conversation, prompt: string, signal?: AbortSignal) => {
const generalSettings = getGeneralSettings()
const provider = getProviderById(conversation?.providerId)
if (!provider) return
Expand Down Expand Up @@ -40,15 +40,17 @@ export const handlePrompt = async(conversation: Conversation, prompt: string) =>
historyMessages: getMessagesByConversationId(conversation.id),
}
try {
providerResponse = await getProviderResponse(callMethod, providerPayload)
providerResponse = await getProviderResponse(callMethod, providerPayload, signal)
} catch (e) {
const error = e as Error
const cause = error?.cause as ErrorMessage
console.error(e)
currentErrorMessage.set({
code: cause?.code || 'provider_error',
message: cause?.message || error.message || 'Unknown error',
})
setLoadingStateByConversationId(conversation.id, false)
if (error.name !== 'AbortError') {
currentErrorMessage.set({
code: cause?.code || 'provider_error',
message: cause?.message || error.message || 'Unknown error',
})
}
}

if (providerResponse) {
Expand All @@ -72,20 +74,21 @@ export const handlePrompt = async(conversation: Conversation, prompt: string) =>
// Update conversation title
if (providerResponse && conversation.conversationType === 'continuous' && !conversation.name) {
const rapidPayload = generateRapidProviderPayload(promptHelper.summarizeText(prompt), conversation.providerId)
const generatedTitle = await getProviderResponse(callMethod, rapidPayload).catch(() => {}) as string || prompt
const generatedTitle = await getProviderResponse(callMethod, rapidPayload, signal).catch(() => {}) as string || prompt
updateConversationById(conversation.id, {
name: generatedTitle,
})
}
}

const getProviderResponse = async(caller: 'frontend' | 'backend', payload: CallProviderPayload) => {
const getProviderResponse = async(caller: 'frontend' | 'backend', payload: CallProviderPayload, signal?: AbortSignal) => {
if (caller === 'frontend') {
return callProviderHandler(payload)
return callProviderHandler(payload, signal)
} else {
const backendResponse = await fetch('/api/handle', {
method: 'POST',
body: JSON.stringify(payload),
signal,
})
if (!backendResponse.ok) {
const error = await backendResponse.json()
Expand All @@ -101,7 +104,7 @@ const getProviderResponse = async(caller: 'frontend' | 'backend', payload: CallP
}

// Called by both client and server
export const callProviderHandler = async(payload: CallProviderPayload) => {
export const callProviderHandler = async(payload: CallProviderPayload, signal?: AbortSignal) => {
console.log('callProviderHandler', payload)

const { conversationMeta, providerId, prompt, historyMessages } = payload
Expand All @@ -117,15 +120,15 @@ export const callProviderHandler = async(payload: CallProviderPayload) => {
mockMessages: [],
}
if (conversationMeta.conversationType === 'single') {
response = await provider.handleSinglePrompt?.(prompt, handlerPayload)
response = await provider.handleSinglePrompt?.(prompt, handlerPayload, signal)
} else if (conversationMeta.conversationType === 'continuous') {
const messages = historyMessages.map(message => ({
role: message.role,
content: message.content,
}))
response = await provider.handleContinuousPrompt?.(messages, handlerPayload)
response = await provider.handleContinuousPrompt?.(messages, handlerPayload, signal)
} else if (conversationMeta.conversationType === 'image') {
response = await provider.handleImagePrompt?.(prompt, handlerPayload)
response = await provider.handleImagePrompt?.(prompt, handlerPayload, signal)
} else if (conversationMeta.conversationType === 'rapid') {
response = await provider.handleRapidPrompt?.(prompt, handlerPayload.globalSettings)
}
Expand Down
28 changes: 16 additions & 12 deletions src/logics/stream.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import type { Setter } from 'solid-js'

export const convertReadableStreamToAccessor = async(stream: ReadableStream, setter: Setter<string>) => {
const reader = stream.getReader()
const decoder = new TextDecoder('utf-8')
let done = false
let text = ''
while (!done) {
const { value, done: readerDone } = await reader.read()
if (value) {
const char = decoder.decode(value)
if (char) {
text += char
setter(text)
try {
const reader = stream.getReader()
const decoder = new TextDecoder('utf-8')
let done = false
while (!done) {
const { value, done: readerDone } = await reader.read()
if (value) {
const char = decoder.decode(value)
if (char) {
text += char
setter(text)
}
}
done = readerDone
}
done = readerDone
return text
} catch (error) {
return text
}
return text
}
3 changes: 3 additions & 0 deletions src/providers/openai/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export interface OpenAIFetchPayload {
apiKey: string
baseUrl: string
body: Record<string, any>
signal?: AbortSignal
}

export const fetchChatCompletion = async(payload: OpenAIFetchPayload) => {
Expand All @@ -12,6 +13,7 @@ export const fetchChatCompletion = async(payload: OpenAIFetchPayload) => {
},
method: 'POST',
body: JSON.stringify(payload.body),
signal: payload.signal,
}
return fetch(`${payload.baseUrl}/v1/chat/completions`, initOptions)
}
Expand All @@ -24,6 +26,7 @@ export const fetchImageGeneration = async(payload: OpenAIFetchPayload) => {
},
method: 'POST',
body: JSON.stringify(payload.body),
signal: payload.signal,
}
return fetch(`${payload.baseUrl}/v1/images/generations`, initOptions)
}
Loading

1 comment on commit aae7830

@vercel
Copy link

@vercel vercel bot commented on aae7830 Apr 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yzh990918 is attempting to deploy a commit to the Anse Team on Vercel.

To accomplish this, @yzh990918 needs to request access to the Team.

Afterwards, an owner of the Team is required to accept their membership request.

If you're already a member of the respective Vercel Team, make sure that your Personal Vercel Account is connected to your GitHub account.

Please sign in to comment.