-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add general support for generation on TRN with NxD #370
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you very much for this pull-request.
I am fully aligned with the fact that the existing NeuronGenerationMixin
needs to be replaced by a more robust and easier to maintain implementation, and this goes definitely in the right direction.
An additional benefit is that this new class will be compatible with the logits processors, some of which are really important (like repetition penalty).
I have made a few comments on the implementation itself, mostly related to how things are organized (I would rather put everything at the same place, if possible inside the mixin).
My main request however is related to the tests: I am not comfortable with integrating new code without proper unit tests.
These tests will not only allow us to detect early bugs or configurations we don't support, but will also help us detecting regressions when bumping the transformers version.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks you very much for this pull-request, and taking the time to fully address my comments.
Before merging, you need to:
- apply styling (
make style
), - rebase on main (another pull-request has been merged to use the latest AWS Neuron SDK, sorry about that ...),
- ping us to trigger the CI tests again.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test_runner failures in the CI are related to the fact that you don't have access to the HF internal token, and you should not worry about them.
However, two of the tests that you added are failing.
tests/test_cache_utils.py ......ssssssssss [ 21%]
tests/test_generate.py sssss [ 28%]
tests/test_runner.py EEEEEE [ 36%]
tests/test_trainer_callback.py ssss [ 41%]
tests/test_trainers.py sssssss [ 50%]
tests/test_utils.py . [ 52%]
tests/cli/test_neuron_cache_cli.py ssssss [ 60%]
tests/generation/test_generate.py ............FF................ [100%]
The tests are now passing. Could you please check if there is any other merge blocker? If no, when can we target to merge this PR? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for your patience and dedication !
This pull request adds general support for generative LM generation on TRN. Currently, the generation can only be done separately from the training job with
transformer_neuronx
inference feature. The limitation is that it needs to be run as a separate program from the training job. With this support added, the users will be able to run model evaluation as part of the training job and choose to run it every n steps or epochs. Although we have implementedNeuronGenerationMixin
class that delivers a similar feature, it has some drawbacks and functionality issues which I'll specify soon.The generation will be done on both Neuron device and CPU in the following way:
This enables us to support various of search/sampling methods provided by huggingface with minimal code changes. It has a couple of advantages comparing to the existing implementation of
NeuronGenerationMixin
class:generate
function and callsuper().generate()
to reuse most of HuggingFace implementation. Therefore, it will be more stable and less likely to brake as HuggingFace updates transformers package.I've tested the functionality of text generation with GPT2 and Llama-2-7b models using this script. The output results on TRN exactly matches the CPU outputs.
Reference: #108