diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 30cc363c41..701d8a9a03 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -33,6 +33,25 @@ namespace turbomind { +static bool is_fuse_silu_act() +{ + static const bool value = [] { + const auto str = std::getenv("TM_FUSE_SILU_ACT"); + if (str) { + try { + auto v = std::stoi(str) != 0; + TM_LOG_INFO("TM_FUSE_SILU_ACT=%d", (int)v); + return v; + } + catch (...) { + } + } + TM_LOG_INFO("TM_FUSE_SILU_ACT=1"); + return true; + }(); + return value; +} + template LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, size_t head_num, @@ -91,9 +110,8 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, } } } - // fused_up_and_gate_ = weight_type_ == WeightType::kINT4 && ffn_weights.gating.lora.policy != LoraPolicy::kPlora; - fused_up_and_gate_ = true && ffn_weights.gating.lora.policy != LoraPolicy::kPlora; + fused_up_and_gate_ = ffn_weights.gating.lora.policy != LoraPolicy::kPlora; self_attn_weights.qkv.input_dims = hidden_units_; self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; @@ -119,7 +137,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2; ffn_weights.fused_gating_intermediate.type = weight_type; ffn_weights.fused_gating_intermediate.group_size = group_size; - ffn_weights.is_fused_silu = weight_type == WeightType::kINT4; + ffn_weights.is_fused_silu = weight_type == WeightType::kINT4 && is_fuse_silu_act(); ffn_weights.output.input_dims = inter_size_ / tensor_para_size_; ffn_weights.output.output_dims = hidden_units_;