diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py index 08b8f214f..63d60908e 100644 --- a/lmdeploy/turbomind/deploy/source_model/llama.py +++ b/lmdeploy/turbomind/deploy/source_model/llama.py @@ -199,6 +199,7 @@ def model_info(self): kv_head_num = model_arg['num_key_value_heads'] else: kv_head_num = model_arg['num_attention_heads'] + tie_word_embeddings = model_arg.get('tie_word_embeddings', False) rope_theta = float(model_arg.get('rope_theta', 10000.0)) max_position_embeddings = int( model_arg.get('max_position_embeddings', 0)) @@ -233,6 +234,7 @@ def model_info(self): norm_eps=norm_eps, attn_head_num=attn_head_num, kv_head_num=kv_head_num, + tie_word_embeddings=tie_word_embeddings, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 2e54a9fa9..0c70f6464 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -79,6 +79,7 @@ class TurbomindModelConfig: lora_max_wo_r: int = 0 lora_rank_pattern: str = '' lora_scale_pattern: str = '' + tie_word_embeddings: bool = False @classmethod def from_dict(cls, env, allow_none=False): diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 78e36aaff..c59eaa684 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -33,6 +33,7 @@ LlamaWeight::LlamaWeight(size_t head_num, WeightType weight_type, int group_size, LoraParams lora_params, + bool tie_word_embeddings, size_t tensor_para_size, size_t tensor_para_rank): hidden_units_(head_num * size_per_head), @@ -41,6 +42,7 @@ LlamaWeight::LlamaWeight(size_t head_num, vocab_size_padded_(vocab_size), num_layer_(num_layer), weight_type_(weight_type), + tie_word_embeddings_(tie_word_embeddings), tensor_para_size_(tensor_para_size), tensor_para_rank_(tensor_para_rank) { @@ -71,7 +73,9 @@ LlamaWeight::~LlamaWeight() { cudaFree((void*)pre_decoder_embedding_table); cudaFree((void*)output_norm_weight); - cudaFree((void*)post_decoder_embedding_kernel); + if (pre_decoder_embedding_table != post_decoder_embedding_kernel) { + cudaFree((void*)post_decoder_embedding_kernel); + } pre_decoder_embedding_table = nullptr; post_decoder_embedding_kernel = nullptr; @@ -87,7 +91,12 @@ void LlamaWeight::mallocWeights() FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0); deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_ / tensor_para_size_); deviceMalloc((T**)&output_norm_weight, hidden_units_); - deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_); + if (!tie_word_embeddings_) { + deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_); + } + else { + post_decoder_embedding_kernel = pre_decoder_embedding_table; + } } template diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index 65eb986d8..dbd6d6685 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -39,6 +39,7 @@ struct LlamaWeight { WeightType weight_type, int group_size, LoraParams lora_params, + bool tie_word_embeddings, size_t tensor_para_size, size_t tensor_para_rank); @@ -59,6 +60,7 @@ struct LlamaWeight { private: void mallocWeights(); + bool tie_word_embeddings_{false}; size_t hidden_units_; size_t inter_size_; size_t vocab_size_; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 87fd2cdf5..b5bddb1d4 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -206,6 +206,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, end_id_ = reader.GetInteger("llama", "end_id"); use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1); cache_block_seq_len_ = reader.GetInteger("llama", "cache_block_seq_len", 0); + tie_word_embeddings_ = reader.GetBoolean("llama", "tie_word_embeddings", false); attn_bias_ = reader.GetInteger("llama", "attn_bias", 0); quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); @@ -416,6 +417,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) weight_type_, group_size_, lora_params_, + tie_word_embeddings_, tensor_para_size_, tensor_para_rank); // model inited with model_dir diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index c0a0ebf3a..a1b934184 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -102,6 +102,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { int quant_policy_; int group_size_; turbomind::LoraParams lora_params_; + bool tie_word_embeddings_; // shared weights for each device std::vector>> shared_weights_;