-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add DataClass Arguments to Activate Padding-Free and MultiPack Plugin and FastKernels #280
feat: Add DataClass Arguments to Activate Padding-Free and MultiPack Plugin and FastKernels #280
Conversation
e350ae7
to
193ab9d
Compare
2310f32
to
f9d046f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@achew010 lets gracefully handle the case when use_flash_attn is set to False and padding free is being used.
fms-hf-tuning/tuning/config/configs.py
Line 39 in d728007
use_flash_attn: bool = field( |
29362a4
to
00d17e7
Compare
3b20f22
to
53d1a8c
Compare
8f1c9ea
to
b15a9c7
Compare
3cccc41
to
46d587f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the excellent change and description! Had a few questions...also am wondering, should the plugin be installed by default so users can utilize these new parameters? Looks like a very useful addition.
Also please add some of the great description from this PR into the readme.
@dataclass | ||
class MultiPack: | ||
|
||
num_processes: int = 16 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any guidance on what this number should be set to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this number is a reasonable one for most datasets of reasonable size (e.g. under a million examples). The packing algorithm is relatively fast, but in the event the dataset is too large, then our plugin will raise a warning
https://github.com/foundation-model-stack/fms-acceleration/blob/4e81c64453ec5d2b06a8d14a2a72374cc736098a/plugins/attention-and-distributed-packing/src/fms_acceleration_aadp/framework_plugin_multipack.py#L117-L123
that advises the user to increase this number if the process times out.
tuning/config/acceleration_configs/acceleration_framework_config.py
Outdated
Show resolved
Hide resolved
framework = AccelerationFrameworkConfig.from_dataclasses( | ||
quantized_lora_config, fusedops_kernels_config | ||
quantized_lora_config, | ||
fusedops_kernels_config, | ||
attention_and_distributed_packing_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my understanding, so these are all model loader augmentors that change how the model is loaded based on the acceleration framework configurations? Although padding free and multipack are both dataset augmentors? How does setting the acceleration framework here affect the dataset loading?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right in saying padding free and multipack affect the dataloading, but more specifically
- padding free only requires modifications to data collation.
- multpack requires modification to dataloader
Both we handle by our AccelerationPatcher
, which is a component that we wrote to allow controlled replacements of the data collator and data loader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for the explanation
tuning/sft_trainer.py
Outdated
"ensure `use_flash_attn = True` to use padding-free flash attention" | ||
) | ||
|
||
if train_args.packing is True: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can simplify to if train_args.packing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small additional comments
* `fused_ops_and_kernels` works for full-finetuning, LoRA, QLoRA and GPTQ-LORA, | ||
- pass `--fast_kernels True True True` for full finetuning/LoRA | ||
- pass `--fast_kernels True True True --auto_gptq triton_v2 --fused_lora auto_gptq True` for GPTQ-LoRA | ||
- pass `--fast_kernels True True True --bitsandbytes nf4 --fused_lora bitsandbytes True` for QLoRA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering for fast-kernels if there is a better way to understand what is being set to true
--fast_kernels True True True
feels unclear on what is being set to True. Could the user instead pass in --fast_kernels <types of kernel to use>
like --fast_kernels FastCrossEntropyLoss FastRoPE FastRMSLayerNorm
. If they only want one would they currently have to set --fast_kernels False True False
whereas instead setting --fast_kernels FastRoPE
would be easier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that is correct, but unfortunately that will be more complicated than the current implementation.
- Consider the plugin
dataclass
(e.g.,FusedOpsAndKernelsConfig
), see here - the plugin dataclass is a nested
dataclass
; this is because it hasdataclasses
as members. - each member dataclass (e.g.,
FastKernelsConfig
) needs to be parsable byHfArgumentParser
, which actually does not support parsing adataclass
type. - hence, we made it possible due to our parsable_dataclass decorator, that
- masquarades the member dataclass as a
List
, whichHfArgumentParser
does support lists of a uniform type. - allows our member dataclass to contain mixed types by the casting logic implemented in
parsable_dataclass
via EnsureTypes.
- masquarades the member dataclass as a
All this logic is needed just.to parse --fast_kernels False True False
into the dataclass FastKernelsConfig(fast_loss=False, fast_rsm_layernorm=True, fast_rope_embeddings=False)
.
To support parsing of the kind --fast_kernels FastRoPE
, we need to handle
- handling of different types. "FastRoPE" is clearly a boolean type, but we also need to handle
str
inputs,float
inputs etc, where it would need to be akey=value
pair - handling different orders, we need to be able to parse
--dataclass_key key_a=a key_b
and--dataclass key_b key_a=a
equivalently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the details, I agree getting this merged and thinking about improvements for fast_kernels later makes sense. Is just the FusedOpsAndKernelsConfig
ready to move out of experimental or can this also be done for PaddingFree and Multipack?
framework = AccelerationFrameworkConfig.from_dataclasses( | ||
quantized_lora_config, fusedops_kernels_config | ||
quantized_lora_config, | ||
fusedops_kernels_config, | ||
attention_and_distributed_packing_config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you for the explanation
Also we added the new automation that ensure PRs follow convention commits which you can see is failing -- https://github.com/foundation-model-stack/fms-hf-tuning/actions/runs/10920573842/job/30310716778?pr=280 please address the change |
Please update the branch with the new changes from |
Note @kmehant I think since you requested changes, an approval is needed from your side as well before this can merge |
Signed-off-by: 1000960000 user <aaron.chew1@ibm.com> Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com> Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
Signed-off-by: 1000850000 user <aaron.chew1@ibm.com>
f38c827
to
b78936e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@anhuong thanks for letting me know. Its really annoying that I am not able to dismiss my review some way so that I do not stand as a blocker :( forcing me to push a approval.
Nonetheless, I have used most of these features as part of iLab and undoubtedly vouch for the changes. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can also mark paddingFree and multiPack as not experimental but LGTM
Description of the change
This PR adds two dataclass arguments to enable padding free and multipack for the
sft_trainer.py
, via the newfms acceleration
attention-and-distributed-packing plugin and allows the current--fastkernels
dataclass to support optimized full-finetuning:--padding_free
: technique to process multiple examples in single batch without adding padding tokens that waste compute.--multipack
: technique for multi-gpu training to balance out number of tokens processed in each device, to minimize waiting time.--fast_kernels
: Previously limited only for QPEFT (used to raise if not activated with--fast_lora
), Now allows for optimized full/standard LoRA finetuning.These are extremely effective methods to improve training throughputs:
NOTE: adhering to the design of fms-acceleration, the new plugin is optional, and separately installed.
Notes on Padding Free
<=4.43
), when padding free is not yet integrated from our PR into Hugging Face: Enhancing SFT Training Efficiency Using Packing and FlashAttention2 with Position IDs huggingface/transformers#31629.>= 4.44
).Notes on Multipack
Notes on FastKernels
--fast_kernels True True True
on full finetuning/LoRA runs--fast_kernels True True True --auto_gptq triton_v2 --fused_lora auto_gptq True
for GPTQ-LoRA--fast_kernels True True True --bitsandbytes nf4 --fused_lora bitsandbytes True
for QLoRApositional_ids
but this issue will be addressed in the futureBenchmarks
PaddingFree and Multipack Benchmarks for Mistral 7B
Notes:
Per Device Batch Size 4
Per Device Batch Size 8
Verified Similar Improvements for Untokenized Dataset
Full Finetuning Benchmarks for Mistral 7B
Early Version Of This Plugin
We have an unofficial version with more features than our present release. @kmehant is currently using for ILAB work. It addition to the padding-free and multipack, it also has the additional two plugins below:
To use the early version a quick hack of
sft_trainer
with pretokenized + custom tokenizer: https://github.com/fabianlim/fms-hf-tuning/tree/attn-plugin . This will be superceded by this PR in the near futureUse with these command line arugments:
How to verify the PR
Additional checks/tests were added to
--padding_free
andmultipack
is correct intest_dataclass_parse_successfully
--padding_free
are caught intest_dataclass_will_fail_to_accept_illegal_args
test_framework_initialize_and_trains_with_aadp
--padding_free
must be used with flash-attn, otherwise error is raised--multi_pack
must be used with--padding_free
, otherwise error is raised--packing True
with--padding_free
will raise an error--fast_kernels
works with full finetuning--fast_lora
not called with either--auto_gptq
or--bitsandbytes
will raise an errorRan the full suite of acceleration checks to verify all fms-acceleration unit tests passed
Was the PR tested