Skip to content

Commit

Permalink
Offline batch inference mode (#82)
Browse files Browse the repository at this point in the history
* Batch working! Next keep track of indicies in dataset

* Batching working

* Add some print statements

* Add yield

* Remove print statements

* Fixing types

* Small cleanup

* More small fixes

* Pull out additional data before running through collator

* remove _process_dataset_batch unused.

* Adding displacement to audio token start idx if padding left is true

* Address comments

* Addressing comments

* Batch without dataloader

* Clean up

* Remove tensor conversion in collator

* Remove extra lines

* Addressing comments

* Address comments

* Using dataloader without collator

* Small cleanup

* Addressing comments

* Remove text output from json

* Use zip
  • Loading branch information
liPatrick authored Aug 20, 2024
1 parent e5caca9 commit b2dc7f1
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 22 deletions.
6 changes: 6 additions & 0 deletions ultravox/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(self, features, *args, **kwargs):
}
for f in features
]
input_ids_lens = torch.LongTensor([f["input_ids"].shape[-1] for f in features])
batch = super().__call__(features, *args, **kwargs)
if self.include_alt_fields:
alt_batch = super().__call__(alt_features, *args, **kwargs)
Expand All @@ -101,6 +102,11 @@ def __call__(self, features, *args, **kwargs):
batch["audio_values"] = torch.stack(
[F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
)
if self.tokenizer.padding_side == "left":
displacement = batch["input_ids"].shape[-1] - input_ids_lens
batch["audio_token_start_idx"] += displacement.to(
batch["audio_token_start_idx"].device
)

return batch

Expand Down
11 changes: 10 additions & 1 deletion ultravox/inference/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import dataclasses
from typing import Generator, Optional
from typing import Generator, List, Optional

from ultravox.data import datasets

Expand Down Expand Up @@ -40,6 +40,15 @@ def infer(
) -> VoiceOutput:
pass

# Unoptimized batch inference. Used as a fallback if the derived class doesn't implement it.
def infer_batch(
self,
samples: List[datasets.VoiceSample],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> List[VoiceOutput]:
return [self.infer(sample, max_tokens, temperature) for sample in samples]

def infer_stream(
self,
sample: datasets.VoiceSample,
Expand Down
35 changes: 34 additions & 1 deletion ultravox/inference/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def __init__(
self.past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = (
None
)
self.data_collator = datasets.DataCollatorForSeq2SeqWithAudio(
tokenizer=self.tokenizer,
include_alt_fields=False,
)

assert self.tokenizer.padding_side == "left"

def update_conversation(
self,
Expand All @@ -53,6 +59,32 @@ def _get_sample_with_past(
sample.add_past_messages(self.past_messages)
return sample

def infer_batch(
self,
samples: List[datasets.VoiceSample],
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
) -> List[base.VoiceOutput]:
assert not self.conversation_mode
inputs = [self._dataproc(s) for s in samples]
for input in inputs:
for key, val in input.items():
input[key] = val.squeeze(0)

tensors = self.data_collator(inputs)
input_len = tensors["input_ids"].shape[1]
output_batch = self._generate(
tensors, max_tokens, temperature, return_dict_in_generate=False
)
output_texts = []
for output in output_batch:
output_tokens = output[input_len:]
output_text = self.tokenizer.decode(output_tokens, skip_special_tokens=True)
output_len = len(output_tokens)
output_text = base.VoiceOutput(output_text, input_len, output_len)
output_texts.append(output_text)
return output_texts

def infer(
self,
sample: datasets.VoiceSample,
Expand Down Expand Up @@ -162,6 +194,7 @@ def _generate(
temperature: Optional[float] = None,
streamer: Optional[transformers.TextStreamer] = None,
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
return_dict_in_generate: Optional[bool] = True,
):
temperature = temperature or None
do_sample = temperature is not None
Expand All @@ -180,5 +213,5 @@ def _generate(
eos_token_id=terminators,
streamer=streamer,
past_key_values=past_key_values,
return_dict_in_generate=True,
return_dict_in_generate=return_dict_in_generate,
)
5 changes: 4 additions & 1 deletion ultravox/inference/infer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# work properly for external contributions (since Llama 3 is gated).
@pytest.fixture(scope="module")
def tokenizer():
return transformers.AutoTokenizer.from_pretrained(
tokenizer = transformers.AutoTokenizer.from_pretrained(
"./assets/hf/Meta-Llama-3-8B-Instruct", local_files_only=True
)
# Set padding_side to "left" to support batch inference.
tokenizer.padding_side = "left"
return tokenizer


@pytest.fixture(scope="module")
Expand Down
57 changes: 38 additions & 19 deletions ultravox/tools/infer_tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#!/usr/bin/env python

import argparse
import dataclasses
import json
Expand All @@ -9,6 +8,7 @@

import numpy as np
import simple_parsing
from torch.utils import data as data_utils

from ultravox.data import datasets
from ultravox.evaluation import eval
Expand Down Expand Up @@ -77,6 +77,8 @@ class InferArgs:
verbose: bool = simple_parsing.field(default=False, alias="-v")
# JSON output
json: bool = simple_parsing.field(default=False)
# Batch size
batch_size: Optional[int] = simple_parsing.field(default=1, alias="-b")

def __post_init__(self):
if self.prompt and self.prompt.startswith("@"):
Expand Down Expand Up @@ -190,25 +192,42 @@ def dataset_infer(inference: base.VoiceInference, args: InferArgs):
if args.seed is not None:
ds_args.shuffle_seed = args.seed
ds = datasets.create_dataset(args.data_sets[0], ds_args)
scores: List[float] = []
for i, sample in enumerate(datasets.Range(ds, args.num_samples)):
# Store the original question and answer for JSON output.
question_text = sample.audio_transcript
expected_answer = sample.messages[-1]["content"]
# Drop any assistant response from the sample.
sample.messages = sample.messages[:-1]
if not args.json:
run_tui(i, inference, sample, args, expected_answer, scores)
else:
output = inference.infer(
sample, max_tokens=args.max_tokens, temperature=args.temperature

if args.json:
dl = data_utils.DataLoader(
datasets.Range(ds, args.num_samples),
batch_size=args.batch_size,
collate_fn=lambda x: x,
)
sample_index = 0
for input_batch in dl:
expected_answers = [
sample.messages.pop()["content"] for sample in input_batch
]
output_batch = inference.infer_batch(
input_batch,
max_tokens=args.max_tokens,
temperature=args.temperature,
)
obj = {
"question": question_text,
"generated_answer": output.text,
"expected_answer": expected_answer,
}
print(json.dumps(obj))
for sample, generated, expected in zip(
input_batch, output_batch, expected_answers
):
output = {
"index": sample_index,
"question": sample.audio_transcript,
"expected_answer": expected,
"generated_answer": generated.text,
}
sample_index += 1
print(json.dumps(output))
else:
scores: List[float] = []
for i, sample in enumerate(datasets.Range(ds, args.num_samples)):
# Store the answer for JSON output.
expected_answer = sample.messages[-1]["content"]
# Drop any assistant response from the sample.
sample.messages = sample.messages[:-1]
run_tui(i, inference, sample, args, expected_answer, scores)


def main(args: InferArgs):
Expand Down

0 comments on commit b2dc7f1

Please sign in to comment.