Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update gradio demo to support text/voice conversation #75

Merged
merged 34 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
947c151
Update poetry.lock file
Aug 7, 2024
6961767
Update gradio_demo.py to support conversation
Aug 7, 2024
27f2090
Update
Aug 7, 2024
fd29cd8
Update
Aug 7, 2024
85a8105
Update
Aug 8, 2024
71b8b09
Support audio input in conversation
Aug 8, 2024
5d163f6
Cleanup code
Aug 9, 2024
ae24225
Update
Aug 9, 2024
499e775
Update
Aug 9, 2024
919839c
Fix formatting
Aug 9, 2024
a2d0c2d
Fix format and signature issues
Aug 9, 2024
e57eb62
Update
Aug 9, 2024
d016229
Update config
Aug 9, 2024
47586db
Update LocalInference to support history
Aug 13, 2024
533234c
Rename max_new_tokens to max_tokens in LocalInference
Aug 13, 2024
461fbf9
Remove num_beams from gradio due to error
Aug 13, 2024
41baf4c
Remove poetry config from Justfile
Aug 13, 2024
a458b77
Rename max_new_tokens to max_tokens in infer_stream
Aug 13, 2024
d7074a3
Update
Aug 13, 2024
fdd53ff
Update to include conversation_mode flag
Aug 14, 2024
bdfdfbd
Update
Aug 14, 2024
29c44e5
Update pyproject.toml
zqhuang211 Aug 14, 2024
b8ec886
Update pyproject.toml
zqhuang211 Aug 14, 2024
4e288c0
Move cache_position to function signature
Aug 14, 2024
32fc18d
Update
Aug 14, 2024
04781ce
Update
Aug 14, 2024
6a370d9
Update
Aug 14, 2024
be3cd15
Update
Aug 14, 2024
0160332
Include max_new_tokens in demo page, update to use v0.3 as the defaul…
Aug 14, 2024
3a3d749
Add max_new_tokens and temperature to input arguments
Aug 14, 2024
2811cde
Update
Aug 14, 2024
9ef8240
Address Justin's comments
Aug 16, 2024
fcf85d1
Update
Aug 16, 2024
717dcbe
Merge branch 'main' into zhuang/support-gradio-conversation
zqhuang211 Aug 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ ipython_config.py
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
poetry.toml

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
Expand Down
1 change: 1 addition & 0 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ default: format check test

install:
pip install poetry==1.7.1
poetry config virtualenvs.in-project true --local
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
poetry install

format:
Expand Down
3,377 changes: 1,601 additions & 1,776 deletions poetry.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.11"
torch = "2.2.2"
torch = "2.4"
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
transformers = {version = ">=4.43.1", extras = ["torch"]}
bitsandbytes = "~0.42.0"
peft = "~0.11.1"
simple-parsing = "~0.1.5"
librosa = "~0.10.2.post1"
requests = "~2.26.0"
requests = "~2.31.0"
datasets = "~2.19.1"
mosaicml-streaming = "~0.7.6"
nltk = "~3.8.1"
Expand All @@ -39,8 +39,8 @@ fsspec = "~2024.3.1"
gcsfs = "~2024.3.1"
sounddevice = "~0.4.7"
mosaicml-cli = "~0.6.31"
gradio-client = "~1.0.1"
gradio = "~3.40.1"
gradio-client = "~0.16.1"
gradio = "~4.29.0"
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
gpustat = "~1.1.1"
types-requests = "~2.26.0"
types-pyyaml = "^6.0.12.20240724"
Expand Down
3 changes: 3 additions & 0 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def __post_init__(self):
), f"Unexpected audio dtype: {self.audio.dtype}"
assert self.audio.ndim == 1, f"Unexpected audio shape: {self.audio.shape}"

def add_past_messages(self, past_messages: List[Dict[str, str]]):
self.messages = past_messages + self.messages

messages: List[Dict[str, str]]
"""List of messages, each with a "role" and "content" field."""
audio: Optional[np.typing.NDArray[np.float32]] = None
Expand Down
52 changes: 50 additions & 2 deletions ultravox/inference/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import abc
import dataclasses
from typing import Generator, Optional
from typing import Dict, Generator, List, Optional, Tuple, Union

import transformers

from ultravox.data import datasets

Expand All @@ -10,6 +12,8 @@ class VoiceOutput:
text: str
input_tokens: int
output_tokens: int
audio_token_len: int = 0
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None


class InferenceMessage:
Expand Down Expand Up @@ -37,6 +41,7 @@ def infer(
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
) -> VoiceOutput:
pass

Expand All @@ -45,8 +50,51 @@ def infer_stream(
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
) -> InferenceGenerator:
"""Streaming polyfill, if not supported directly in derived classes."""
output = self.infer(sample, max_tokens, temperature)
output = self.infer(sample, max_tokens, temperature, past_key_values)
yield InferenceChunk(output.text)
yield InferenceStats(output.input_tokens, output.output_tokens)


