Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower sequence generation length on code gen to be dependent on max canonical solution length #2682

Merged
merged 15 commits into from
Dec 4, 2023
8 changes: 7 additions & 1 deletion composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,7 @@ def __init__(
self.max_prompt_length = 0
self.top_p = top_p
self.top_k = top_k
self.max_answer_length = 0
fewshot_rng = random.Random(fewshot_random_seed)
self.encoded_dataset = self.prep_examples(num_fewshot, prompt_string, example_delimiter, code_prelimiter,
fewshot_rng)
Expand All @@ -1009,6 +1010,7 @@ def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter:
"""
max_prompt_length = 0
examples = []
max_answer_length = 0
for sample_idx in tqdm(range(len(self.samples))):
encoded_example = {}

Expand Down Expand Up @@ -1050,8 +1052,12 @@ def prep_examples(self, num_fewshot: int, prompt_string: str, example_delimiter:
max_prompt_length = max(
max_prompt_length,
len(encoded_example['preamble']['input_ids'] + encoded_example['prompt']['input_ids']))
max_answer_length = max(
max_answer_length,
len(self.tokenizer(encoded_example['canonical_solution'], add_special_tokens=False)['input_ids']))

self.max_prompt_length = max_prompt_length
self.max_answer_length = max_answer_length + _MAX_ANSWER_BUFFER_LENGTH
return examples

def __getitem__(self, index):
Expand Down Expand Up @@ -1101,7 +1107,7 @@ def collate_fn(self, data):
'test_outputs': test_outputs, # list of test outputs
'languages': languages, # list of languages
'pass_at_k': self.pass_at_k,
'generation_length': self.max_seq_len - self.max_prompt_length,
'generation_length': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length),
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'num_beams': 1, # single beam
Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,7 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == seqlen - max_prompt_length
assert batch['generation_length'] == 129
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down Expand Up @@ -860,7 +860,7 @@ def test_code_eval_test_cases(dataset_uri, tmp_path):
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == seqlen - max_prompt_length
assert batch['generation_length'] == 129
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

mod = types.ModuleType('test_module')
Expand Down Expand Up @@ -938,7 +938,7 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == seqlen - max_prompt_length
assert batch['generation_length'] == 122
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down
Loading