Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gradio demo to support text/voice conversation #75

Merged
merged 34 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
947c151
Update poetry.lock file
Aug 7, 2024
6961767
Update gradio_demo.py to support conversation
Aug 7, 2024
27f2090
Update
Aug 7, 2024
fd29cd8
Update
Aug 7, 2024
85a8105
Update
Aug 8, 2024
71b8b09
Support audio input in conversation
Aug 8, 2024
5d163f6
Cleanup code
Aug 9, 2024
ae24225
Update
Aug 9, 2024
499e775
Update
Aug 9, 2024
919839c
Fix formatting
Aug 9, 2024
a2d0c2d
Fix format and signature issues
Aug 9, 2024
e57eb62
Update
Aug 9, 2024
d016229
Update config
Aug 9, 2024
47586db
Update LocalInference to support history
Aug 13, 2024
533234c
Rename max_new_tokens to max_tokens in LocalInference
Aug 13, 2024
461fbf9
Remove num_beams from gradio due to error
Aug 13, 2024
41baf4c
Remove poetry config from Justfile
Aug 13, 2024
a458b77
Rename max_new_tokens to max_tokens in infer_stream
Aug 13, 2024
d7074a3
Update
Aug 13, 2024
fdd53ff
Update to include conversation_mode flag
Aug 14, 2024
bdfdfbd
Update
Aug 14, 2024
29c44e5
Update pyproject.toml
zqhuang211 Aug 14, 2024
b8ec886
Update pyproject.toml
zqhuang211 Aug 14, 2024
4e288c0
Move cache_position to function signature
Aug 14, 2024
32fc18d
Update
Aug 14, 2024
04781ce
Update
Aug 14, 2024
6a370d9
Update
Aug 14, 2024
be3cd15
Update
Aug 14, 2024
0160332
Include max_new_tokens in demo page, update to use v0.3 as the defaul…
Aug 14, 2024
3a3d749
Add max_new_tokens and temperature to input arguments
Aug 14, 2024
2811cde
Update
Aug 14, 2024
9ef8240
Address Justin's comments
Aug 16, 2024
fcf85d1
Update
Aug 16, 2024
717dcbe
Merge branch 'main' into zhuang/support-gradio-conversation
zqhuang211 Aug 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ ipython_config.py
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
poetry.toml

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
Expand Down
3,377 changes: 1,601 additions & 1,776 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.11"
torch = "2.2.2"
torch = "2.4"
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
transformers = {version = ">=4.43.1", extras = ["torch"]}
bitsandbytes = "~0.42.0"
peft = "~0.11.1"
simple-parsing = "~0.1.5"
librosa = "~0.10.2.post1"
requests = "~2.26.0"
requests = "~2.31.0"
datasets = "~2.19.1"
mosaicml-streaming = "~0.7.6"
nltk = "~3.8.1"
Expand All @@ -39,8 +39,8 @@ fsspec = "~2024.3.1"
gcsfs = "~2024.3.1"
sounddevice = "~0.4.7"
mosaicml-cli = "~0.6.31"
gradio-client = "~1.0.1"
gradio = "~3.40.1"
gradio-client = "~0.16.1"
gradio = "~4.29.0"
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
gpustat = "~1.1.1"
types-requests = "~2.26.0"
types-pyyaml = "^6.0.12.20240724"
Expand Down
3 changes: 3 additions & 0 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def __post_init__(self):
), f"Unexpected audio dtype: {self.audio.dtype}"
assert self.audio.ndim == 1, f"Unexpected audio shape: {self.audio.shape}"

def add_past_messages(self, past_messages: List[Dict[str, str]]):
self.messages = past_messages + self.messages

messages: List[Dict[str, str]]
"""List of messages, each with a "role" and "content" field."""
audio: Optional[np.typing.NDArray[np.float32]] = None
Expand Down
1 change: 1 addition & 0 deletions ultravox/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class VoiceOutput:
text: str
input_tokens: int
output_tokens: int
audio_token_len: int = 0
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved


class InferenceMessage:
Expand Down
68 changes: 59 additions & 9 deletions ultravox/inference/infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import threading
from typing import Optional
from typing import Dict, List, Optional, Tuple, Union

import librosa
import numpy as np
Expand All @@ -11,7 +12,7 @@
from ultravox.model import ultravox_processing

SAMPLE_RATE = 16000
MAX_TOKENS = 1024
MAX_NEW_TOKENS = 1024
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
# Without this penalty, the model tends to repeat itself.
REPETITION_PENALTY = 1.1

