Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral testing #888

Merged
merged 8 commits into from
May 1, 2024
Merged

Mistral testing #888

merged 8 commits into from
May 1, 2024

Conversation

SalmanMohammadi
Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi commented Apr 26, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

#848

Changelog

I've started adding scripts to verify the implementation of mistral. I'm using the reference implementation from the official repo. There's another implementation in the repo which uses xformers for the attention mechanism, but it's not straightforward to replicate. I ended up running into lots of issues when I initially tried.

So far, I've added a script to compare the attention implementation. I've verified the attention implementation produces consistent ouputs using python -m tests.torchtune.models.mistral.scripts.compare_attention. I'll be keeping the reference implementation in tests/torchtune/models/mistral/scripts/mistral_reference.py.

Next steps

I'm generally following this process - the plan is to continue copying and testing the components of the mistral implementation, and then testing models as a whole and implementing mapping torchune.models.mistral into the reference implementation. Finally, I'll add unit tests to integrate into CI.

Good to make sure I'm not too far off the mark : )

Copy link

pytorch-bot bot commented Apr 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/888

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 190cc8a with merge base bec7bab (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 26, 2024
@SalmanMohammadi SalmanMohammadi marked this pull request as draft April 26, 2024 23:23
@SalmanMohammadi SalmanMohammadi changed the title Adding initial script for testing mistral reference attention Mistral testing Apr 27, 2024
… model. Added comparison scripts and verified correctness.
@SalmanMohammadi
Copy link
Collaborator Author

I've updated scripts for the rest of the mistral components. I need to write the comparison involving mapping state dicts, update the unit test, and (potentially) add LoRA comparisons.

@SalmanMohammadi
Copy link
Collaborator Author

SalmanMohammadi commented Apr 27, 2024

Okay, all seems good. We now have a unit test for the base mistral model using the copied implementation from the mistral repo.

For the unfortunate reviewer seeing my +1160 line PR (I hope you read this first!):
I'm hoping the mistral/scripts/compare_{component}.py for the individual components weren't unnecessary - I'm realising that we already test each component in llama2/scripts/compare_{component}.py. The only component that I'm testing that wasn't compared in llama2/scripts/ is mistral_mlp. Maybe it's good that they've been verified with two implementations? If it's not useful I can take them out - mistral/scripts/compare_mistral.py is the main one.

@@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main file used for comparing implementations.

@ebsmothers
Copy link
Contributor

Thanks for all this extensive testing!

I'm hoping the mistral/scripts/compare_{component}.py for the individual components weren't unnecessary - I'm realising that we already test each component in llama2/scripts/compare_{component}.py. The only component that I'm testing that wasn't compared in llama2/scripts/ is mistral_mlp.

I think we wanna find the right balance of rigorous testing and maintenance here. So while I don't want your work to be in vain, I wonder if we should just add those comparison scripts that differ nontrivially from the Llama2 ones, and for other components point to the Llama2 ones. So in this case that would mean keep compare_mistral and compare_feedforward (since you mentioned it's not tested under llama2). Then you can add a readme to tests/torchtune/models/mistral/scripts (similar to this one in the llama2 scripts directory) and state that components X, Y, and Z are identical to the llama2 ones and their comparison scripts can be found in that directory. (If you want you can even move the MLP comparison under llama2 so that everything is colocated, but tbh I have no strong preference here.)

@SalmanMohammadi
Copy link
Collaborator Author

not in vain at all - I learnt lots! I've updated and added a README.

@SalmanMohammadi SalmanMohammadi marked this pull request as ready for review April 30, 2024 09:38
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Two small nits, with green CI this is good to merge

tests/torchtune/models/mistral/scripts/README.md Outdated Show resolved Hide resolved
tests/torchtune/models/mistral/scripts/compare_mistral.py Outdated Show resolved Hide resolved
@SalmanMohammadi
Copy link
Collaborator Author

Thanks again for your review @ebsmothers :)

@ebsmothers ebsmothers merged commit 06c5fcb into pytorch:main May 1, 2024
29 checks passed
@SalmanMohammadi SalmanMohammadi deleted the mistral-tests branch July 20, 2024 22:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants