Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Quiet Softmax (Attention Is Off By One) #692

Merged
merged 3 commits into from
Nov 30, 2023

Conversation

wbrickner
Copy link
Contributor

Pull Request Template

Checklist

  • Confirm that run-checks script has been executed.

Related Issues/PRs

Changes

  • I propose a quiet_softmax activation function (motivated by Attention Is Off By One).
  • I propose a quiet_softmax configuration field for MultiHeadAttention (on MultiHeadAttentionConfig), and to provide this field as well for all the layers that use MultiHeadAttention internally (like TransformerEncoder).
  • There is a case to be made that this should be the default softmax implementation (It is not default enabled in this PR)

Testing

Run checks.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small change, but thanks a lot for implementing this.

The paper is indeed very interesting, and it's nice to provide that option as well. Similarly to using normalization first, it's not the default, but it's probably the better option.

burn-tensor/src/tensor/activation/base.rs Outdated Show resolved Hide resolved
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @wbrickner
I'm happy to see how as a community we can respond quickly to new things in our field. Thanks a lot!
I have only one request, that the new function be tested, see my comment

burn-tensor/src/tensor/activation/base.rs Outdated Show resolved Hide resolved
burn-tensor/src/tensor/activation/base.rs Show resolved Hide resolved
@nathanielsimard
Copy link
Member

@wbrickner Can you fix the comments and the merge conflicts? I think it would be ready to be merged :)

@antimora
Copy link
Collaborator

@wbrickner, let us know if you need help with this. We do not wish to lose such important addition to Burn.

@wbrickner
Copy link
Contributor Author

got busy, I will implement this again today.

@wbrickner
Copy link
Contributor Author

Hey, so turns out I do need help! Sorry for screwing up the issue format by accidentally closing. I am having a lot of trouble getting burn checks to pass locally after pulling a fresh copy of burn. The implementation has been rewritten in the new copy (since a few days ago). It's fairly simple. My other road block is computing the derivatives by hand to get correct values to test against, the expressions get quite complicated. I suppose I should just write it out by hand and refresh on the differentiation rules haha. Sorry this was left open for so many months! I would like to move forward with it!

@wbrickner wbrickner reopened this Nov 22, 2023
@antimora
Copy link
Collaborator

Sounds good! We will look into this.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to run the correct formating: cargo fmt --all

@wbrickner
Copy link
Contributor Author

I think you need to run the correct formating: cargo fmt --all

To clarify, I should format my changes and update the remote repo attached to this PR?
The checks pass in github actions, is this blocked from merging?

@antimora
Copy link
Collaborator

I think you need to run the correct formating: cargo fmt --all

To clarify, I should format my changes and update the remote repo attached to this PR? The checks pass in github actions, is this blocked from merging?

OK, I have rebased it and ran the formatting. Now it's passing the checks. We probably need an approval to be merged.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are formatting issues in burn-derive cargo fmt has a hard time fixing them because of macros, but no changes should actually be made in burn-derive.

Comment on lines +75 to +79
let variant_name = &variant.ident;
let (variant_input, variant_output) = self.gen_variant_field(variant);

quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
});
quote! { Self::#variant_name #variant_input => #enum_name::#variant_name #variant_output }
});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That formatting seems odd to me and somehow fmt can't update it 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you like me to do to resolve this issue and merge the PR? cargo fmt --all results in a bit identical repository. Is the issue these two lines of whitespace? Should I modify them manually? You mention no changes should be made to burn-derive. Would like to get this closed out, apologies on the silliness of these problems.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just reset all the changes under the burn-derive directory to origin/main. :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's due to the quote! macro. rust-lang/rustfmt#8

If you comment out quote!, format, and uncomment, it'll do the right thing.

Copy link
Contributor

@AlexErrant AlexErrant Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aaaaaactshully it's due to the default max_width = 100. If you add a file called rustfmt.toml and in it is max_width = 110 then format, it works. Related. I'll open another PR about this tomorrow since it causes changes elsewhere and deserves its own discussion.

@AlexErrant AlexErrant mentioned this pull request Nov 29, 2023
2 tasks
Copy link
Member

@louisfd louisfd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been pending long enough 😅 I'll merge it, I think the formatting is fine

@louisfd louisfd merged commit 03af140 into tracel-ai:main Nov 30, 2023
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants