-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Could it support Gemma? #616
Comments
We are considering many new model additions and will keep you posted! |
@solitude-alive would you be open to adding this model? I'm happy to help share specific pointers and review code if you're interested. We'd love the contribution. |
@kartikayk Yeah, I'm happy to do that. I would try it. |
@solitude-alive awesome! For a starting point, take a look at the Mistral 7B model builder in: We expose specific models through model builders which basically stitich together components (eg: Attention, RoPE, RMS Norm etc). You can find some examples here: I think adding support for gemma_2b would be similar. You just need to make sure the components line up with what Gemma is doing. |
@kartikayk Hi, _model_builders.py and _component_builders.py have been mostly completed, except for some components that need to be confirmed. Is there documentation on how to load the weights file? It seems that Gemma only support [model-00001-of-00002.safetensors, model-00002-of-00002.safetensors] rather than .bin or .pth files. |
@solitude-alive great catch! Right now, TorchTune supports only PyTorch-native .bin or .pt formats. In order to add Gemma, we need to think about a functionality to support loading
Is this something you feel comfortable adding? This would be an incredible feature b/c there's a lot of other models on HF Hub that only support safetensors, too. |
Also, @solitude-alive - would love for you to join the Discord channel (see our README for invite link) so we can quickly answer any questions you may have as you work on this! |
@joecummings Yeah, thanks. |
@solitude-alive Awesome! As @joecummings said, it would be awesome to add safetensor support to TorchTune's I verified that
And here's the output: Given that these are numerically equivalent, I think the best way forward would be if you can add a flag to Does this make sense to you? |
@kartikayk Thank you for your suggestion. |
Hi, it seems have errors on my device when I set the
|
Seems like this is actively being discussed on the discord. Once the discussion is over, we can come back and summarize it here. cc: @ebsmothers |
Thanks for the discussion, there is a temporary solution: remove any weight tying that occurs before FSDP wrapping and put weight tying here. |
Yeah to summarize the discussion on Discord: when training with FSDP the way we initialize the model undoes the weight tying. Specifically I suspect it's because we initialize on meta device. Not only that, but we cannot tie weights prior to FSDP wrapping or else we will hit a hang at our first sync point. You can see e.g. here for some discussion on the topic. We can get around this by instead tying weights after FSDP wrapping. I believe @solitude-alive already has a |
The Google model have 2B model, it seems that we can use less than 4*24GB GPUs to fine-tune with full parameters. Do you plan to support it?
The text was updated successfully, but these errors were encountered: