Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LLM] Add pipeline and flashmask for Qwen2Moe and Deepseek #9827

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,18 @@
AutoModelForCausalLM,
AutoModelForCausalLMPipe,
AutoTokenizer,
DeepseekV2ForCausalLM,
DeepseekV2ForCausalLMPipe,
DeepseekV3ForCausalLM,
DeepseekV3ForCausalLMPipe,
Llama3Tokenizer,
LlamaForCausalLM,
LlamaForCausalLMPipe,
LlamaTokenizer,
Qwen2ForCausalLM,
Qwen2ForCausalLMPipe,
Qwen2MoeForCausalLM,
Qwen2MoeForCausalLMPipe,
)
from paddlenlp.transformers.configuration_utils import LlmMetaConfig
from paddlenlp.trl import DataConfig, ModelConfig, SFTConfig, SFTTrainer
Expand All @@ -74,7 +80,18 @@
# Fine-tune Environment Variables to support sharding stage1 overlap optimization.
os.environ["USE_CASUAL_MASK"] = "False"

flash_mask_support_list = [LlamaForCausalLM, LlamaForCausalLMPipe, Qwen2ForCausalLM, Qwen2ForCausalLMPipe]
flash_mask_support_list = [
DeepseekV2ForCausalLM,
DeepseekV2ForCausalLMPipe,
DeepseekV3ForCausalLM,
DeepseekV3ForCausalLMPipe,
LlamaForCausalLM,
LlamaForCausalLMPipe,
Qwen2ForCausalLM,
Qwen2ForCausalLMPipe,
Qwen2MoeForCausalLM,
Qwen2MoeForCausalLMPipe,
]


def paddlenlp_verison_check():
Expand Down Expand Up @@ -151,7 +168,11 @@ def main():
quantization_config=quantization_config,
)

if "Qwen2Moe" in str(model_config.architectures) and training_args.data_parallel_degree > 1:
architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
if (
any(architecture in str(model_config.architectures) for architecture in architectures_to_check)
and training_args.data_parallel_degree > 1
):
training_args.use_expert_parallel = True

LlmMetaConfig.set_llm_config(model_config, training_args)
Expand Down Expand Up @@ -585,7 +606,12 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
def trans_dataset_to_ids(train_ds, dev_ds, test_ds, model_args, data_args, trans_func, eval_zero_padding):
if train_ds is not None:
train_ds = train_ds.map(
partial(trans_func, is_test=False, zero_padding=data_args.zero_padding, flash_mask=model_args.flash_mask)
partial(
trans_func,
is_test=False,
zero_padding=data_args.zero_padding,
flash_mask=model_args.flash_mask,
)
)
if dev_ds is not None:
dev_ds = dev_ds.map(
Expand Down
6 changes: 5 additions & 1 deletion llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,11 @@ def main():
except:
print("Not register llama pp reshard information.")

if "Qwen2Moe" in str(config.architectures) and training_args.data_parallel_degree > 1:
architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"}
if (
any(architecture in str(config.architectures) for architecture in architectures_to_check)
and training_args.data_parallel_degree > 1
):
training_args.use_expert_parallel = True

if model_args.continue_training:
Expand Down
4 changes: 3 additions & 1 deletion llm/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@ def get_convert_example(model):
"gpt",
"yuan",
"jamba",
"deepseek_v2",
"deepseek_v3",
]:
return convert_example_common
else:
raise ValueError(
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe, yuan, jamba",
f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma, qwen2, qwen2_moe, yuan, jamba,deepseek_v2, deepseek_v3",
)


Expand Down
7 changes: 2 additions & 5 deletions paddlenlp/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,8 @@
from .deberta_v2.configuration import *
from .deberta_v2.modeling import *
from .deberta_v2.tokenizer import *
from .deepseek_v2.configuration import *
from .deepseek_v2.modeling import *
from .deepseek_v2.tokenizer_fast import *
from .deepseek_v3.configuration import *
from .deepseek_v3.modeling import *
from .deepseek_v2 import *
from .deepseek_v3 import *
from .distilbert.configuration import *
from .distilbert.modeling import *
from .distilbert.tokenizer import *
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/transformers/deepseek_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@

from .configuration import *
from .modeling import *
from .modeling_pp import *
from .tokenizer_fast import *
Loading