-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement bi-directionality #52
base: main
Are you sure you want to change the base?
Conversation
7692f59
to
52f57d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some nits
What if the sequences have paddings? E.g. |
@sentialx , agreed. That's a good catch. |
how the speed compares to uni-directional? |
@jimmieliu, it's about 2x |
@yair-schiff I am just curious, did you solve the
Just curious, is this problem solved? |
I came up with a solution to the padding issue. Say a tensor [1,2,3,0,0], where 0 is the padding token. We flip it to get [0,0,1,2,3], pass it to the network and flip it back. Therefore, the flipped tensor information matches the original tensor order as we apply double flips.
|
6d45666
to
41d30ce
Compare
Hi, Your approach is clever! But I have a question: if you flip the input to [0,0,1,2,3], does the padding in front of it affect sequence hidden features learning? i.e., does it produce a different result(bad repersentation of sequence) than the input of [3,2,1,0,0]? |
@xuanwuji well, you can remove the leading paddings by shifting each row of x before flipping x. As for its effect, since the hidden state is initialized with 0, it should still be filled with 0 after scanning through the paddings. So, those padding shouldn't have any effect on the result. However, you can use the following function just to be sure.
To check the effect of paddings:
However, these errors do stack after multiple layers, so you should use the |
Edit:
Mamba
module twice: (1) to the forward sequence and (2) to the backward sequence.32 strategies for combining forward / backward Mamba hidden states:add
: Add the states.concat
: Concatenate the states. This doubles the hidden dimension,d_model
, which also prevents weight tying betweenembedding
andlm_head
weights.ew_multiply
: perform element-wise multiplication between the states.