-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No Query or Key Found when using nn.TransformerEncoderLayer #30
Comments
Hi @USBskycrafts , it seems that you you are using "fused attention" implementation, which merges q,k,v as a whole big matrix.
This operation will calculate vmean by the output dimension (or equivalently, by row ) of the the merged QKV matrix. This will cause a slight mismatch with our original design (calculate vmean by head of Q and K, and calculate vmean of V as a whole). I did not try this before but I guess the performance would be similar. Please have a try and see if it helps. |
@zyushun : I tried the following, but still not helping. Could you advise ? optimizer_denoiser.mlp_names = {"self_attn.in_proj_weight"}
optimizer_denoiser.output_names.add('embedding.weight') # actual output layer name, projection layer is using weight-tying with embedding layer
for i in range(num_layers):
optimizer_denoiser.wqk_names.add(f'transformer_encoder.layers.{i}.self_attn.in_proj_weight') # For query, key, and value combined
optimizer_denoiser.wqk_names.add(f'transformer_decoder.layers.{i}.self_attn.in_proj_weight') # Another example for decoder
optimizer_denoiser.wqk_names.add(f'transformer_decoder.layers.{i}.multihead_attn.in_proj_weight') # Another example for decoder
|
Hi @buttercutter . Thanks for the update. Please try the following.
You will still receive the log "Adam-mini found ...., 0 Querys, Keys, and Values." But it is okay you can ignore this. Please try :) |
@buttercutter I just realize that you are already using v1.0.5, so you don't need to re-install again. Just add the following two lines after the optimizer.
Please try :) |
@zyushun : I added those two lines, then I got this strange error.
|
Hi, @buttercutter . Good to know this... A simple twist is to remove these lines.
By removing these two lines, Adam-mini will use a single learning rate for each block under pytorch default partition (except for the embedding and output layers, where it will use Adam). There is no guarantee that this would perform well but you can have a try. Usually, it does not work well for pre-training but it can work for finetuning. |
@zyushun : Thanks and I really appreciate your prompt reply. So far, my model training run memory consumption does not decrease.
optimizer_denoiser.mlp_names = {"self_attn.in_proj_weight"}
optimizer_denoiser.output_names.add('embedding.weight') # actual output layer name, projection layer is using weight-tying with embedding layer
for i in range(num_layers):
optimizer_denoiser.wqk_names.add(f'transformer_encoder.layers.{i}.self_attn.in_proj_weight') # For query, key, and value combined
optimizer_denoiser.wqk_names.add(f'transformer_decoder.layers.{i}.self_attn.in_proj_weight') # Another example for decoder
optimizer_denoiser.wqk_names.add(f'transformer_decoder.layers.{i}.multihead_attn.in_proj_weight') # Another example for decoder |
Hi @buttercutter You can remove all these lines and try again.
"memory consumption does not decrease" this seems weird. Did you use any other orthogonalization tricks to AdamW, like quantization or cpu-offload? Or are you using PagedAdam (which will use quantization)? |
@zyushun : Thanks again for your prompt response.
Noted. I increase the model internal dimension size for both the encoders and decoders, yet still seeing no decreasing memory consumption trend.
No. By the way, I am using MPS backend. |
Hi @buttercutter , How large is your model? One possible reason: your model is too small so the embedding + output layer takes the major proportion of memory. In this case, Adam-mini, at least for v.1.0.5, takes the similar memory to Adam. The advantage of memory-reduction usually becomes significant when model size reaches 1B, where the embedding & output takes <10% of total params , and Adam-mini saves 45% memory over AdamW. |
You are right, embedding + output layer took the most memory as compared to the encoders and decoders in my code.
Would there be any plan for enabling |
@buttercutter Yes! We have developed a new version of Adam-mini (would be v.1.0.6) and it will also cut down the memory for embedding & output layers to 50%. We will also update the paper soon accordingly. I will keep you noticed once we updated v.1.0.6. 😄 |
@zyushun : I noticed that Adam-mini version is currently at v.1.1.0 For the new version, could I skip the following warnings before I check the traceback error ?
|
Hi @buttercutter, yes, we have updated 1.1.0 version of Adam-mini, which saves memory for the embedding & output layers.
Note that we still assume embedding & output layers are matrices instead of long vectors. So it might raise error if your codebase will automatically reshape the embedding & output parameters to vectors. If this is the case, then you can put embedding and output layer to
We will try to support the case where "weight matrices are stretched as long vectors" in future versions, perhaps in v.1.1.1. |
Hi @zyushun , sorry for overwhelming you with a lot of technical questions. I had added the following naming instantiation schemes to help Adam_Mini locates the layers, but they are not able to do so according to the warning log. optimizer_ebm.embd_names.add('embedding') # add the keyword of the embedding layer
optimizer_ebm.output_names.add('denoise_head') # output layer of EBM model is not using projection layer
optimizer_denoiser.embd_names.add('embedding') # add the keyword of the embedding layer
optimizer_denoiser.output_names.add('projection') # projection layer is using weight-tying with embedding layer
optimizer_ebm.mlp_names = {"self_attn"}
optimizer_denoiser.mlp_names = {"self_attn"}
optimizer_ebm.mlp_names.add("attn")
optimizer_ebm.mlp_names.add("linear")
optimizer_denoiser.mlp_names.add("attn")
optimizer_denoiser.mlp_names.add("linear")
optimizer_denoiser.wqk_names.add("self_attn") # For query, key, and value combined
optimizer_denoiser.wqk_names.add("multihead_attn")
|
I met the same Warning (No XXX found), when i try to run run_gpt2.sh (gpt2_small) |
I thought the issue came from the self.named_parameters in the class Adam-mini. Because it's a generator, it will have no elements after we have a loop in the init function. So the warning will appear the second time call on this generator in the count_block function. The problem was solved when I changed the code |
@Sun2018421 Hi! Thanks for the update. Sorry for the late response since I am traveling recently. I will get back to your question as soon as I am settled. Yushun |
@zyushun Thank you very much for your reply and wish you a pleasant trip :) |
Hi @Sun2018421 , we have updated Adam-mini to v.1.1.1 and this issue is fixed. Please pip uninstall and install adam-mini again. Thanks a lot for mentioning this issue! We have acknowledged your help in readme. |
@zyushun Thanks for getting it to I am still getting the following runtime error with this latest version. It seems that the previous error popped up again even when
|
@zyushun @buttercutter i get something like
any sollution |
i think with FSDP and tensor-paralllel you will get
|
@buttercutter @zyushun @awgu how can it work with torchtitan Tensorparalllel and FSDP2 if that is the case |
there is a warning says "=====>>> Warning by Adam-mini: No Query or Key found. If you are training Transformers......".
The existence of Key and Query is judged by
self.wqk_names = {"k_proj.weight", "q_proj.weight", "wq.weight", "wk.weight"}
, and there is only aself_attn.in_proj_weight
in thenn.TransformerEncoderLayer
. So I think more works need to be done to solve this situation.The text was updated successfully, but these errors were encountered: