Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] DeepseekV2 Support #499

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

saurabhkoshatwar
Copy link

@saurabhkoshatwar saurabhkoshatwar commented Dec 26, 2024

Summary

Resolves #129 Add monkeypatch to support deepseepV2 model.

Details

Ops patched:

  • rms_norm
  • swiglu
  • cross_entropy
  • fused_linear_cross_entropy

Testing Done

  • Hardware Type: NVIDIA A100-SXM4-40GB
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@saurabhkoshatwar saurabhkoshatwar marked this pull request as draft December 26, 2024 00:58
@saurabhkoshatwar saurabhkoshatwar marked this pull request as ready for review December 31, 2024 20:42
@saurabhkoshatwar
Copy link
Author

saurabhkoshatwar commented Jan 7, 2025

@ByronHsu @yundai424 @Tcc0403 @qingquansong
As discussed in the issue, the rope implementation is different in DeepSeek.

deepseek:

    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)

    b, h, s, d = q.shape
    q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    b, h, s, d = k.shape
    k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

llama:

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed`

I will create a separate PR to implement the DeepSeek rope.

modeling_mod.DeepseekV2MLP.forward = LigerSwiGLUMLP.forward
if cross_entropy:
if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since its so common to use import torch.nn as nn, perhaps we import loss_utils under a different symbol?

Maybe even just from transformers.loss import loss_utils?

import sys

# Ensure the model is a DeepSeek model
if "deepseek" not in model.__class__.__module__:
Copy link
Collaborator

@tyler-romero tyler-romero Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do deepseek and deepseek-v3 share the same architecture? If so, perhaps this function should be called apply_liger_kernel_to_deepseek, if not, perhaps we should strengthen this check.

if transformer_version >= version.parse(SUPPORTED_TRANSFORMER_VERSION):
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
Copy link
Collaborator

@tyler-romero tyler-romero Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will globally patch from transformers.loss.loss_utils.function.cross_entropy, which is a pretty undersireable / unexpected side effect of applying this deepseek-specific monkey patch.

See this issue: #315

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be fixed if deepseekv2 is added to the transformers library (see below comment about trust_remote_code)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

heavy plus^

if model_name[:6] == "remote":
revert_kwargs["remote_model_module"] = MINI_MODEL_SETUPS[model_name].remote_model_module

model = create_model(model_name).to(dtype).to(device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change to create the model before applying the patch?

model_class = MINI_MODEL_SETUPS[model_name].model_class
return model_class(model_config)
if model_name[:6] == "remote":
config = AutoConfig.from_pretrained(MINI_MODEL_SETUPS[model_name].remote_model_path, trust_remote_code=True)
Copy link
Collaborator

@tyler-romero tyler-romero Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why this is necessary? Its it because the model cannot be run without trust_remote_code? As is, this default opts-in anyone who runs these unit tests into running remote code on their machine, which is a red flag.

I think a preferable path would be to add deepseekv2 to the transformers library, then add it to Liger, so that trust_remote_code is not necessary.

This also has the benefit of making it easier to follow changes that are made to the underlying model, which is a common source of bugs in Liger.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like support for deepseekv2 is underway (maybe stalled though): huggingface/transformers#31976

Comment on lines +14 to +15

DeepseekV2_INPUTS_DOCSTRING = r"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it'll be helpful to document where this part of docstring is ported from -- at least need a link to https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/main/modeling_deepseek.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feat] support for DeepseekV2
4 participants