diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 535add40fcfc..3caa08608a8a 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1353,7 +1353,13 @@ class LlamaModel(LlamaPretrainedModel): """ def __init__(self, config: LlamaConfig): + ##################################### + from paddlenlp.PaddleAPEX import Acc + self.checker = Acc() + ##################################### super().__init__(config) + + self.vocab_size = config.vocab_size self.hidden_size = config.hidden_size self.sequence_parallel = config.sequence_parallel @@ -1471,6 +1477,10 @@ def forward( return_dict=False, **kwargs, ): + ##################################### + self.checker.start() + ##################################### + if self.sequence_parallel and use_cache: raise ValueError("We currently only support sequence parallel without cache.") @@ -1615,6 +1625,11 @@ def forward( next_cache = next_decoder_cache if use_cache else None + + ##################################### + self.checker.stop() + ##################################### + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPastAndCrossAttentions(