Skip to content

Commit

Permalink
move some function to modeling.py
Browse files Browse the repository at this point in the history
commit
  • Loading branch information
zhoutianzi666 committed Sep 5, 2023
1 parent a5cf43a commit bee52b6
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 191 deletions.
78 changes: 0 additions & 78 deletions llm/export_llama_for_minigpt4.py

This file was deleted.

19 changes: 14 additions & 5 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ class PredictorArgument:
)
inference_model: bool = field(default=False, metadata={"help": "whether use InferenceModel to do generation"})
batch_size: int = field(default=1, metadata={"help": "The batch size of data."})
max_batch_size: int = field(default=None, metadata={"help": "The max batch size of data during serving."})
max_batch_size: int = field(default=1, metadata={"help": "The max batch size of data during serving."})
llm_for_img2txt: bool = field(
default=False, metadata={"help": "whether this llm model is used for img2txt, such as miniGPT4, blip2."}
)


@dataclass
Expand Down Expand Up @@ -553,13 +556,19 @@ def create_predictor(
# TODO(wj-Mcat): complete AutoInferenceModel & AutoPredictor
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
if "llama" in config.architectures[0].lower():
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel,
)
if predictor_args.llm_for_img2txt:
# we use llama for img2txt.
from paddlenlp.experimental.transformers import (
LlamaForminiGPT4InferenceModel as LlamaInferenceModel,
)
else:
from paddlenlp.experimental.transformers import (
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
)

config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank
model = LlamaForCausalLMInferenceModel.from_pretrained(
model = LlamaInferenceModel.from_pretrained(
predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype
)
model.eval()
Expand Down
110 changes: 3 additions & 107 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_cache_kvs_shape(cls, max_batch_size: int = None, max_length: int = None)

def to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())

cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))

input_spec = [
Expand Down Expand Up @@ -82,111 +83,6 @@ def to_static(self, output_path: str, config: dict):
model = paddle.jit.to_static(self.generate, input_spec=input_spec)
paddle.jit.save(model, output_path)

# this function make generate_with_image_features to static inference model.
def generate_with_image_features_to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
input_spec = [
paddle.static.InputSpec(
shape=[None, None, None], dtype="float32", name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="first_input_ids"), # first_input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="second_input_ids"), # second_input_ids
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path)

# This function is called by miniGPT4's second part.
@paddle.no_grad()
def generate_with_image_features(
self,
image_features: paddle.Tensor,
first_input_ids: paddle.Tensor,
second_input_ids: paddle.Tensor,
attention_mask: paddle.Tensor,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**generate_kwargs
) -> paddle.Tensor:

first_embeds = self.llama.embed_tokens(first_input_ids)
second_embeds = self.llama.embed_tokens(second_input_ids)
image_features = paddle.cast(image_features, dtype=first_embeds.dtype)
inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1)

outputs = self.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

@staticmethod
def prepare_input_ids_for_generation(bos_token_id, encoder_output=None):
batch_size = 1
Expand Down Expand Up @@ -246,11 +142,11 @@ def generate(

ret = self.sample(
input_ids,
inputs_embeds,
eos_token_id,
top_p=top_p,
cache_kvs=cache_kvs,
temperature=temperature,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
return ret
Expand Down Expand Up @@ -329,11 +225,11 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e
def sample(
self,
input_ids=None,
inputs_embeds=None,
eos_token_id=None,
cache_kvs=[],
top_p=None,
temperature=None,
inputs_embeds=None,
**model_kwargs,
):
step_idx_ori = paddle.full(shape=[1], dtype="int64", fill_value=1)
Expand Down
114 changes: 113 additions & 1 deletion paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from paddlenlp.transformers.model_utils import register_base_model

__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel"]
__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel", "LlamaForminiGPT4InferenceModel"]


class FusedLlamaRMSNorm(nn.Layer):
Expand Down Expand Up @@ -439,3 +439,115 @@ def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})


class LlamaForminiGPT4InferenceModel(LlamaForCausalLMInferenceModel):
"""
This class is 99% like LlamaForCausalLMInferenceModel.
Used only for miniGPT4's second part.
"""

# This function corresponds to miniGPT4's second part, only used in miniGPT4.
@paddle.no_grad()
def generate_text_with_image_features(
self,
image_features: paddle.Tensor,
first_input_ids: paddle.Tensor,
second_input_ids: paddle.Tensor,
attention_mask: paddle.Tensor,
position_ids=None,
penalty_score=None,
frequency_score=None,
presence_score=None,
min_length=None,
max_length=None,
temperature=None,
top_p=None,
eos_token_id=None,
seq_len_encoder=None,
seq_len_decoder=None,
step_idx=None,
stop_flags=None,
tgt_ids=None,
tgt_pos=None,
tgt_generation_mask=None,
pre_ids=None,
stop_nums=None,
cache_kvs=[],
inputs_embeds=None,
**generate_kwargs
) -> paddle.Tensor:

first_embeds = self.llama.embed_tokens(first_input_ids)
second_embeds = self.llama.embed_tokens(second_input_ids)
image_features = paddle.cast(image_features, dtype=first_embeds.dtype)
inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1)

outputs = self.generate(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
penalty_score=penalty_score,
frequency_score=frequency_score,
presence_score=presence_score,
min_length=min_length,
max_length=max_length,
temperature=temperature,
top_p=top_p,
eos_token_id=eos_token_id,
seq_len_encoder=seq_len_encoder,
seq_len_decoder=seq_len_decoder,
step_idx=step_idx,
stop_flags=stop_flags,
tgt_ids=tgt_ids,
tgt_pos=tgt_pos,
tgt_generation_mask=tgt_generation_mask,
pre_ids=pre_ids,
stop_nums=stop_nums,
cache_kvs=cache_kvs,
)
return outputs

# rewrite to_static function in generation_utils.py
def to_static(self, output_path: str, config: dict):
dtype = config.get("dtype", paddle.get_default_dtype())
cache_kvs_shapes = self.get_cache_kvs_shape(self.config, max_length=config.get("max_length", None))
input_spec = [
paddle.static.InputSpec(
shape=[None, None, None], dtype="float32", name="image_features"
), # image_features
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="first_input_ids"), # first_input_ids
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="second_input_ids"), # second_input_ids
paddle.static.InputSpec(shape=[None, None], dtype=dtype, name="attention_mask"), # attention_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="position_ids"), # position_ids
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_decode_length
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature
paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p
paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_encoder"), # seq_len_encoder
paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_len_decoder"), # seq_len_decoder
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx
paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_ids"), # tgt_ids
paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="tgt_pos"), # tgt_pos
paddle.static.InputSpec(
shape=[None, 1, 1, None], dtype=dtype, name="tgt_generation_mask"
), # tgt_generation_mask
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids
paddle.static.InputSpec(shape=[1], dtype="int64", name="stop_nums"), # stop_nums
[
paddle.static.InputSpec(
shape=shape,
dtype=dtype,
name="cache_kvs_{}".format(i),
)
for i, shape in enumerate(cache_kvs_shapes)
], # cache_kvs
]

model = paddle.jit.to_static(self.generate_text_with_image_features, input_spec=input_spec)
paddle.jit.save(model, output_path)

0 comments on commit bee52b6

Please sign in to comment.