diff --git a/app/client/api.ts b/app/client/api.ts index cecc453baa2..9f2cb23053e 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -26,6 +26,7 @@ export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; export const Models = ["gpt-3.5-turbo", "gpt-4"] as const; +export const TTSModels = ["tts-1", "tts-1-hd"] as const; export type ChatModel = ModelType; export interface MultimodalContent { @@ -54,6 +55,25 @@ export interface LLMConfig { style?: DalleRequestPayload["style"]; } +export interface SpeechOptions { + model: string; + input: string; + voice: string; + response_format?: string; + speed?: number; + onController?: (controller: AbortController) => void; +} + +export interface TranscriptionOptions { + model?: "whisper-1"; + file: Blob; + language?: string; + prompt?: string; + response_format?: "json" | "text" | "srt" | "verbose_json" | "vtt"; + temperature?: number; + onController?: (controller: AbortController) => void; +} + export interface ChatOptions { messages: RequestMessage[]; config: LLMConfig; @@ -88,6 +108,8 @@ export interface LLMModelProvider { export abstract class LLMApi { abstract chat(options: ChatOptions): Promise; + abstract speech(options: SpeechOptions): Promise; + abstract transcription(options: TranscriptionOptions): Promise; abstract usage(): Promise; abstract models(): Promise; } @@ -206,13 +228,16 @@ export function validString(x: string): boolean { return x?.length > 0; } -export function getHeaders() { +export function getHeaders(ignoreHeaders?: boolean) { const accessStore = useAccessStore.getState(); const chatStore = useChatStore.getState(); - const headers: Record = { - "Content-Type": "application/json", - Accept: "application/json", - }; + let headers: Record = {}; + if (!ignoreHeaders) { + headers = { + "Content-Type": "application/json", + Accept: "application/json", + }; + } const clientConfig = getClientConfig(); diff --git a/app/client/platforms/alibaba.ts b/app/client/platforms/alibaba.ts index d5fa3042fc1..e839c69f01f 100644 --- a/app/client/platforms/alibaba.ts +++ b/app/client/platforms/alibaba.ts @@ -12,6 +12,8 @@ import { getHeaders, LLMApi, LLMModel, + SpeechOptions, + TranscriptionOptions, MultimodalContent, } from "../api"; import Locale from "../../locales"; @@ -83,6 +85,13 @@ export class QwenApi implements LLMApi { return res?.output?.choices?.at(0)?.message?.content ?? ""; } + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } + async chat(options: ChatOptions) { const messages = options.messages.map((v) => ({ role: v.role, diff --git a/app/client/platforms/anthropic.ts b/app/client/platforms/anthropic.ts index 7dd39c9cddc..2ab67ed1371 100644 --- a/app/client/platforms/anthropic.ts +++ b/app/client/platforms/anthropic.ts @@ -1,5 +1,12 @@ import { ACCESS_CODE_PREFIX, Anthropic, ApiPath } from "@/app/constant"; -import { ChatOptions, getHeaders, LLMApi, MultimodalContent } from "../api"; +import { + ChatOptions, + getHeaders, + LLMApi, + MultimodalContent, + SpeechOptions, + TranscriptionOptions, +} from "../api"; import { useAccessStore, useAppConfig, @@ -80,6 +87,13 @@ const ClaudeMapper = { const keys = ["claude-2, claude-instant-1"]; export class ClaudeApi implements LLMApi { + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } + extractMessage(res: any) { console.log("[Response] claude response: ", res); diff --git a/app/client/platforms/baidu.ts b/app/client/platforms/baidu.ts index 3be147f4985..0c2be5fb14b 100644 --- a/app/client/platforms/baidu.ts +++ b/app/client/platforms/baidu.ts @@ -14,6 +14,8 @@ import { LLMApi, LLMModel, MultimodalContent, + SpeechOptions, + TranscriptionOptions, } from "../api"; import Locale from "../../locales"; import { @@ -75,6 +77,13 @@ export class ErnieApi implements LLMApi { return [baseUrl, path].join("/"); } + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } + async chat(options: ChatOptions) { const messages = options.messages.map((v) => ({ // "error_code": 336006, "error_msg": "the role of message with even index in the messages must be user or function", diff --git a/app/client/platforms/bytedance.ts b/app/client/platforms/bytedance.ts index 7677cafe12b..5a0c9b8b12e 100644 --- a/app/client/platforms/bytedance.ts +++ b/app/client/platforms/bytedance.ts @@ -13,6 +13,8 @@ import { LLMApi, LLMModel, MultimodalContent, + SpeechOptions, + TranscriptionOptions, } from "../api"; import Locale from "../../locales"; import { @@ -77,6 +79,13 @@ export class DoubaoApi implements LLMApi { return res.choices?.at(0)?.message?.content ?? ""; } + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } + async chat(options: ChatOptions) { const messages = options.messages.map((v) => ({ role: v.role, diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index 12d8846357a..c8d3658b350 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -1,5 +1,13 @@ import { ApiPath, Google, REQUEST_TIMEOUT_MS } from "@/app/constant"; -import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; +import { + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + LLMUsage, + SpeechOptions, + TranscriptionOptions, +} from "../api"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { getClientConfig } from "@/app/config/client"; import { DEFAULT_API_HOST } from "@/app/constant"; @@ -56,6 +64,12 @@ export class GeminiProApi implements LLMApi { "" ); } + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } async chat(options: ChatOptions): Promise { const apiClient = this; let multimodal = false; diff --git a/app/client/platforms/iflytek.ts b/app/client/platforms/iflytek.ts index 73cea5ba0e7..6463e052e40 100644 --- a/app/client/platforms/iflytek.ts +++ b/app/client/platforms/iflytek.ts @@ -7,7 +7,14 @@ import { } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; -import { ChatOptions, getHeaders, LLMApi, LLMModel } from "../api"; +import { + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + SpeechOptions, + TranscriptionOptions, +} from "../api"; import Locale from "../../locales"; import { EventStreamContentType, @@ -53,6 +60,13 @@ export class SparkApi implements LLMApi { return res.choices?.at(0)?.message?.content ?? ""; } + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } + async chat(options: ChatOptions) { const messages: ChatOptions["messages"] = []; for (const v of options.messages) { diff --git a/app/client/platforms/moonshot.ts b/app/client/platforms/moonshot.ts index cd10d2f6c15..173ecd14c9d 100644 --- a/app/client/platforms/moonshot.ts +++ b/app/client/platforms/moonshot.ts @@ -26,6 +26,8 @@ import { LLMModel, LLMUsage, MultimodalContent, + SpeechOptions, + TranscriptionOptions, } from "../api"; import Locale from "../../locales"; import { @@ -72,6 +74,13 @@ export class MoonshotApi implements LLMApi { return res.choices?.at(0)?.message?.content ?? ""; } + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } + async chat(options: ChatOptions) { const messages: ChatOptions["messages"] = []; for (const v of options.messages) { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 664ff872ba3..71b7731fa0b 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -33,6 +33,8 @@ import { LLMModel, LLMUsage, MultimodalContent, + SpeechOptions, + TranscriptionOptions, } from "../api"; import Locale from "../../locales"; import { @@ -84,7 +86,7 @@ export interface DalleRequestPayload { export class ChatGPTApi implements LLMApi { private disableListModels = true; - path(path: string): string { + path(path: string, model?: string): string { const accessStore = useAccessStore.getState(); let baseUrl = ""; @@ -147,6 +149,85 @@ export class ChatGPTApi implements LLMApi { return res.choices?.at(0)?.message?.content ?? res; } + async speech(options: SpeechOptions): Promise { + const requestPayload = { + model: options.model, + input: options.input, + voice: options.voice, + response_format: options.response_format, + speed: options.speed, + }; + + console.log("[Request] openai speech payload: ", requestPayload); + + const controller = new AbortController(); + options.onController?.(controller); + + try { + const speechPath = this.path(OpenaiPath.SpeechPath, options.model); + const speechPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + + const res = await fetch(speechPath, speechPayload); + clearTimeout(requestTimeoutId); + return await res.arrayBuffer(); + } catch (e) { + console.log("[Request] failed to make a speech request", e); + throw e; + } + } + + async transcription(options: TranscriptionOptions): Promise { + const formData = new FormData(); + formData.append("file", options.file, "audio.wav"); + formData.append("model", options.model ?? "whisper-1"); + if (options.language) formData.append("language", options.language); + if (options.prompt) formData.append("prompt", options.prompt); + if (options.response_format) + formData.append("response_format", options.response_format); + if (options.temperature) + formData.append("temperature", options.temperature.toString()); + + console.log("[Request] openai audio transcriptions payload: ", options); + + const controller = new AbortController(); + options.onController?.(controller); + + try { + const path = this.path(OpenaiPath.TranscriptionPath, options.model); + const headers = getHeaders(true); + const payload = { + method: "POST", + body: formData, + signal: controller.signal, + headers: headers, + }; + + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + const res = await fetch(path, payload); + clearTimeout(requestTimeoutId); + const json = await res.json(); + return json.text; + } catch (e) { + console.log("[Request] failed to make a audio transcriptions request", e); + throw e; + } + } + async chat(options: ChatOptions) { const modelConfig = { ...useAppConfig.getState().modelConfig, diff --git a/app/client/platforms/tencent.ts b/app/client/platforms/tencent.ts index 579008a9b9d..1739b7a142b 100644 --- a/app/client/platforms/tencent.ts +++ b/app/client/platforms/tencent.ts @@ -8,6 +8,8 @@ import { LLMApi, LLMModel, MultimodalContent, + SpeechOptions, + TranscriptionOptions, } from "../api"; import Locale from "../../locales"; import { @@ -89,6 +91,13 @@ export class HunyuanApi implements LLMApi { return res.Choices?.at(0)?.Message?.Content ?? ""; } + speech(options: SpeechOptions): Promise { + throw new Error("Method not implemented."); + } + transcription(options: TranscriptionOptions): Promise { + throw new Error("Method not implemented."); + } + async chat(options: ChatOptions) { const visionModel = isVisionModel(options.config.model); const messages = options.messages.map((v, index) => ({ diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 3cc02d48672..cb03440775e 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -10,11 +10,14 @@ import React, { } from "react"; import SendWhiteIcon from "../icons/send-white.svg"; +import VoiceWhiteIcon from "../icons/voice-white.svg"; import BrainIcon from "../icons/brain.svg"; import RenameIcon from "../icons/rename.svg"; import ExportIcon from "../icons/share.svg"; import ReturnIcon from "../icons/return.svg"; import CopyIcon from "../icons/copy.svg"; +import SpeakIcon from "../icons/speak.svg"; +import SpeakStopIcon from "../icons/speak-stop.svg"; import LoadingIcon from "../icons/three-dots.svg"; import LoadingButtonIcon from "../icons/loading.svg"; import PromptIcon from "../icons/prompt.svg"; @@ -70,6 +73,7 @@ import { isDalle3, showPlugins, safeLocalStorage, + isFirefox, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -79,7 +83,7 @@ import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; import { DalleSize, DalleQuality, DalleStyle } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; -import Locale from "../locales"; +import Locale, { getLang, getSTTLang } from "../locales"; import { IconButton } from "./button"; import styles from "./chat.module.scss"; @@ -96,6 +100,10 @@ import { import { useNavigate } from "react-router-dom"; import { CHAT_PAGE_SIZE, + DEFAULT_STT_ENGINE, + DEFAULT_TTS_ENGINE, + FIREFOX_DEFAULT_STT_ENGINE, + ModelProvider, LAST_INPUT_KEY, Path, REQUEST_TIMEOUT_MS, @@ -113,6 +121,16 @@ import { useAllModels } from "../utils/hooks"; import { MultimodalContent } from "../client/api"; const localStorage = safeLocalStorage(); +import { ClientApi } from "../client/api"; +import { createTTSPlayer } from "../utils/audio"; +import { + OpenAITranscriptionApi, + SpeechApi, + WebTranscriptionApi, +} from "../utils/speech"; +import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts"; + +const ttsPlayer = createTTSPlayer(); const Markdown = dynamic(async () => (await import("./markdown")).Markdown, { loading: () => , @@ -443,6 +461,7 @@ export function ChatActions(props: { hitBottom: boolean; uploading: boolean; setShowShortcutKeyModal: React.Dispatch>; + setUserInput: (input: string) => void; }) { const config = useAppConfig(); const navigate = useNavigate(); @@ -537,6 +556,44 @@ export function ChatActions(props: { } }, [chatStore, currentModel, models]); + const [isListening, setIsListening] = useState(false); + const [isTranscription, setIsTranscription] = useState(false); + const [speechApi, setSpeechApi] = useState(null); + + useEffect(() => { + if (isFirefox()) config.sttConfig.engine = FIREFOX_DEFAULT_STT_ENGINE; + setSpeechApi( + config.sttConfig.engine === DEFAULT_STT_ENGINE + ? new WebTranscriptionApi((transcription) => + onRecognitionEnd(transcription), + ) + : new OpenAITranscriptionApi((transcription) => + onRecognitionEnd(transcription), + ), + ); + }, []); + + const startListening = async () => { + if (speechApi) { + await speechApi.start(); + setIsListening(true); + } + }; + const stopListening = async () => { + if (speechApi) { + if (config.sttConfig.engine !== DEFAULT_STT_ENGINE) + setIsTranscription(true); + await speechApi.stop(); + setIsListening(false); + } + }; + const onRecognitionEnd = (finalTranscript: string) => { + console.log(finalTranscript); + if (finalTranscript) props.setUserInput(finalTranscript); + if (config.sttConfig.engine !== DEFAULT_STT_ENGINE) + setIsTranscription(false); + }; + return (
{couldStop && ( @@ -771,6 +828,16 @@ export function ChatActions(props: { icon={} /> )} + + {config.sttConfig.enable && ( + + isListening ? await stopListening() : await startListening() + } + text={isListening ? Locale.Chat.StopSpeak : Locale.Chat.StartSpeak} + icon={} + /> + )}
); } @@ -1184,10 +1251,55 @@ function _Chat() { }); }; + const accessStore = useAccessStore(); + const [speechStatus, setSpeechStatus] = useState(false); + const [speechLoading, setSpeechLoading] = useState(false); + async function openaiSpeech(text: string) { + if (speechStatus) { + ttsPlayer.stop(); + setSpeechStatus(false); + } else { + var api: ClientApi; + api = new ClientApi(ModelProvider.GPT); + const config = useAppConfig.getState(); + setSpeechLoading(true); + ttsPlayer.init(); + let audioBuffer: ArrayBuffer; + const { markdownToTxt } = require("markdown-to-txt"); + const textContent = markdownToTxt(text); + if (config.ttsConfig.engine !== DEFAULT_TTS_ENGINE) { + const edgeVoiceName = accessStore.edgeVoiceName(); + const tts = new MsEdgeTTS(); + await tts.setMetadata( + edgeVoiceName, + OUTPUT_FORMAT.AUDIO_24KHZ_96KBITRATE_MONO_MP3, + ); + audioBuffer = await tts.toArrayBuffer(textContent); + } else { + audioBuffer = await api.llm.speech({ + model: config.ttsConfig.model, + input: textContent, + voice: config.ttsConfig.voice, + speed: config.ttsConfig.speed, + }); + } + setSpeechStatus(true); + ttsPlayer + .play(audioBuffer, () => { + setSpeechStatus(false); + }) + .catch((e) => { + console.error("[OpenAI Speech]", e); + showToast(prettyObject(e)); + setSpeechStatus(false); + }) + .finally(() => setSpeechLoading(false)); + } + } + const context: RenderMessage[] = useMemo(() => { return session.mask.hideContext ? [] : session.mask.context.slice(); }, [session.mask.context, session.mask.hideContext]); - const accessStore = useAccessStore(); if ( context.length === 0 && @@ -1724,6 +1836,25 @@ function _Chat() { ) } /> + {config.ttsConfig.enable && ( + + ) : ( + + ) + } + onClick={() => + openaiSpeech(getMessageTextContent(message)) + } + /> + )} )} @@ -1842,6 +1973,7 @@ function _Chat() { onSearch(""); }} setShowShortcutKeyModal={setShowShortcutKeyModal} + setUserInput={setUserInput} />