-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SFTTrainer support #682
SFTTrainer support #682
Changes from 1 commit
0b88043
4fa4020
1e8473b
758d7d8
d2672cc
74e47a9
3efd548
58e3a44
1122f68
8691b9e
1c1c26b
25ab3b9
72b968c
9fc38e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
AutoModelForSequenceClassification, | ||
) | ||
|
||
from optimum.neuron import NeuronTrainer, NeuronTrainingArguments | ||
from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainer, NeuronTrainingArguments | ||
from optimum.neuron.distributed.utils import MODEL_PARALLEL_SHARDS_DIR_NAME | ||
from optimum.neuron.utils import is_neuronx_distributed_available | ||
from optimum.neuron.utils.cache_utils import ( | ||
|
@@ -300,7 +300,7 @@ def create_training_args(output_dir, resume_from_checkpoint=None, max_steps=max_ | |
per_device_train_batch_size=train_batch_size, | ||
per_device_eval_batch_size=eval_batch_size, | ||
max_steps=max_steps, | ||
logging_steps=1, | ||
logging_steps=2, | ||
save_steps=5, | ||
do_eval=do_eval, | ||
output_dir=output_dir, | ||
|
@@ -396,3 +396,70 @@ def preprocess_function(examples): | |
|
||
trainer.train(resume_from_checkpoint=True) | ||
trainer.evaluate() | ||
|
||
|
||
@is_trainium_test | ||
class TestNeuronSFTTrainer(DistributedTest): | ||
@pytest.fixture( | ||
scope="class", | ||
params=[[2, 1, 1], [2, 2, 1]], | ||
ids=["dp=2", "tp=2"], | ||
) | ||
def parallel_sizes(self, request): | ||
return request.param | ||
|
||
def _test_sft_trainer(self, parallel_sizes, tmpdir, packing): | ||
_, tp_size, pp_size = parallel_sizes | ||
|
||
output_dir = Path(tmpdir) | ||
|
||
dataset = load_dataset("databricks/databricks-dolly-15k", split="train") | ||
# dataset = dataset.select(range(1000)) | ||
|
||
def format_dolly(sample): | ||
instruction = f"### Instruction\n{sample['instruction']}" | ||
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None | ||
response = f"### Answer\n{sample['response']}" | ||
# join all the parts together | ||
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None]) | ||
if packing: | ||
return prompt | ||
return [prompt] | ||
|
||
tokenizer, model = get_tokenizer_and_tiny_llama_model() | ||
tokenizer.pad_token = tokenizer.eos_token | ||
tokenizer.padding_side = "left" # to prevent warnings | ||
|
||
args = NeuronTrainingArguments( | ||
output_dir=output_dir, | ||
do_train=True, | ||
max_steps=20, | ||
per_device_train_batch_size=1, | ||
tensor_parallel_size=tp_size, | ||
pipeline_parallel_size=pp_size, | ||
logging_steps=1, | ||
) | ||
args = args.to_dict() | ||
sft_config = NeuronSFTConfig( | ||
max_seq_length=512, | ||
packing=packing, | ||
dataset_num_proc=1, | ||
**args, | ||
) | ||
|
||
# Create Trainer instance | ||
trainer = NeuronSFTTrainer( | ||
model=model, | ||
tokenizer=tokenizer, | ||
train_dataset=dataset, | ||
formatting_func=format_dolly, | ||
args=sft_config, | ||
) | ||
|
||
trainer.train() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we verify that the loss goes down or something ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a tiny random model. The |
||
|
||
def test_without_packing(self, parallel_sizes, tmpdir): | ||
return self._test_sft_trainer(parallel_sizes, tmpdir, False) | ||
|
||
def test_with_packing(self, parallel_sizes, tmpdir): | ||
return self._test_sft_trainer(parallel_sizes, tmpdir, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove commented line