Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Inference]qwen2-a8w8c8 support use_fake_parameter #9109

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddlenlp/experimental/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@
self.key_map = key_map_dict
self.scale = {}
for scale_type, key_template in self.key_map.items():
self.scale[scale_type] = np.full([num_of_layers], fill_value=-1.0)
self.scale[scale_type] = np.full([num_of_layers], fill_value=-1.0, dtype="float32")

Check warning on line 335 in paddlenlp/experimental/model_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/model_utils.py#L335

Added line #L335 was not covered by tests
for i in range(num_of_layers):
if key_template.replace("#", str(i)) in self.scale_dict.keys():
self.scale[scale_type][i] = 1 / self.scale_dict[key_template.replace("#", str(i))]
Expand Down
146 changes: 117 additions & 29 deletions paddlenlp/experimental/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@
GenerationBlockInferenceModel,
GenerationInferenceModel,
)
from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained
from paddlenlp.experimental.transformers.utils import (

Check warning on line 52 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L52

Added line #L52 was not covered by tests
EmptyActScale,
EmptyCacheScale,
EmptyWeightScale,
infererence_model_from_pretrained,
)
from paddlenlp.transformers import Qwen2Config, Qwen2PretrainedModel
from paddlenlp.transformers.conversion_utils import split_param_func
from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions,
Expand Down Expand Up @@ -113,6 +118,8 @@
self.smooth = config.quantization_config.smooth
self.shift_smooth_all_linears = config.quantization_config.shift_smooth_all_linears

self.use_fake_parameter = config.get("use_fake_parameter", False)

Check warning on line 121 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L121

Added line #L121 was not covered by tests

if self.use_weight_only:
assert (
self.quant_type == "weight_only_int8" or self.quant_type == "weight_only_int4"
Expand Down Expand Up @@ -603,6 +610,30 @@

if "a8w8" in self.quant_type:
if self.shift_smooth_all_linears:
if self.use_fake_parameter:
if "qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx) not in state_dict:
state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)] = paddle.zeros(

Check warning on line 615 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L613-L615

Added lines #L613 - L615 were not covered by tests
shape=[
(self.num_attention_heads // self.config.tensor_parallel_degree)
* (self.hidden_size // self.num_attention_heads)
],
dtype=paddle.get_default_dtype(),
)
state_dict["qwen2.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] = paddle.ones(

Check warning on line 622 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L622

Added line #L622 was not covered by tests
shape=[
(self.num_attention_heads // self.config.tensor_parallel_degree)
* (self.hidden_size // self.num_attention_heads)
],
dtype=paddle.get_default_dtype(),
)
state_dict["qwen2.layers.{}.mlp.down_proj.shift_bias".format(idx)] = paddle.zeros(

Check warning on line 629 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L629

Added line #L629 was not covered by tests
shape=[self.intermediate_size // self.config.tensor_parallel_degree],
dtype=paddle.get_default_dtype(),
)
state_dict["qwen2.layers.{}.mlp.down_proj.smooth_weight".format(idx)] = paddle.ones(

Check warning on line 633 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L633

Added line #L633 was not covered by tests
shape=[self.intermediate_size // self.config.tensor_parallel_degree],
dtype=paddle.get_default_dtype(),
)
self.transformer_block.linear_shifts[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
)
Expand All @@ -617,6 +648,32 @@
)

if self.shift:
if self.use_fake_parameter:
if "qwen2.layers.{}.input_layernorm.bias".format(idx) not in state_dict:
state_dict["qwen2.layers.{}.input_layernorm.bias".format(idx)] = paddle.zeros(

Check warning on line 653 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L651-L653

Added lines #L651 - L653 were not covered by tests
shape=[self.hidden_size], dtype=paddle.get_default_dtype()
)
state_dict["qwen2.layers.{}.post_attention_layernorm.bias".format(idx)] = paddle.zeros(

Check warning on line 656 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L656

Added line #L656 was not covered by tests
[self.hidden_size], dtype=paddle.get_default_dtype()
)
unfused_state_dict["self_attn.q_proj.bias"] = paddle.zeros(

Check warning on line 659 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L659

Added line #L659 was not covered by tests
shape=[self.num_attention_heads * (self.hidden_size // self.num_attention_heads)],
dtype=paddle.get_default_dtype(),
)
unfused_state_dict["self_attn.k_proj.bias"] = paddle.zeros(

Check warning on line 663 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L663

Added line #L663 was not covered by tests
shape=[self.num_key_value_heads * (self.hidden_size // self.num_attention_heads)],
dtype=paddle.get_default_dtype(),
)
unfused_state_dict["self_attn.v_proj.bias"] = paddle.zeros(

Check warning on line 667 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L667

Added line #L667 was not covered by tests
shape=[self.num_key_value_heads * (self.hidden_size // self.num_attention_heads)],
dtype=paddle.get_default_dtype(),
)
unfused_state_dict["mlp.gate_proj.bias"] = paddle.zeros(

Check warning on line 671 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L671

Added line #L671 was not covered by tests
shape=[self.intermediate_size], dtype=paddle.get_default_dtype()
)
unfused_state_dict["mlp.up_proj.bias"] = paddle.zeros(

Check warning on line 674 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L674

Added line #L674 was not covered by tests
shape=[self.intermediate_size], dtype=paddle.get_default_dtype()
)
self.transformer_block.ln_biases[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.input_layernorm.bias".format(idx)])
)
Expand Down Expand Up @@ -657,6 +714,14 @@
self.transformer_block.ffn1_biases[idx].set_value(paddle.to_tensor(concated_ffn1_bias))

if self.shift_smooth_all_linears:
if self.use_fake_parameter:
if "qwen2.layers.{}.self_attn.o_proj.bias".format(idx) not in state_dict:
state_dict["qwen2.layers.{}.self_attn.o_proj.bias".format(idx)] = paddle.zeros(

Check warning on line 719 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L717-L719

Added lines #L717 - L719 were not covered by tests
[self.hidden_size], dtype=paddle.get_default_dtype()
)
state_dict["qwen2.layers.{}.mlp.down_proj.layer.bias".format(idx)] = paddle.zeros(

Check warning on line 722 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L722

Added line #L722 was not covered by tests
[self.hidden_size], dtype=paddle.get_default_dtype()
)
self.transformer_block.linear_biases[idx].set_value(
paddle.to_tensor(state_dict["qwen2.layers.{}.self_attn.o_proj.bias".format(idx)])
)
Expand All @@ -678,40 +743,63 @@
cache_scale_map_dict = scale_map_dict["cachekv_scale"]
# TODO(RichardWooSJTU): support multi-cards

act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json")
weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json")
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
act_scale_json_path = os.path.join(
self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json"
if not self.use_fake_parameter:
act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json")
weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json")
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
act_scale_json_path = os.path.join(

Check warning on line 750 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L746-L750

Added lines #L746 - L750 were not covered by tests
self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json"
)
weight_scale_json_path = os.path.join(

Check warning on line 753 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L753

Added line #L753 was not covered by tests
self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json"
)
act_scale_loader = ActScalesLoader(

Check warning on line 756 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L756

Added line #L756 was not covered by tests
act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers
)
weight_scale_json_path = os.path.join(
self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json"
weight_scales_loader = WeightScalesLoader(

Check warning on line 759 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L759

Added line #L759 was not covered by tests
weight_scale_json_path,
weight_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
concat_qkv=True,
concat_ffn1=True,
)
else:
act_scale_loader = EmptyActScale(act_scale_map_dict, num_of_layers=self.config.num_hidden_layers)
weight_scales_loader = EmptyWeightScale(

Check warning on line 768 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L767-L768

Added lines #L767 - L768 were not covered by tests
weight_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_head=self.num_attention_heads,
dim_head=self.hidden_size // self.num_attention_heads,
ffn_hidden_size=self.intermediate_size,
num_key_value_heads=self.num_key_value_heads,
mp_size=self.config.tensor_parallel_degree,
)
act_scale_loader = ActScalesLoader(
act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers
)
self.transformer_block.act_scales = act_scale_loader.scale
weight_scales_loader = WeightScalesLoader(
weight_scale_json_path,
weight_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
concat_qkv=True,
concat_ffn1=True,
)

if self.config.cachekv_int8_type == "static":
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json")
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
cache_scale_json_path = os.path.join(
self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json"
if not self.use_fake_parameter:
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json")
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
cache_scale_json_path = os.path.join(

Check warning on line 783 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L780-L783

Added lines #L780 - L783 were not covered by tests
self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json"
)
cache_scales_loader = CacheScaleLoader(

Check warning on line 786 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L786

Added line #L786 was not covered by tests
cache_scale_json_path,
cache_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
)
else:
cache_scales_loader = EmptyCacheScale(

Check warning on line 794 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L794

Added line #L794 was not covered by tests
cache_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_heads=self.num_attention_heads,
dim_heads=self.hidden_size // self.num_attention_heads,
is_channel_wise=False,
num_key_value_heads=self.num_key_value_heads,
mp_size=self.config.tensor_parallel_degree,
)
cache_scales_loader = CacheScaleLoader(
cache_scale_json_path,
cache_scale_map_dict,
num_of_layers=self.config.num_hidden_layers,
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
)

for k, v in cache_scales_loader.scale.items():
for i_layer, weight_scale in enumerate(v):
Expand Down
Loading