diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index d0df32321e18..2d8521bcfbcf 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -483,6 +483,12 @@ def main(): config.num_attention_heads % config.sep_parallel_degree == 0 ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + if paddle.is_compiled_with_xpu() and training_args.gradient_accumulation_steps > 1: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + print("Final pre-training config:", config) # Set the dtype for loading model diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index d70e63ffa484..1b511602421e 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -174,7 +174,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int): return assignment_list -def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): +def parallel_matmul(matmul_op, x: Tensor, y: Tensor, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 try: @@ -192,7 +192,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) - logits = paddle.matmul(input_parallel, y, transpose_y=False) + logits = matmul_op(input_parallel, y, transpose_y=False) if tensor_parallel_output: return logits @@ -200,7 +200,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) else: - logits = paddle.matmul(x, y, transpose_y=False) + logits = matmul_op(x, y, transpose_y=False) return logits @@ -413,6 +413,10 @@ def forward(self, hidden_states): if self.config.use_fused_rms_norm: if get_env_device() == "npu": return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0] + elif get_env_device() == "xpu": + import paddle_xpu_nn + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): @@ -582,6 +586,14 @@ def __init__(self, config): ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear + elif get_env_device() == "xpu": + from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) + + ColumnParallelLinear = XPUColumnSequenceParallelLinear + RowParallelLinear = XPURowSequenceParallelLinear else: ColumnParallelLinear = ColumnSequenceParallelLinear RowParallelLinear = RowSequenceParallelLinear @@ -589,6 +601,11 @@ def __init__(self, config): ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear + if get_env_device() == "xpu": + Linear = paddle_xpu.layers.nn.Linear # noqa: F821 + else: + Linear = nn.Linear + if config.tensor_parallel_degree > 1: if config.fuse_attention_ffn: self.gate_up_fused_proj = ColumnParallelLinear( @@ -619,12 +636,12 @@ def __init__(self, config): ) else: if config.fuse_attention_ffn: - self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) else: - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) def forward(self, x): if self.fuse_attention_ffn: @@ -689,7 +706,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.use_fused_rope = config.use_fused_rope if self.use_fused_rope and get_env_device() != "npu": - if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + if ( + "gpu" not in paddle.device.get_device() + or "xpu" not in paddle.device.get_device() + or fused_rotary_position_embedding is None + ): warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " "Will disable fuse rope. Try using latest gpu version of Paddle." @@ -705,6 +726,14 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear + elif get_env_device() == "xpu": + from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) + + ColumnParallelLinear = XPUColumnSequenceParallelLinear + RowParallelLinear = XPURowSequenceParallelLinear else: ColumnParallelLinear = ColumnSequenceParallelLinear RowParallelLinear = RowSequenceParallelLinear @@ -712,6 +741,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear + if get_env_device() == "xpu": + Linear = paddle_xpu.layers.nn.Linear # noqa: F821 + else: + Linear = nn.Linear + if config.tensor_parallel_degree > 1: if self.fuse_attention_qkv: self.qkv_proj = ColumnParallelLinear( @@ -741,12 +775,12 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): gather_output=False, ) else: - self.k_proj = nn.Linear( + self.k_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - self.v_proj = nn.Linear( + self.v_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, @@ -754,23 +788,23 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): else: if self.fuse_attention_qkv: - self.qkv_proj = nn.Linear( + self.qkv_proj = Linear( self.hidden_size, self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) else: - self.q_proj = nn.Linear( + self.q_proj = Linear( self.hidden_size, self.hidden_size, bias_attr=False, ) - self.k_proj = nn.Linear( + self.k_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - self.v_proj = nn.Linear( + self.v_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, @@ -784,7 +818,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): input_is_parallel=True, ) else: - self.o_proj = nn.Linear( + self.o_proj = Linear( self.hidden_size, self.hidden_size, bias_attr=False, @@ -1428,6 +1462,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16") expanded_attn_mask = expanded_attn_mask.astype("float16") expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) + expanded_attn_mask = expanded_attn_mask.astype(dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) else: expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) return expanded_attn_mask @@ -1708,6 +1747,13 @@ def __init__(self, config: LlamaConfig): self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False if self.weight.is_distributed: self.weight.split_axis = 1 + if paddle.is_compiled_with_xpu(): + from paddle_xpu.layers.nn import xpu_matmul # noqa: F401 + + self._xpu_matmul = xpu_matmul() + self.matmul_op = self._xpu_matmul.forward + else: + self.matmul_op = paddle.matmul def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sequence_parallel: @@ -1721,7 +1767,13 @@ def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + matmul_op = self.matmul_op + if paddle.is_compiled_with_xpu(): + from functools import partial + + matmul_op = partial(matmul_op, training=self.training) + + logits = parallel_matmul(matmul_op, hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) return logits