Expand All @@ -29,20 +30,61 @@ def __init__(
self.tokenizer = tokenizer
self.processor = processor
self.dtype = dtype
self.past_messages: List[Dict[str, str]] = []
self.past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = (
None
)

def reset_history(self):
self.past_messages = []
self.past_key_values = None

def _add_past_message(self, message: Dict[str, str], audio_token_len: int):
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
message = copy.copy(message)
content = message["content"]
if audio_token_len > 0:
if content.count("<|audio|>") != 1:
raise ValueError(
f"Expected 1 audio placeholder, found {content.count('<|audio|>')}"
)
message["content"] = content.replace(
"<|audio|>", self.tokenizer.eos_token * audio_token_len
)

self.past_messages.append(message)

def _get_sample_with_past(
self, sample: datasets.VoiceSample
) -> datasets.VoiceSample:
sample = copy.copy(sample)
sample.add_past_messages(self.past_messages)
return sample

def infer(
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> base.VoiceOutput:
inputs = self._dataproc(sample)
extended_sample = self._get_sample_with_past(sample)
inputs = self._dataproc(extended_sample)
input_len = inputs["input_ids"].shape[1]
output = self._generate(inputs, max_tokens, temperature)
output_tokens = output[0][input_len:]
output = self._generate(
inputs, max_tokens, temperature, self.past_key_values
)
output_tokens = output.sequences[0][input_len:]
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)
return base.VoiceOutput(output_text, input_len, output_len)

# update history
audio_token_len = (
0 if "audio_token_len" not in inputs else inputs["audio_token_len"][0]
)
self._add_past_message(extended_sample.messages[-1], audio_token_len)
self._add_past_message({"role": "assistant", "content": output_text}, 0)
self.past_key_values = output.past_key_values

return base.VoiceOutput(output_text, input_len, output_len, audio_token_len)

def infer_stream(
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
self,
Expand All @@ -57,7 +99,12 @@ def infer_stream(
self.tokenizer, skip_prompt=True, decode_kwargs=decode_kwargs
)

thread_args = (inputs, max_tokens, temperature, streamer)
thread_args = (
inputs,
max_tokens,
temperature,
streamer,
)
thread = threading.Thread(target=self._generate, args=thread_args)
thread.start()
output_tokens = 0
Expand Down Expand Up @@ -108,8 +155,9 @@ def _dataproc(self, sample: datasets.VoiceSample):
def _generate(
self,
inputs: torch.Tensor,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
streamer: Optional[transformers.TextStreamer] = None,
):
temperature = temperature or None
Expand All @@ -122,10 +170,12 @@ def _generate(
return self.model.generate(
**inputs,
do_sample=do_sample,
max_new_tokens=max_tokens or MAX_TOKENS,
max_new_tokens=max_new_tokens or MAX_NEW_TOKENS,
temperature=temperature,
repetition_penalty=REPETITION_PENALTY,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=terminators,
streamer=streamer,
past_key_values=past_key_values,
return_dict_in_generate=True,
)
12 changes: 9 additions & 3 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,16 @@ def prepare_inputs_for_generation(
**kwargs,
)

zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
if is_cache_empty(past_key_values) and audio_values is not None:
# We only want to use audio features in the 1st generation step
prefill_start_idx = kwargs["cache_position"][0]
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
if (
audio_values is not None
and audio_token_start_idx is not None
and prefill_start_idx <= torch.max(audio_token_start_idx)
):
model_input["audio_values"] = audio_values
model_input["audio_token_start_idx"] = audio_token_start_idx
model_input["audio_token_start_idx"] = (
audio_token_start_idx - prefill_start_idx
)
model_input["audio_token_len"] = audio_token_len

return model_input
Expand Down
120 changes: 104 additions & 16 deletions ultravox/tools/gradio_demo.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from dataclasses import dataclass
from typing import Tuple
from typing import Optional

import gradio as gr
import numpy as np
import simple_parsing

from ultravox.data import datasets
from ultravox.inference import ultravox_infer

demo_instruction: str = """Enter your prompt here (audio will be inserted at the end or at <|audio|>).

Text mode: Shift+Enter to submit.
Voice mode: Click the recording button to start, then click again to stop and submit.
"""


@dataclass
class DemoConfig:
Expand All @@ -16,27 +21,110 @@ class DemoConfig:
# runs/llama2_asr_gigaspeech/checkpoint-1000/
# wandb://fixie/ultravox/model-llama2_asr_gigaspeech:v0
model_path: str = "fixie-ai/ultravox"
default_prompt: str = "Transcribe\n<|audio|>"
# Use <|audio|> to specify where to insert audio, otherwise, audio is inserted at the end in voice mode.
default_prompt: str = ""
max_new_tokens: int = 256
device: str = "mps"
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
data_type: str = "float16"


def main():
args = simple_parsing.parse(config_class=DemoConfig)
inference = ultravox_infer.UltravoxInference(args.model_path)
inference = ultravox_infer.UltravoxInference(
args.model_path, device=args.device, data_type=args.data_type
)

def wrapper(text: str, audio: Tuple[int, np.ndarray]) -> str:
sample = datasets.VoiceSample.from_prompt_and_raw(text, audio[1], audio[0])
return inference.infer(sample, max_tokens=64).text
def add_text(chatbot: gr.Chatbot, text: str) -> gr.Chatbot:
return chatbot + [(text, None)]

inputs = [
gr.Textbox(label="Prompt", value=args.default_prompt),
gr.Audio(label="Audio", show_download_button=True),
]
outputs = [gr.Textbox(label="Output")]
examples = [["Transcribe\n<|audio|>", "examples/test16.wav"]]
def add_audio(chatbot: gr.Chatbot, audio: str) -> gr.Chatbot:
return chatbot + [((audio,), None)]

gr.Interface(fn=wrapper, inputs=inputs, outputs=outputs, examples=examples).launch(
share=True
)
def process_turn(
chatbot: gr.Chatbot,
prompt: str,
audio: Optional[str] = None,
temperature: float = 0,
):
# We want to keep the prompt (mixed audio/text instruction) as is in voice mode, but set it to "" in anticipation of new prompt in text mode.
prompt_to_return = prompt
if audio:
if "<|audio|>" not in prompt:
prompt += "<|audio|>"
sample = datasets.VoiceSample.from_prompt_and_file(prompt, audio)
else:
sample = datasets.VoiceSample.from_prompt(prompt)
prompt_to_return = ""

if len(sample.messages) != 1:
raise ValueError(
f"Expected exactly 1 message in sample but got {len(sample.messages)}"
)

output = inference.infer(
sample,
max_tokens=args.max_new_tokens,
temperature=temperature,
)

chatbot = chatbot + [(None, output.text)]
return chatbot, gr.update(value=prompt_to_return)

def process_text(chatbot, prompt, temperature):
return process_turn(chatbot, prompt, None, temperature)

def process_audio(chatbot, prompt, audio, temperature):
return process_turn(chatbot, prompt, audio, temperature)

def gradio_reset():
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
inference.reset_history()
return [], "", None

with gr.Blocks() as demo:
chatbot = gr.Chatbot(scale=10, height=1000)

with gr.Row():
with gr.Column(scale=1):
reset = gr.Button("Reset")
audio = gr.Audio(
label="🎤",
sources=["microphone"],
type="filepath",
visible=True,
)
with gr.Column(scale=8):
prompt = gr.Textbox(
show_label=False,
lines=5,
placeholder=demo_instruction,
value=args.default_prompt,
container=True,
)
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0,
maximum=5.0,
value=0,
step=0.1,
interactive=True,
label="temperature",
)

prompt.submit(add_text, [chatbot, prompt], [chatbot], queue=False).then(
process_text,
[chatbot, prompt, temperature],
[chatbot, prompt],
queue=False,
)
audio.stop_recording(add_audio, [chatbot, audio], [chatbot], queue=False).then(
process_audio,
[chatbot, prompt, audio, temperature],
[chatbot, prompt],
queue=False,
)
reset.click(gradio_reset, [], [chatbot, prompt, audio], queue=False)

demo.launch(share=True)


if __name__ == "__main__":
Expand Down
7 changes: 6 additions & 1 deletion ultravox/tools/infer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import json
import os
import tempfile
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple, Union

import gradio_client
import numpy as np
import requests
import transformers

from ultravox.data import datasets
from ultravox.inference import base
Expand All @@ -23,6 +24,7 @@ def infer(
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
) -> base.VoiceOutput:
text = ""
stats = None
Expand All @@ -41,6 +43,7 @@ def infer_stream(
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
) -> base.InferenceGenerator:
url = f"{self._base_url}/chat/completions"
headers = {"Content-Type": "application/json"}
Expand Down Expand Up @@ -104,6 +107,7 @@ def infer(
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
) -> base.VoiceOutput:
headers = {"Content-Type": "application/json"}
response = requests.post(
Expand All @@ -127,6 +131,7 @@ def infer(
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
) -> base.VoiceOutput:
# For some reason the most recent Gradio endpoint only accepts
# audio as a file, not as a base64-encoded string. There's probably
Expand Down
Loading