class History:
def __init__(self, audio_token_replacement: str = "<|eot_token|>"):
self.audio_token_replacement: str = audio_token_replacement
self.audio_placeholder = "<|audio|>"
self.messages: List[Dict[str, str]] = []
self.key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None

def add_message(self, message: Dict[str, str], audio_token_len: int):
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
message = message.copy()
content = message["content"]
if audio_token_len > 0:
if content.count(self.audio_placeholder) != 1:
raise ValueError(
f"Expected 1 audio placeholder, found {content.count(self.audio_placeholder)}"
)
message["content"] = content.replace(
self.audio_placeholder, self.audio_token_replacement * audio_token_len
)

if self.messages:
self.messages.append(message)
else:
self.messages = [message]

def update_key_values(
self, key_values: Union[Tuple, transformers.cache_utils.Cache]
):
self.key_values = key_values

@property
def past_messages(self) -> List[Dict[str, str]]:
return self.messages

@property
def past_key_values(self) -> Optional[Union[Tuple, transformers.cache_utils.Cache]]:
return self.key_values

def reset(self):
self.messages = []
self.key_values = None
43 changes: 33 additions & 10 deletions ultravox/inference/infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import threading
from typing import Optional
from typing import Optional, Tuple, Union

import librosa
import numpy as np
Expand All @@ -11,7 +11,7 @@
from ultravox.model import ultravox_processing

SAMPLE_RATE = 16000
MAX_TOKENS = 1024
MAX_NEW_TOKENS = 1024
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
# Without this penalty, the model tends to repeat itself.
REPETITION_PENALTY = 1.1

Expand All @@ -33,22 +33,33 @@ def __init__(
def infer(
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
num_beams: int = 1,
) -> base.VoiceOutput:
inputs = self._dataproc(sample)
input_len = inputs["input_ids"].shape[1]
output = self._generate(inputs, max_tokens, temperature)
output_tokens = output[0][input_len:]
output = self._generate(
inputs, max_new_tokens, temperature, past_key_values, num_beams
)
output_tokens = output.sequences[0][input_len:]
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)
return base.VoiceOutput(output_text, input_len, output_len)
audio_token_len = 0
if "audio_token_len" in inputs:
audio_token_len = inputs["audio_token_len"][0]
return base.VoiceOutput(
output_text, input_len, output_len, audio_token_len, output.past_key_values
)

def infer_stream(
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
self,
sample: datasets.VoiceSample,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
num_beams: int = 1,
) -> base.InferenceGenerator:
inputs = self._dataproc(sample)
input_tokens = inputs["input_ids"].shape[1]
Expand All @@ -57,7 +68,14 @@ def infer_stream(
self.tokenizer, skip_prompt=True, decode_kwargs=decode_kwargs
)

thread_args = (inputs, max_tokens, temperature, streamer)
thread_args = (
inputs,
max_new_tokens,
temperature,
past_key_values,
num_beams,
streamer,
)
thread = threading.Thread(target=self._generate, args=thread_args)
thread.start()
output_tokens = 0
Expand Down Expand Up @@ -108,8 +126,10 @@ def _dataproc(self, sample: datasets.VoiceSample):
def _generate(
self,
inputs: torch.Tensor,
max_tokens: Optional[int] = None,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
num_beams: int = 1,
streamer: Optional[transformers.TextStreamer] = None,
):
temperature = temperature or None
Expand All @@ -122,10 +142,13 @@ def _generate(
return self.model.generate(
**inputs,
do_sample=do_sample,
max_new_tokens=max_tokens or MAX_TOKENS,
max_new_tokens=max_new_tokens or MAX_NEW_TOKENS,
temperature=temperature,
repetition_penalty=REPETITION_PENALTY,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=terminators,
streamer=streamer,
past_key_values=past_key_values,
num_beams=num_beams,
return_dict_in_generate=True,
)
12 changes: 9 additions & 3 deletions ultravox/model/ultravox_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,16 @@ def prepare_inputs_for_generation(
**kwargs,
)

zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
if is_cache_empty(past_key_values) and audio_values is not None:
# We only want to use audio features in the 1st generation step
prefill_start_idx = kwargs["cache_position"][0]
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
if (
audio_values is not None
and audio_token_start_idx is not None
and prefill_start_idx <= torch.max(audio_token_start_idx)
):
model_input["audio_values"] = audio_values
model_input["audio_token_start_idx"] = audio_token_start_idx
model_input["audio_token_start_idx"] = (
audio_token_start_idx - prefill_start_idx
)
model_input["audio_token_len"] = audio_token_len

return model_input
Expand Down
Loading
Loading