-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SFTTrainer support #682
SFTTrainer support #682
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, it looks good! Just don't we need to add trl to the setup.py, add NeuronSFTTrainer to the API doc, and perhaps if possible have minimal test?
I did not add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As @JingyaHuang said, it would ne nice to have some unit tests before integrating this in the demo, to speed up integration by identifying issues early on.
args = NeuronSFTConfig(output_dir=output_dir) | ||
elif args is not None and args.__class__.__name__ == "NeuronTrainingArguments": | ||
args_as_dict = args.to_dict() | ||
# Manually copy token values as TrainingArguments.to_dict() redacts them |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comes from the original trl, but I have no idea what this means ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically the SFTConfig
replaces the TrainingArguments
. You can still provide training args and the SFTTrainer
converts them to an SFTConfig
.
optimum/neuron/trainers.py
Outdated
@@ -1465,3 +1503,345 @@ class Seq2SeqNeuronTrainer(AugmentTrainerForNeuronMixin, Seq2SeqTrainer): | |||
""" | |||
Seq2SeqTrainer that is suited for performing training on AWS Tranium instances. | |||
""" | |||
|
|||
|
|||
class NeuronSFTTrainer(AugmentTrainerForNeuronMixin, SFTTrainer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe add a comment here indicating how this differs from the original (ie what are the neuron specifics).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!
I have added tests. They do not check anything but run a small training job both with packed and unpacked datasets. If the training job succeeds, the test pass, otherwise it fails. |
tests/test_trainers.py
Outdated
output_dir = Path(tmpdir) | ||
|
||
dataset = load_dataset("databricks/databricks-dolly-15k", split="train") | ||
# dataset = dataset.select(range(1000)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove commented line
args=sft_config, | ||
) | ||
|
||
trainer.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we verify that the loss goes down or something ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a tiny random model. The SFTTrainer
does not anything related to the loss anyways. It's just Trainer
with dataset preparation abilities.
The new tests are failing:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this pull-request !
What does this PR do?
This PR adds two classes:
NeuronSFTConfig
NeuronSFTTrainer
Both of these classes achieve the same goal as their
trl
counterpart.