diff --git a/ultravox/data/datasets.py b/ultravox/data/datasets.py index 596fe108..89f131f2 100644 --- a/ultravox/data/datasets.py +++ b/ultravox/data/datasets.py @@ -88,6 +88,7 @@ def __call__(self, features, *args, **kwargs): } for f in features ] + input_ids_lens = torch.LongTensor([f["input_ids"].shape[-1] for f in features]) batch = super().__call__(features, *args, **kwargs) if self.include_alt_fields: alt_batch = super().__call__(alt_features, *args, **kwargs) @@ -101,6 +102,11 @@ def __call__(self, features, *args, **kwargs): batch["audio_values"] = torch.stack( [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values] ) + if self.tokenizer.padding_side == "left": + displacement = batch["input_ids"].shape[-1] - input_ids_lens + batch["audio_token_start_idx"] += displacement.to( + batch["audio_token_start_idx"].device + ) return batch diff --git a/ultravox/inference/base.py b/ultravox/inference/base.py index bef8185f..9834aae6 100644 --- a/ultravox/inference/base.py +++ b/ultravox/inference/base.py @@ -1,6 +1,6 @@ import abc import dataclasses -from typing import Generator, Optional +from typing import Generator, List, Optional from ultravox.data import datasets @@ -40,6 +40,15 @@ def infer( ) -> VoiceOutput: pass + # Unoptimized batch inference. Used as a fallback if the derived class doesn't implement it. + def infer_batch( + self, + samples: List[datasets.VoiceSample], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> List[VoiceOutput]: + return [self.infer(sample, max_tokens, temperature) for sample in samples] + def infer_stream( self, sample: datasets.VoiceSample, diff --git a/ultravox/inference/infer.py b/ultravox/inference/infer.py index 752b0c72..3c6a8e6c 100644 --- a/ultravox/inference/infer.py +++ b/ultravox/inference/infer.py @@ -37,6 +37,12 @@ def __init__( self.past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = ( None ) + self.data_collator = datasets.DataCollatorForSeq2SeqWithAudio( + tokenizer=self.tokenizer, + include_alt_fields=False, + ) + + assert self.tokenizer.padding_side == "left" def update_conversation( self, @@ -53,6 +59,32 @@ def _get_sample_with_past( sample.add_past_messages(self.past_messages) return sample + def infer_batch( + self, + samples: List[datasets.VoiceSample], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> List[base.VoiceOutput]: + assert not self.conversation_mode + inputs = [self._dataproc(s) for s in samples] + for input in inputs: + for key, val in input.items(): + input[key] = val.squeeze(0) + + tensors = self.data_collator(inputs) + input_len = tensors["input_ids"].shape[1] + output_batch = self._generate( + tensors, max_tokens, temperature, return_dict_in_generate=False + ) + output_texts = [] + for output in output_batch: + output_tokens = output[input_len:] + output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True) + output_len = len(output_tokens) + output_text = base.VoiceOutput(output_text, input_len, output_len) + output_texts.append(output_text) + return output_texts + def infer( self, sample: datasets.VoiceSample, @@ -162,6 +194,7 @@ def _generate( temperature: Optional[float] = None, streamer: Optional[transformers.TextStreamer] = None, past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None, + return_dict_in_generate: Optional[bool] = True, ): temperature = temperature or None do_sample = temperature is not None @@ -180,5 +213,5 @@ def _generate( eos_token_id=terminators, streamer=streamer, past_key_values=past_key_values, - return_dict_in_generate=True, + return_dict_in_generate=return_dict_in_generate, ) diff --git a/ultravox/inference/infer_test.py b/ultravox/inference/infer_test.py index c9768937..f6a06c74 100644 --- a/ultravox/inference/infer_test.py +++ b/ultravox/inference/infer_test.py @@ -15,9 +15,12 @@ # work properly for external contributions (since Llama 3 is gated). @pytest.fixture(scope="module") def tokenizer(): - return transformers.AutoTokenizer.from_pretrained( + tokenizer = transformers.AutoTokenizer.from_pretrained( "./assets/hf/Meta-Llama-3-8B-Instruct", local_files_only=True ) + # Set padding_side to "left" to support batch inference. + tokenizer.padding_side = "left" + return tokenizer @pytest.fixture(scope="module") diff --git a/ultravox/tools/infer_tool.py b/ultravox/tools/infer_tool.py index bd640068..36ea4a0d 100644 --- a/ultravox/tools/infer_tool.py +++ b/ultravox/tools/infer_tool.py @@ -1,5 +1,4 @@ #!/usr/bin/env python - import argparse import dataclasses import json @@ -9,6 +8,7 @@ import numpy as np import simple_parsing +from torch.utils import data as data_utils from ultravox.data import datasets from ultravox.evaluation import eval @@ -77,6 +77,8 @@ class InferArgs: verbose: bool = simple_parsing.field(default=False, alias="-v") # JSON output json: bool = simple_parsing.field(default=False) + # Batch size + batch_size: Optional[int] = simple_parsing.field(default=1, alias="-b") def __post_init__(self): if self.prompt and self.prompt.startswith("@"): @@ -190,25 +192,42 @@ def dataset_infer(inference: base.VoiceInference, args: InferArgs): if args.seed is not None: ds_args.shuffle_seed = args.seed ds = datasets.create_dataset(args.data_sets[0], ds_args) - scores: List[float] = [] - for i, sample in enumerate(datasets.Range(ds, args.num_samples)): - # Store the original question and answer for JSON output. - question_text = sample.audio_transcript - expected_answer = sample.messages[-1]["content"] - # Drop any assistant response from the sample. - sample.messages = sample.messages[:-1] - if not args.json: - run_tui(i, inference, sample, args, expected_answer, scores) - else: - output = inference.infer( - sample, max_tokens=args.max_tokens, temperature=args.temperature + + if args.json: + dl = data_utils.DataLoader( + datasets.Range(ds, args.num_samples), + batch_size=args.batch_size, + collate_fn=lambda x: x, + ) + sample_index = 0 + for input_batch in dl: + expected_answers = [ + sample.messages.pop()["content"] for sample in input_batch + ] + output_batch = inference.infer_batch( + input_batch, + max_tokens=args.max_tokens, + temperature=args.temperature, ) - obj = { - "question": question_text, - "generated_answer": output.text, - "expected_answer": expected_answer, - } - print(json.dumps(obj)) + for sample, generated, expected in zip( + input_batch, output_batch, expected_answers + ): + output = { + "index": sample_index, + "question": sample.audio_transcript, + "expected_answer": expected, + "generated_answer": generated.text, + } + sample_index += 1 + print(json.dumps(output)) + else: + scores: List[float] = [] + for i, sample in enumerate(datasets.Range(ds, args.num_samples)): + # Store the answer for JSON output. + expected_answer = sample.messages[-1]["content"] + # Drop any assistant response from the sample. + sample.messages = sample.messages[:-1] + run_tui(i, inference, sample, args, expected_answer, scores) def main(args: InferArgs):