Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BitNet #85

Merged
merged 10 commits into from
Nov 27, 2024
Merged

Add BitNet #85

merged 10 commits into from
Nov 27, 2024

Conversation

DustinWang1
Copy link
Contributor

Created BitNet implementation by copying the transformer code and replacing nn.Linear with fused bit linear. Added a "bit" version for the attention module to pair with the BitNet.

@yzhangcs
Copy link
Member

@DustinWang1 Hello, thanks for this PR.

What's the diffs between SeptNet and BitNet?

@DustinWang1
Copy link
Contributor Author

Ah I did not mean to add SeptNet into this pull request. Is there any way for you to pull from "init changes"
on your side?

@yzhangcs
Copy link
Member

@DustinWang1 Feel free to add new commits to delete undesired parts. I will squash them before merging :-)

@yzhangcs
Copy link
Member

Also It would better to use isort to rearrange imported modules in your new commits

@yzhangcs
Copy link
Member

@DustinWang1 Another minor reminder, we have updated the attn layer and Xfmr++ impls recently, primarily on Cache update and param inits. Please ensure your code aligns with these latest changes.

@DustinWang1
Copy link
Contributor Author

Thanks for letting me know about isort. There are so many nice utilities out there :0. I've rearranged the imports and synced my changes with the cache update and param inits.

@yzhangcs yzhangcs self-requested a review November 27, 2024 05:35
@yzhangcs
Copy link
Member

@DustinWang1 Thank you for your quick fix. Could you check out my latest comments again.

@DustinWang1
Copy link
Contributor Author

Are talking about the failed check? I'm currently fixing the style errors on my side, but keep in mind that I haven't changed many of the files that flake8 is citing, they were alr in the main repo. There is one error where S is not defined: "fla\ops\generalized_delta_rule\iplr\naive.py:43:9: F821 undefined name 'S'". I'm not sure how this part of the code works, could you take a look?

Copy link
Member

@yzhangcs yzhangcs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DustinWang1
https://github.com/DustinWang1/flash-linear-attention/blob/a56e65801f7adbc519edae8c72bfd891a1ddf836/fla/models/bitnet/modeling_bitnet.py#L63-L69

Be careful that once you called rms_norm_linear or swiglu_linear, the F.linear is conducted internally. You did not actually invoke quant layers.

fla/models/bitnet/modeling_bitnet.py Outdated Show resolved Hide resolved
@DustinWang1
Copy link
Contributor Author

I replaced rms_norm_linear with a wrapper of the layer_norm_linear_quant_fn in modules/fused_bitlinear. For the swiglu function, I added an alternate version in modules/activations.py that uses a functional form of BitLinear. The function, bit_linear, is also just a wrapper for layer_norm_linear_quant_fn with rms_norm set to true. Let me know if this fixes the issue.

@yzhangcs yzhangcs merged commit 7cc436f into fla-org:main Nov 27, 2024
1 check failed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants