Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline batch inference mode #82

Merged
merged 26 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(self, features, *args, **kwargs):
}
for f in features
]
input_ids_len = torch.LongTensor([f["input_ids"].shape[-1] for f in features])
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
batch = super().__call__(features, *args, **kwargs)
if self.include_alt_fields:
alt_batch = super().__call__(alt_features, *args, **kwargs)
Expand All @@ -101,6 +102,11 @@ def __call__(self, features, *args, **kwargs):
batch["audio_values"] = torch.stack(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
)
if self.tokenizer.padding_side == "left":
displacement = batch["input_ids"].shape[-1] - input_ids_len
batch["audio_token_start_idx"] += displacement.to(
batch["audio_token_start_idx"].device
)

return batch

Expand Down
34 changes: 33 additions & 1 deletion ultravox/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,37 @@ def __init__(
self.past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = (
None
)
self.data_collator = datasets.DataCollatorForSeq2SeqWithAudio(
tokenizer=self.tokenizer,
include_alt_fields=False,
)

assert self.tokenizer.padding_side == "left"
farzadab marked this conversation as resolved.
Show resolved Hide resolved

def batch_infer(
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
self,
samples: List[datasets.VoiceSample],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> List[base.VoiceOutput]:
inputs = [self._dataproc(s) for s in samples]
for input in inputs:
for key, val in input.items():
input[key] = val.squeeze(0)

tensors = self.data_collator(inputs)
input_len = tensors["input_ids"].shape[1]
output_batch = self._generate(
tensors, max_tokens, temperature, return_dict_in_generate=False
)
output_texts = []
for output in output_batch:
output_tokens = output[input_len:]
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)
output_text = base.VoiceOutput(output_text, input_len, output_len)
output_texts.append(output_text)
return output_texts

def update_conversation(
self,
Expand Down Expand Up @@ -162,6 +193,7 @@ def _generate(
temperature: Optional[float] = None,
streamer: Optional[transformers.TextStreamer] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
return_dict_in_generate: Optional[bool] = True,
):
temperature = temperature or None
do_sample = temperature is not None
Expand All @@ -180,5 +212,5 @@ def _generate(
eos_token_id=terminators,
streamer=streamer,
past_key_values=past_key_values,
return_dict_in_generate=True,
return_dict_in_generate=return_dict_in_generate,
)
4 changes: 3 additions & 1 deletion ultravox/inference/infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# work properly for external contributions (since Llama 3 is gated).
@pytest.fixture(scope="module")
def tokenizer():
return transformers.AutoTokenizer.from_pretrained(
tokenizer = transformers.AutoTokenizer.from_pretrained(
"./assets/hf/Meta-Llama-3-8B-Instruct", local_files_only=True
)
tokenizer.padding_side = "left"
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
return tokenizer


@pytest.fixture(scope="module")
Expand Down
65 changes: 45 additions & 20 deletions ultravox/tools/infer_tool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/usr/bin/env python

import argparse
import dataclasses
import json
import os
import time
from typing import IO, List, Optional

import numpy as np
import simple_parsing
from torch.utils.data import DataLoader
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

from ultravox.data import datasets
from ultravox.evaluation import eval
from ultravox.evaluation import eval_types
from ultravox.inference import base
from ultravox.inference import infer
from ultravox.tools import infer_api

# There are two default modes for this tool, agent mode and ASR mode.
Expand Down Expand Up @@ -77,6 +77,8 @@ class InferArgs:
verbose: bool = simple_parsing.field(default=False, alias="-v")
# JSON output
json: bool = simple_parsing.field(default=False)
# Batch size
batch_size: Optional[int] = simple_parsing.field(default=1, alias="-b")

def __post_init__(self):
if self.prompt and self.prompt.startswith("@"):
Expand Down Expand Up @@ -190,25 +192,48 @@ def dataset_infer(inference: base.VoiceInference, args: InferArgs):
if args.seed is not None:
ds_args.shuffle_seed = args.seed
ds = datasets.create_dataset(args.data_sets[0], ds_args)
scores: List[float] = []
for i, sample in enumerate(datasets.Range(ds, args.num_samples)):
# Store the original question and answer for JSON output.
question_text = sample.audio_transcript
expected_answer = sample.messages[-1]["content"]
# Drop any assistant response from the sample.
sample.messages = sample.messages[:-1]
if not args.json:
run_tui(i, inference, sample, args, expected_answer, scores)
else:
output = inference.infer(
sample, max_tokens=args.max_tokens, temperature=args.temperature

if args.json and isinstance(inference, infer.LocalInference):
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
start_time = time.time()
dl = DataLoader(
datasets.Range(ds, args.num_samples),
batch_size=args.batch_size,
collate_fn=lambda x: x,
)
current_batch = []
sample_index = 0
for current_batch in dl:
output = []
for sample in current_batch:
output.append(
{
"index": sample_index,
"question_text": sample.audio_transcript,
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
"expected_answer": sample.messages[-1]["content"],
}
)
sample.messages = sample.messages[:-1]
zqhuang211 marked this conversation as resolved.
Show resolved Hide resolved
sample_index += 1

output_batch = inference.batch_infer(
current_batch,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
obj = {
"question": question_text,
"generated_answer": output.text,
"expected_answer": expected_answer,
}
print(json.dumps(obj))
for i, output_text in enumerate(output_batch):
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
output[i]["output_text"] = output_text
print(output)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
print("Total time", time.time() - start_time)
liPatrick marked this conversation as resolved.
Show resolved Hide resolved

else:
scores: List[float] = []
for i, sample in enumerate(datasets.Range(ds, args.num_samples)):
# Store the original question and answer for JSON output.
liPatrick marked this conversation as resolved.
Show resolved Hide resolved
question_text = sample.audio_transcript
expected_answer = sample.messages[-1]["content"]
# Drop any assistant response from the sample.
sample.messages = sample.messages[:-1]
run_tui(i, inference, sample, args, expected_answer, scores)


def main(args: InferArgs):
Expand Down
Loading