-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to enable Mamba2 to see all tokens when predicting the current token? #624
Comments
You don't have to see all previous tokens, the essence of Mamba(2) is to store all of that information in its most recent hidden state. To predict the next token you only need that most recent hidden state as the information about all predeceding tokens is stored there. Causal conv1d has a filter size for its convolution, depending on that size youre only dependent on the filter size - 1 preceding tokens. Your convolution is local, i.e. it only processes groups of tokens depending on your filter size - no need to see all tokens. |
Hi @vasqu, thanks for your reply. According to your interpretation, then what is said here is actually incorrect (hp-l33/AiM#5 (comment))? |
I think it's technically not incorrect as he says previous tokens, but not how many previous tokens ^^ |
The background here is that this paper aims to generate images in a way similar to LLM. In other words, when predicting the current token, only all the previous tokens can be looked at. Therefore, the "how many" here is all the previous tokens before the current token. |
So the conv size is basically all tokens/images? Yea then its a global convolution on all input. Sry, i wasnt familiar with the context. Are we predicting multiple images, one image? Imo, the same principle applies when we predict more than one image autoregressively one by one. When we predict the first image, then we have an initial hidden state of 0 and the convolution is padded accordingly with 0s on the left side (to account for the sliding convolution). |
We can predict one image or we can predict multiple images at the same time. As long as the batchsize is set to greater than 1, multiple images can be predicted simultaneously. I am not very familiar with the causal conv1d in Mamba2, but the original Mamba2 code should be used in the AiM code, that is, the causal conv1d part is the same as that of Mamb2. Combined with the content of this review (hp-l33/AiM#5 (comment)), can causal conv 1d help the model realize the causal reasoning mode? |
So we have an autoregressive model with images as input and output iiuc. The same principle will apply as for language. Your convolution is only dependent on the size-1 preceding tokens (and the current token). I think you have a misconception of causal convolution. The causal convolution is not dependent on all previous input. Can you define what is meant with causal reasoning mode? On another note if you want to infer with batches, then it is broken in this code base except you use it with packed sequences. See #66 (comment) |
Causal reasoning refers to that when predicting the current token, it will pay attention to all previous tokens, which is consistent with the reasoning method in LLM. This is realized through a mask of the lower triangle in Transformer attention. The original Mamba2 is used in nlp, so how does Mamba2 achieve this kind of causal reasoning? Thanks for your reminder, I will look at 66 comments. |
Ah ok, i was a bit off put by the term reasoning but makes sense. The causal property of causal convolution is realized by only looking at the tokens we have with a set of preceding tokens. In attention we only need that mask because otherwise we would get the influence of future tokens too. The convolution does not process futures tokens tho and hence, it does not need to block out future tokens as we only process n preceding tokens. It's by nature, causal. For mamba(2), i think it would help if you looked into rnns as a base principle of why rnns are causal. It's the same as for the convolution. We only process preceding tokens one by one, which makes it unnecessary to block out (future) information. For the batching, we only have a problem there because we need to insert padding which may introduce information we want to block. You can imagine mamba to work like bert (encoder transformer) without padding. We would also get a full attention mask. Edit: the bert comparison aint it, see it as technical perspective to show what needed to be done. |
Thank you for your reply. I'll remember how an RNN works, and it's true that an RNN can only see the previous state. Therefore, the "causal reasoning" nature of Mamba1 and Mamba2 here has little to do with this causal conv1d in essence. Is Mamba adopting the concept of RNN-like design? So couldn't Mamba have a mechanism like attention, where it can predict the current token by paying attention to all the other tokens (not just the previous ones)? |
Any mamba block consists of linear projections, causal convolution, and state space models (SSM). All operations are causal in essence. I referred to rnns as SSMs can be viewed as rnns (I thought thats what you referred to by Mamba). So yes, it's an RNN-like design, at least it can be viewed as one. However, it Mamba is still parallelized for its initial step (the first step). In essence, you could do that parallelized version on all input for each inference step but that's very inefficient when we only need the hidden state there which is given (after the first step). |
Thanks for your reply. In summary, the response here is not very accurate (hp-l33/AiM#5 (comment)). The "causal property" is a natural property of Mamba. |
Yes, at least that's my interpretation. It's not through convolution rather all operations are causal; small detail I missed. |
If we want Mamba to understand the global information of the image sequence, we should probably follow the Settings in Vim (https://github.com/hustvl/Vim). Global information here refers to the attention mechanism, which can see all the tokens. |
Yea, that seems correct. That's an encoder like bert then. |
Thank you very much! |
Hi everyone, I'm new to Mamba. I see in Mamba2 that there is causal 1D conv, so that Mamba2 can only focus on the previous tokens when predicting the current token. But if I want Mamba2 to be able to see all tokens when predicting the current token, how should I set up the Mamba2 model?
The text was updated successfully, but these errors were encountered: