-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BitNet #85
Add BitNet #85
Conversation
@DustinWang1 Hello, thanks for this PR. What's the diffs between SeptNet and BitNet? |
Ah I did not mean to add SeptNet into this pull request. Is there any way for you to pull from "init changes" |
@DustinWang1 Feel free to add new commits to delete undesired parts. I will squash them before merging :-) |
Also It would better to use isort to rearrange imported modules in your new commits |
@DustinWang1 Another minor reminder, we have updated the attn layer and Xfmr++ impls recently, primarily on Cache update and param inits. Please ensure your code aligns with these latest changes. |
Thanks for letting me know about isort. There are so many nice utilities out there :0. I've rearranged the imports and synced my changes with the cache update and param inits. |
@DustinWang1 Thank you for your quick fix. Could you check out my latest comments again. |
Are talking about the failed check? I'm currently fixing the style errors on my side, but keep in mind that I haven't changed many of the files that flake8 is citing, they were alr in the main repo. There is one error where S is not defined: "fla\ops\generalized_delta_rule\iplr\naive.py:43:9: F821 undefined name 'S'". I'm not sure how this part of the code works, could you take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DustinWang1
https://github.com/DustinWang1/flash-linear-attention/blob/a56e65801f7adbc519edae8c72bfd891a1ddf836/fla/models/bitnet/modeling_bitnet.py#L63-L69
Be careful that once you called rms_norm_linear
or swiglu_linear
, the F.linear
is conducted internally. You did not actually invoke quant layers.
I replaced rms_norm_linear with a wrapper of the layer_norm_linear_quant_fn in modules/fused_bitlinear. For the swiglu function, I added an alternate version in modules/activations.py that uses a functional form of BitLinear. The function, bit_linear, is also just a wrapper for layer_norm_linear_quant_fn with rms_norm set to true. Let me know if this fixes the issue. |
Created BitNet implementation by copying the transformer code and replacing nn.Linear with fused bit linear. Added a "bit" version for the attention module to pair with the BitNet.