-
Notifications
You must be signed in to change notification settings - Fork 441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma #630
Gemma #630
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/630
Note: Links to docs will display an error until the docs builds have been completed. ❌ 8 New FailuresAs of commit 321f59e with merge base aacaadd (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great - thanks for the contribution! I left a couple comments, but generally looks good.
Can you add a screenshot of running a distributed full finetune with Gemma to confirm it works?
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation) | ||
|
||
|
||
def lora_gemma( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like you included the LoRA version of Gemma for this PR. Are you planning on including LoRA, as well, or just starting with the full fine-tuning version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May take a longer time to complete LoRA version of Gemma, could I PR full fine-tuning version first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be fine to start with just full fine-tune for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@solitude-alive Can you remove all the LoRA code since we won't be addressing it in this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I remove them in the latest version.
): | ||
super().__init__() | ||
self.w1 = gate_proj | ||
self.w2 = down_proj | ||
self.w3 = up_proj | ||
self.activation = F.silu |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good abstraction!
@@ -11,14 +11,17 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.optim as optim | |||
from safetensors import safe_open |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add this to requirements.txt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, thank you for your suggestion, I add it in the latest version.
TransformerDecoder: Instantiation of Gemma 2B model | ||
""" | ||
return gemma( | ||
vocab_size=256_000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still shocked by this vocab size - so large!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, 😂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean the embedding(/output projection since they're tied) constitutes a full 25% of their params?!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I calculate it with count_trainable_parameters
, the params of embed_tokens is 21%.
def count_trainable_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
recipes/full_finetune_distributed.py
Outdated
@@ -203,6 +210,7 @@ def _setup_model( | |||
cfg_model: DictConfig, | |||
enable_activation_checkpointing: bool, | |||
model_state_dict: Dict[str, Any], | |||
mode_tie: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mode_tie: bool = False, | |
model_tie: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion, I have fixed it in the latest version.
recipes/full_finetune_distributed.py
Outdated
@@ -259,6 +267,10 @@ def _setup_model( | |||
), | |||
) | |||
|
|||
if mode_tie: # Tie the weights of the model if required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if mode_tie: # Tie the weights of the model if required | |
if model_tie: # Tie the weights of the model if required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion, I have fixed it in the latest version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Really excited to see how nicely this is shaping up.
Re testing, aside from making sure training runs, let's try to get a sanity check that the model forward here lines up with the one from the original implementation on some dummy data (assuming you haven't done so already). We have a bunch of scripts we've used in the past for this with various components in the library, so you can use these as a reference if it helps. For example (Note: you do not have to actually write a script like this and check it in, this is meant more as a reference if it helps you)
recipes/configs/gemma/2B_full.yaml
Outdated
# --config gemma/2B_full \ | ||
# checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
# This config works best when the model is being fine-tuned on 2+ GPUs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think if we are running with full_finetune_distributed
recipe it will only work on 2+ GPUs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I update it in the latest version.
|
||
def gemma_2b() -> TransformerDecoder: | ||
""" | ||
Builder for creating a Gemma 2B model initialized w/ the default 2b parameter values |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add pointer to the paper or blog post here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I add it in the latest version.
TransformerDecoder: Instantiation of Gemma 2B model | ||
""" | ||
return gemma( | ||
vocab_size=256_000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean the embedding(/output projection since they're tied) constitutes a full 25% of their params?!
return FeedForward(gate_proj=gate_proj, down_proj=down_proj, up_proj=up_proj, activation=activation) | ||
|
||
|
||
def lora_gemma( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should be fine to start with just full fine-tune for now
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | ||
|
||
from torchtune.utils._distributed import contains_fsdp | ||
from transformers.utils import is_safetensors_available |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shouldn't import from transformers here as it's not in our core dependencies. If you've added safetensors to our core dependencies (based on the above comment) probably don't need to do this check anyways.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I remove it in the latest version.
torchtune/modules/feed_forward.py
Outdated
@@ -25,12 +26,13 @@ def __init__( | |||
gate_proj: nn.Module, | |||
down_proj: nn.Module, | |||
up_proj: nn.Module, | |||
activation: nn.Module = F.silu, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: technically F.silu is a Callable, not an nn.Module
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I replace it with nn.SiLU() in the latest version.
Default: False | ||
|
||
Returns: | ||
FeedForward: instantiation of the MLP module with LoRA applied to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the second line of this docstring got lost somewhere along the way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I fix it in the latest version.
recipes/full_finetune_distributed.py
Outdated
if cfg.checkpointer.model_type == "GEMMA": | ||
model_tie = True | ||
else: | ||
model_tie = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also consider associating a weight tying config explicitly with the model type and using that in the checkpointer. E.g.
@dataclass
ModelType
name: str
weight_tying_config: Dict[str, str] = field(default_factory=dict)
Then Gemma would be ModelType(name="GEMMA", weight_tying_config={"tok_embeddings.weight": "output.weight"}
(Anyways, not a blocker for this PR as it's more of a design question)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion, I update in the gemma_full_finetune.py
.
recipes/full_finetune_distributed.py
Outdated
@@ -259,6 +267,10 @@ def _setup_model( | |||
), | |||
) | |||
|
|||
if model_tie: # Tie the weights of the model if required | |||
model.output.weight = model.tok_embeddings.weight |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was pointed out by @rohan-varma that this may not actually do what we expect because FSDP has already sharded the params, so let's double-confirm via testing that the weights are tied correctly here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing out it, I check the model weight after training, they are not same. Is there any solution? I'm not familiar with that. This can cause some problems if the weights are tied before FSDP. issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK sorry for the back and forth on this. Confirmed with @rohan-varma that we should not tie weights after FSDP wrapping after all. The main issue was not FSDP but the initialization on meta device. Unfortunately, weight tying + meta device is tricky because the usage of to_empty
breaks existing references.
Instead, for Gemma we can do everything on CPU without using meta device at all, basically initializing the model on CPU for every rank and then defining a more vanilla FSDP without the param_init_fn
we currently have. This should work fine for smaller models (at least up to 7B). @kartikayk put together a snippet on what this can look like, you can find it here.
We need to decide what the best way to expose this is, but for now feel free to create a separate recipe for Gemma, e.g. gemma_full_finetune.py
. It should look pretty much the same as the existing full_finetune_distributed.py
, but with the changes needed to initialize everything on CPU and perform weight tying there before wrapping with FSDP.
Thanks also to @awgu for helping debug this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion, it works well.
Co-authored-by: ebsmothers <ebs@meta.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this PR @solitude-alive! This would be an awesome contribution to the repo!
Similar to @joecummings I have some questions in the code. My biggest question though is correctness. The loss from the screenshot seems to be much higher than what we've seen with Mistral/Llama2. have you compared this loss for gemma with the official implementation/some other implementation? Or have you seen some issues/blogs which show case the loss value during training that we can compare against?
Also, when adding models we provide some evidence of model numerical correctness - this is really important to build confidence with our users. Please see how we did this for llama2 13B and mistral 7B in the context section of this PR: #571. Would be great if you can add a similar check for Gemma2B. This check would look something like:
- Load official implementation of Gemma2B and take a random tensor, run forward and get output
- Load torchtune implementation, take same tensor, run forward and get output
- Compare outputs with torch.allclose and make sure this returns a True.
recipes/full_finetune_distributed.py
Outdated
@@ -259,6 +267,10 @@ def _setup_model( | |||
), | |||
) | |||
|
|||
if model_tie: # Tie the weights of the model if required |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the use of model_tie
to be a bit unintuitive. Can we rename this to something like share_weights
or share_embed
since I don't think we'll have other modules we share?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I update them in the latest version.
@@ -1,6 +1,7 @@ | |||
# Hugging Face Integration Reqs | |||
datasets | |||
huggingface_hub | |||
safetensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joecummings do we need to explicitly add this if it's a part of the huggingface_hub? I guess it's good practice to explicitly call out?
Tie the weights of the output embeddings and the token embeddings in the model. | ||
|
||
Args: | ||
model (TransformerDecoder): The to tie the weights of the output embeddings and the token embeddings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sentence is missing some info: "the to tie" reads a bit weird
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I modified it in the latest version.
num_kv_heads=1, | ||
embed_dim=2048, | ||
intermediate_dim=16384, | ||
max_seq_len=32768, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this right? I thought this was 8192 for Gemma 2B
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, this is my problem, I fixed it in the latest version.
@@ -383,6 +383,14 @@ def load_checkpoint(self) -> Dict[str, Any]: | |||
dim=self._config["hidden_size"], | |||
) | |||
|
|||
if ( | |||
self._model_type == "GEMMA" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm so I have a question about this code.
hf_to_tune
makes an assumption that head_dim * num_heads = dim
(see here).
But this isn't true for Gemma 7B where num_heads=16
and head_dim= 256
but dim=3072
and not 4096
. So we will need to differentiate between gemma 2b and 7b here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, please move to a utility function in checkpointer_utils so we can keep this code clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I allow explicit parameter num_heads
passing in function hf_to_tune
, is this allowed?
And I moved them to a utility function in checkpointer_utils.
print(f"======={self._model_type}==========") | ||
if ( | ||
self._model_type == "GEMMA" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move this to a separate utility function in checkpointer_utils? we should keep the checkpointer as clean as possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I did it in the latest version.
"because it is the same as the model embed_tokens weight" | ||
) | ||
else: | ||
self._weight_map["lm_head.weight"] = "0002" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will this else block hit? If we know the checkpoints don't contain this key, let's just work with that assumption? Anyways we're hard coding a bunch of stuff like the name of the key etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was because the parameters were not really tied before. Now I removed the else block.
state_dict = torch.load( | ||
str(checkpoint_path), map_location="cpu", mmap=True, weights_only=True | ||
) | ||
if str(checkpoint_path).endswith(".safetensors") and is_safetensors_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a fan of this approach. Can we just add a key to the config, something like is_safetensors_file
and then based on the value determine if we use torch.load or not. Also please break this down into a sub function (eg: load_from_safetensor
or something similar.
@joecummings WDYT?
…the same as model embed_tokens weight
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A huge thank you @solitude-alive for adding this functionality and also patiently addressing the many review comments. This functionality makes TorchTune better and we really appreciate all of your hard work. I'll merge this into MAIN, make a few small changes to the core recipes based on some upcoming changes and then add this to our README and cite you as the author. Thanks so much for all of the hard work!
Co-authored-by: ebsmothers <ebs@meta.com>
state_dict = result | ||
else: | ||
state_dict = torch.load( | ||
str(checkpoint_path), map_location="cpu", mmap=True, weights_only=True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like weights_only
arg is not passed around here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yes, but I just looked at the latest version and it has been updated.
state_dict = torch.load(
str(checkpoint_path),
map_location="cpu",
mmap=True,
weights_only=weights_only,
)
mmap=True, | ||
weights_only=weights_only, | ||
is_safetensors_file = ( | ||
True if str(checkpoint_path).endswith(".safetensors") else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: btw this seems to be the same as:
is_safetensors_file = str(...).endswith(".safetensors")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah.
Context
Changelog
Test plan