Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Phi3 Mini 4K Instruct Model to torchtune #876

Merged
merged 13 commits into from
Apr 28, 2024
Merged

Add Phi3 Mini 4K Instruct Model to torchtune #876

merged 13 commits into from
Apr 28, 2024

Conversation

kartikayk
Copy link
Contributor

@kartikayk kartikayk commented Apr 26, 2024

Imp Note: The tokenizer still needs some work, this will be a follow up PR.

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

This PR adds support for the Phi3 Mini 4K Instruct model to torchtune. Specifically we add the following:

  • Phi3's RoPE module. This is not numerically equivalent to the Llama2 or Llama3 models and care needs to be taken to have correct behavior for bf16 training
  • State dict conversion logic accounting for the fused qkv and gate_up projection matrices
  • Config for multi-gpu full finetuning

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Unit Tests

Detailed comparisons with reference implementation

pytest tests/torchtune

Full-finetune Recipe

Screenshot 2024-04-25 at 7 57 45 PM

Copy link

pytorch-bot bot commented Apr 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/876

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 872911a with merge base 3890200 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 26, 2024
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config recipes/config/phi3/mini_full_low_memory.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# tune run full_finetune_single_device --config recipes/config/phi3/mini_full_low_memory.yaml
# tune run full_finetune_single_device --config phi3/mini_full_low_memory.yaml

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I havent added this to the registry, so this won't work. I'll let you take care of that

# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config recipes/config/phi3/mini_full_low_memory.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# tune run full_finetune_single_device --config recipes/config/phi3/mini_full_low_memory.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
# tune run full_finetune_single_device --config phi3/mini_full_low_memory.yaml checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

x_out = rope_phi3(input)

# check the numerics of the computed tensor
assert_expected(x_out.mean(), tensor(-0.0005), atol=1e-4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you include in the PR where these numbers come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Info's in the gist! Don't really want to replicate all of the info in the context again

from torch import nn, Tensor


class Phi3RotaryPositionalEmbeddings(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're a hero

num_kv_heads=self._config["num_key_value_heads"],
dim=self._config["hidden_size"],
)
if self._model_type == ModelType.PHI3_MINI:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does capitalization change this? B/c in the config it's PHI_MINI and in the enum it's phi_mini

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it matters!

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly nits! Thanks @kartikayk


def phi3_tokenizer(path: str) -> SentencePieceTokenizer:
tokenizer = SentencePieceTokenizer(path)
tokenizer.pad_id = 32000
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the HF config it says that eos_id and pad_id are the same, is that right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but also will let @joecummings confirm that!

torchtune/models/phi3/_position_embeddings.py Outdated Show resolved Hide resolved
torchtune/models/phi3/_position_embeddings.py Show resolved Hide resolved
TODO: The implementation below can be made more efficient
for inference.
"""
# input tensor has shape [b, s, n_h, n_d]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(based on the above docstring)

Suggested change
# input tensor has shape [b, s, n_h, n_d]
# input tensor has shape [b, s, n_h, h_d]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh good catch!

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@kartikayk kartikayk merged commit 23c5585 into main Apr 28, 2024
27 checks passed
@kartikayk kartikayk deleted the phi3 branch April 28, 2024 15:44
@fyabc fyabc mentioned this pull request Jul 9, 2024
11 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants