Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mistakes with FA Padding Free #62

Merged
merged 3 commits into from
Aug 2, 2024
Merged

Fix Mistakes with FA Padding Free #62

merged 3 commits into from
Aug 2, 2024

Conversation

fabianlim
Copy link
Contributor

@fabianlim fabianlim commented Aug 1, 2024

The PR #57 had a couple of mistakes that needed to be fix, This is because of two things

  1. the flash_attention_forward was moved out earlier
  2. the actual padding free fix was done later, and is still not yet relaased (probably 4.44)

The strategy now is simple:

  • if we can import DataCollatorWithFlattening successfully, means the padding free fix is done
  • if we can import _flash_attention_forward, means the function has been seperated out

Augmentation

  1. If padding free fix is done, then nothing to do, otherwise some patching is required
  2. Patch the static or method _flash_attention_forward depending on version.

Some redesign is done, since _flash_attention_forward couild either be a method or function, then thje previous method to bind _flash_attention_forward by closure doesnt hold. So we need to install a method on the backbone to intercept the position ids, then modify _flash_attention_forward to be able to access the position ids, and bind them

Bad news is that once this is done properly, the speed dropped. However, we verified that the speed is consistent when we upgrade transformers to latest main which means our implementation is correct

{'loss': 0.8762, 'grad_norm': 69.0, 'learning_rate': 2e-05, 'epoch': 0.0}
{'loss': 0.9877, 'grad_norm': 29.40625, 'learning_rate': 1.7777777777777777e-05, 'epoch': 0.0}
{'loss': 1.0518, 'grad_norm': 38.90625, 'learning_rate': 1.555555555555556e-05, 'epoch': 0.0}
{'loss': 1.1429, 'grad_norm': 85.625, 'learning_rate': 1.3333333333333333e-05, 'epoch': 0.0}
{'loss': 1.0771, 'grad_norm': 22.890625, 'learning_rate': 1.1111111111111113e-05, 'epoch': 0.0}
{'loss': 0.9842, 'grad_norm': 33.5, 'learning_rate': 8.888888888888888e-06, 'epoch': 0.0}
{'loss': 2.4449, 'grad_norm': 19.9375, 'learning_rate': 6.666666666666667e-06, 'epoch': 0.01}
{'loss': 0.9717, 'grad_norm': 35.5625, 'learning_rate': 4.444444444444444e-06, 'epoch': 0.01}
{'loss': 0.8958, 'grad_norm': 25.203125, 'learning_rate': 2.222222222222222e-06, 'epoch': 0.01}
{'loss': 0.9145, 'grad_norm': 18.296875, 'learning_rate': 0.0, 'epoch': 0.01}
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:07<00:00,  1.41it/s]

Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 67.8947, 'train_samples_per_second': 5.891, 'train_steps_per_second': 1.473, 'train_tokens_per_second': 2029.615, 'train_loss': 1.1346958923339843, 'init_mem_cpu_alloc_delta': -14387679232, 'init_mem_gpu_alloc_delta': 14483611648, 'init_mem_cpu_peaked_delta': 14483382272, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 691978240, 'train_mem_gpu_alloc_delta': 28984245248, 'train_mem_cpu_peaked_delta': 0, 'train_mem_gpu_peaked_delta': 28990169600, 'before_init_mem_cpu': 15096680448, 'before_init_mem_gpu': 0, 'epoch': 0.01}

Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
@fabianlim fabianlim requested a review from achew010 August 1, 2024 16:10
@fabianlim
Copy link
Contributor Author

Potentially, this can be improved by having the bakbone function compute the cumsum once for all layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant