Skip to content

Commit

Permalink
support tie_word_embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Aug 7, 2024
1 parent b409d0b commit b7dc61b
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 2 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def model_info(self):
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
tie_word_embeddings = model_arg.get('tie_word_embeddings', False)
rope_theta = float(model_arg.get('rope_theta', 10000.0))
max_position_embeddings = int(
model_arg.get('max_position_embeddings', 0))
Expand Down Expand Up @@ -233,6 +234,7 @@ def model_info(self):
norm_eps=norm_eps,
attn_head_num=attn_head_num,
kv_head_num=kv_head_num,
tie_word_embeddings=tie_word_embeddings,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
original_max_position_embeddings=original_max_position_embeddings,
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class TurbomindModelConfig:
lora_max_wo_r: int = 0
lora_rank_pattern: str = ''
lora_scale_pattern: str = ''
tie_word_embeddings: bool = False

@classmethod
def from_dict(cls, env, allow_none=False):
Expand Down
13 changes: 11 additions & 2 deletions src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
WeightType weight_type,
int group_size,
LoraParams lora_params,
bool tie_word_embeddings,
size_t tensor_para_size,
size_t tensor_para_rank):
hidden_units_(head_num * size_per_head),
Expand All @@ -41,6 +42,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
vocab_size_padded_(vocab_size),
num_layer_(num_layer),
weight_type_(weight_type),
tie_word_embeddings_(tie_word_embeddings),
tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank)
{
Expand Down Expand Up @@ -71,7 +73,9 @@ LlamaWeight<T>::~LlamaWeight()
{
cudaFree((void*)pre_decoder_embedding_table);
cudaFree((void*)output_norm_weight);
cudaFree((void*)post_decoder_embedding_kernel);
if (pre_decoder_embedding_table != post_decoder_embedding_kernel) {
cudaFree((void*)post_decoder_embedding_kernel);
}

pre_decoder_embedding_table = nullptr;
post_decoder_embedding_kernel = nullptr;
Expand All @@ -87,7 +91,12 @@ void LlamaWeight<T>::mallocWeights()
FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0);
deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_ / tensor_para_size_);
deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_);
if (!tie_word_embeddings_) {
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_);
}
else {
post_decoder_embedding_kernel = pre_decoder_embedding_table;
}
}

template<typename T>
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/models/llama/LlamaWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct LlamaWeight {
WeightType weight_type,
int group_size,
LoraParams lora_params,
bool tie_word_embeddings,
size_t tensor_para_size,
size_t tensor_para_rank);

Expand All @@ -59,6 +60,7 @@ struct LlamaWeight {
private:
void mallocWeights();

bool tie_word_embeddings_{false};
size_t hidden_units_;
size_t inter_size_;
size_t vocab_size_;
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
end_id_ = reader.GetInteger("llama", "end_id");
use_context_fmha_ = reader.GetInteger("llama", "use_context_fmha", 1);
cache_block_seq_len_ = reader.GetInteger("llama", "cache_block_seq_len", 0);
tie_word_embeddings_ = reader.GetBoolean("llama", "tie_word_embeddings", false);

attn_bias_ = reader.GetInteger("llama", "attn_bias", 0);
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
Expand Down Expand Up @@ -416,6 +417,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
weight_type_,
group_size_,
lora_params_,
tie_word_embeddings_,
tensor_para_size_,
tensor_para_rank);
// model inited with model_dir
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ struct LlamaTritonModel: public AbstractTransformerModel {
int quant_policy_;
int group_size_;
turbomind::LoraParams lora_params_;
bool tie_word_embeddings_;

// shared weights for each device
std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_;
Expand Down

0 comments on commit b7dc61b

Please sign in to comment.