Skip to content

Commit

Permalink
add environment variable to turn off silu fusion (#2343)
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz authored Aug 20, 2024
1 parent 8ed696c commit 6f81013
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@

namespace turbomind {

static bool is_fuse_silu_act()
{
static const bool value = [] {
const auto str = std::getenv("TM_FUSE_SILU_ACT");
if (str) {
try {
auto v = std::stoi(str) != 0;
TM_LOG_INFO("TM_FUSE_SILU_ACT=%d", (int)v);
return v;
}
catch (...) {
}
}
TM_LOG_INFO("TM_FUSE_SILU_ACT=1");
return true;
}();
return value;
}

template<typename T>
LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(int layer_idx,
size_t head_num,
Expand Down Expand Up @@ -91,9 +110,8 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(int layer_idx,
}
}
}
// fused_up_and_gate_ = weight_type_ == WeightType::kINT4 && ffn_weights.gating.lora.policy != LoraPolicy::kPlora;

fused_up_and_gate_ = true && ffn_weights.gating.lora.policy != LoraPolicy::kPlora;
fused_up_and_gate_ = ffn_weights.gating.lora.policy != LoraPolicy::kPlora;

self_attn_weights.qkv.input_dims = hidden_units_;
self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_;
Expand All @@ -119,7 +137,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(int layer_idx,
ffn_weights.fused_gating_intermediate.output_dims = inter_size_ / tensor_para_size_ * 2;
ffn_weights.fused_gating_intermediate.type = weight_type;
ffn_weights.fused_gating_intermediate.group_size = group_size;
ffn_weights.is_fused_silu = weight_type == WeightType::kINT4;
ffn_weights.is_fused_silu = weight_type == WeightType::kINT4 && is_fuse_silu_act();

ffn_weights.output.input_dims = inter_size_ / tensor_para_size_;
ffn_weights.output.output_dims = hidden_units_;
Expand Down

0 comments on commit 6f81013

Please sign in to comment.