Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eot token to ICL generate kwargs #2782

Merged
merged 14 commits into from
Dec 20, 2023
8 changes: 7 additions & 1 deletion composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def __init__(
fewshot_random_seed: int,
cot_delimiter: str = '',
):
if tokenizer.eos_token_id is None:
raise ValueError('`InContextLearningQATaskDataset` tokenizer must have non-null `eos_token_id`')
try:
from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues]
except ImportError as e:
Expand Down Expand Up @@ -306,7 +308,8 @@ def collate_fn(self, data):
'generation_length': self.max_answer_length,
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'use_cache': True
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
bmosaicml marked this conversation as resolved.
Show resolved Hide resolved
}
}

Expand Down Expand Up @@ -948,6 +951,8 @@ def __init__(
top_p: Optional[float] = 0.95,
top_k: Optional[int] = 40,
):
if tokenizer.eos_token_id is None:
raise ValueError('`InContextLearningCodeEvalDataset` tokenizer must have non-null `eos_token_id`')
try:
from datasets import load_dataset # pyright: ignore [reportGeneralTypeIssues]
except ImportError as e:
Expand Down Expand Up @@ -1116,6 +1121,7 @@ def collate_fn(self, data):
'top_p': self.top_p,
'top_k': self.top_k,
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id
}
}
batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
Expand Down
31 changes: 31 additions & 0 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,35 @@ def test_qa_split_batch(tiny_opt_tokenizer, dataset_uri, tmp_path):
assert isinstance(split2['generation_kwargs'], dict)


@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0])
@pytest.mark.parametrize('prompt_string', ['I am a prompt', ''])
def test_qa_task_dataloader_w_null_eos(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fewshot, prompt_string):
pytest.importorskip('datasets')

local_data = os.path.join(os.path.dirname(__file__), 'local_data')

tokenizer = tiny_gpt2_tokenizer
dataset_uri = f'{local_data}/{dataset_uri}'
batch_size = 4
seqlen = 512
# empirical number from the small test dataset
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
tiny_gpt2_tokenizer.eos_token_id = None
with pytest.raises(ValueError):
_ = get_icl_task_dataloader('question_answering',
dataset_uri,
tokenizer,
batch_size,
max_seq_len=seqlen,
pad_tok_id=tokenizer.eos_token_id,
num_fewshot=num_fewshot,
prompt_string=prompt_string,
example_delimiter='\n',
question_prelimiter='Q: ',
continuation_delimiter='\nA:',
destination_path=str(tmp_path / f'icl_{num_fewshot}.jsonl'))


@pytest.mark.parametrize('dataset_uri', ['triviaqa_small.jsonl'])
@pytest.mark.parametrize('num_fewshot', [0, 2])
@pytest.mark.parametrize('prompt_string', ['I am a prompt', ''])
Expand Down Expand Up @@ -545,6 +574,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data

assert batch['generation_length'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

Expand All @@ -559,6 +589,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
for found, expected in zip(batch['labels'], [['David Seville'], ['Skorpio', 'Scorpio']]))
assert decoded_batch[0].endswith('Q: Who was the man behind The Chipmunks?\nA:')
assert decoded_batch[1].endswith('Q: What star sign is Jamie Lee Curtis?\nA:')
assert 'eos_token_id' in batch['generation_kwargs']


@pytest.mark.parametrize('dataset_uri', ['gsm8k_small.jsonl'])
Expand Down
Loading