diff --git a/lmdeploy/turbomind/deploy/source_model/deepseek_vl.py b/lmdeploy/turbomind/deploy/source_model/deepseek_vl.py index f17a3398b3..2b60454767 100644 --- a/lmdeploy/turbomind/deploy/source_model/deepseek_vl.py +++ b/lmdeploy/turbomind/deploy/source_model/deepseek_vl.py @@ -47,6 +47,7 @@ def model_info(self): 'language_config'].get('model_type', None) == 'llama': model_arg = model_arg['language_config'] # depseek-vl num_layer = model_arg['num_hidden_layers'] + hidden_units = model_arg['hidden_size'] norm_eps = model_arg.get('rms_norm_eps', 1e-06) attn_head_num = model_arg.get('num_attention_heads', 32) if 'num_key_value_heads' in model_arg: @@ -67,8 +68,9 @@ def model_info(self): return dict(num_layer=num_layer, norm_eps=norm_eps, - attn_head_num=attn_head_num, + head_num=attn_head_num, kv_head_num=kv_head_num, + hidden_units=hidden_units, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, use_dynamic_ntk=use_dynamic_ntk, diff --git a/lmdeploy/turbomind/deploy/source_model/glm4.py b/lmdeploy/turbomind/deploy/source_model/glm4.py index 2c69d5d0da..1c26e0649a 100644 --- a/lmdeploy/turbomind/deploy/source_model/glm4.py +++ b/lmdeploy/turbomind/deploy/source_model/glm4.py @@ -85,6 +85,7 @@ def tokenizer_info(self): def model_info(self): """Read model info.""" config = self.config + hidden_units = config.get('hidden_size', None) num_layer = config.get('num_hidden_layers', None) num_layer = config.get('num_layers', num_layer) norm_eps = config['layernorm_epsilon'] @@ -98,8 +99,9 @@ def model_info(self): seq_length = config['seq_length'] return dict(num_layer=num_layer, norm_eps=norm_eps, - attn_head_num=attn_head_num, + head_num=attn_head_num, kv_head_num=kv_head_num, + hidden_units=hidden_units, rope_theta=rope_theta, max_position_embeddings=seq_length, rotary_embedding=64, diff --git a/lmdeploy/turbomind/deploy/source_model/internvl.py b/lmdeploy/turbomind/deploy/source_model/internvl.py index d7f446da93..83161adb15 100644 --- a/lmdeploy/turbomind/deploy/source_model/internvl.py +++ b/lmdeploy/turbomind/deploy/source_model/internvl.py @@ -61,6 +61,7 @@ def model_info(self): model_arg = json.load(f)['llm_config'] num_layer = model_arg['num_hidden_layers'] norm_eps = model_arg['rms_norm_eps'] + hidden_units = model_arg['hidden_size'] attn_head_num = model_arg['num_attention_heads'] if 'num_key_value_heads' in model_arg: kv_head_num = model_arg['num_key_value_heads'] @@ -80,7 +81,8 @@ def model_info(self): return dict(num_layer=num_layer, norm_eps=norm_eps, - attn_head_num=attn_head_num, + hidden_units=hidden_units, + head_num=attn_head_num, kv_head_num=kv_head_num, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py index fb94854a45..a67e3ee4e4 100644 --- a/lmdeploy/turbomind/deploy/source_model/llama.py +++ b/lmdeploy/turbomind/deploy/source_model/llama.py @@ -207,6 +207,7 @@ def model_info(self): kv_head_num = model_arg['num_key_value_heads'] else: kv_head_num = model_arg['num_attention_heads'] + hidden_units = model_arg['hidden_size'] rope_theta = float(model_arg.get('rope_theta', 10000.0)) max_position_embeddings = int( model_arg.get('max_position_embeddings', 0)) @@ -239,8 +240,9 @@ def model_info(self): return dict( num_layer=num_layer, norm_eps=norm_eps, - attn_head_num=attn_head_num, + head_num=attn_head_num, kv_head_num=kv_head_num, + hidden_units=hidden_units, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, original_max_position_embeddings=original_max_position_embeddings, diff --git a/lmdeploy/turbomind/deploy/source_model/qwen.py b/lmdeploy/turbomind/deploy/source_model/qwen.py index 311f8e0a85..4e87057e62 100644 --- a/lmdeploy/turbomind/deploy/source_model/qwen.py +++ b/lmdeploy/turbomind/deploy/source_model/qwen.py @@ -65,6 +65,7 @@ def model_info(self): params_path = osp.join(self.model_path, 'config.json') with open(params_path) as f: config = json.load(f) + hidden_units = config['hidden_size'] num_layer = config['num_hidden_layers'] norm_eps = config['layer_norm_epsilon'] rope_theta = float(config.get('rotary_emb_base', 10000.0)) @@ -72,11 +73,14 @@ def model_info(self): kv_head_num = config['num_key_value_heads'] else: kv_head_num = config['num_attention_heads'] + attn_head_num = config['num_attention_heads'] seq_length = config['seq_length'] use_dynamic_ntk = int(config['use_dynamic_ntk']) use_logn_attn = int(config['use_logn_attn']) return dict(num_layer=num_layer, norm_eps=norm_eps, + hidden_units=hidden_units, + head_num=attn_head_num, kv_head_num=kv_head_num, rope_theta=rope_theta, max_position_embeddings=seq_length, diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index f969055759..87983b2551 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -41,6 +41,7 @@ class TurbomindModelConfig: tensor_para_size: int = None head_num: int = None kv_head_num: int = None + hidden_units: int = None vocab_size: int = None num_layer: int = None inter_size: int = None @@ -190,14 +191,13 @@ def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig: final_cfg.update(dict(start_id=bos_id, end_id=eos_id)) final_cfg.update(self.input_model.model_info()) - # head_num, vocab_size + # vocab_size for bin in self.input_model.bins(): emb = bin.tok_embeddings() if emb is not None: _vocab_size, dim = emb.shape - head_num = dim // cfg.size_per_head break - final_cfg.update(dict(head_num=head_num, vocab_size=_vocab_size)) + final_cfg.update(dict(vocab_size=_vocab_size)) return TurbomindModelConfig.from_dict(final_cfg, allow_none=True) def export_config(self) -> None: diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index 30cc363c41..61326be274 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -38,6 +38,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, WeightType weight_type, int group_size, @@ -48,7 +49,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, head_num_(head_num), kv_head_num_(kv_head_num), size_per_head_(size_per_head), - hidden_units_(head_num * size_per_head), + hidden_units_(hidden_units), inter_size_(inter_size), weight_type_(weight_type), attn_bias_(attn_bias), @@ -100,7 +101,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(int layer_idx, self_attn_weights.qkv.type = weight_type; self_attn_weights.qkv.group_size = group_size; - self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; + self_attn_weights.output.input_dims = (head_num * size_per_head) / tensor_para_size_; self_attn_weights.output.output_dims = hidden_units_; self_attn_weights.output.type = weight_type; self_attn_weights.output.group_size = group_size; diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index 05600d0f56..07bc65cc5c 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -34,6 +34,7 @@ struct LlamaDecoderLayerWeight { size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, WeightType weight_type, int group_size, diff --git a/src/turbomind/models/llama/LlamaFfnLayer.h b/src/turbomind/models/llama/LlamaFfnLayer.h index 97465ad6d1..db5a94380c 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.h +++ b/src/turbomind/models/llama/LlamaFfnLayer.h @@ -33,6 +33,7 @@ class LlamaFfnLayer { public: LlamaFfnLayer(size_t head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, NcclParam tensor_para, cudaStream_t stream, @@ -42,7 +43,7 @@ class LlamaFfnLayer { head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size / tensor_para.world_size_), - hidden_units_(head_num * size_per_head), + hidden_units_(hidden_units), stream_(stream), linear_(linear), allocator_(allocator), diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index f9f7922ff6..f3e9493698 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -54,6 +54,7 @@ template LlamaV2::LlamaV2(size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, size_t num_layer, size_t vocab_size, @@ -85,7 +86,7 @@ LlamaV2::LlamaV2(size_t head_num, rmsnorm_eps_(norm_eps), start_id_(start_id), end_id_(end_id), - hidden_units_(head_num * size_per_head), + hidden_units_(hidden_units), local_head_num_(head_num / tensor_para.world_size_), local_kv_head_num_(kv_head_num / tensor_para.world_size_), weights_(weights), @@ -137,6 +138,7 @@ void LlamaV2::initialize(const LlamaAttentionParams& attn_params, unified_decoder_.reset(new UnifiedDecoder(head_num_, kv_head_num, size_per_head_, + hidden_units_, inter_size_, num_layer_, attn_params, diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index adf6c4f9d4..b0a19f4239 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -57,6 +57,7 @@ class LlamaV2 { LlamaV2(size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, size_t num_layer, size_t vocab_size, diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 507f1a6f32..18ecc2507d 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -28,6 +28,7 @@ template LlamaWeight::LlamaWeight(size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, size_t vocab_size, size_t num_layer, @@ -37,7 +38,7 @@ LlamaWeight::LlamaWeight(size_t head_num, LoraParams lora_params, size_t tensor_para_size, size_t tensor_para_rank): - hidden_units_(head_num * size_per_head), + hidden_units_(hidden_units), inter_size_(inter_size), vocab_size_(vocab_size), vocab_size_padded_(vocab_size), @@ -56,6 +57,7 @@ LlamaWeight::LlamaWeight(size_t head_num, head_num, kv_head_num, size_per_head, + hidden_units_, inter_size_, weight_type_, group_size, diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index a180204ae2..f71e03715a 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -32,6 +32,7 @@ struct LlamaWeight { LlamaWeight(size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, size_t vocab_size, size_t num_layer, diff --git a/src/turbomind/models/llama/unified_attention_layer.h b/src/turbomind/models/llama/unified_attention_layer.h index f632830e5f..58bba45896 100644 --- a/src/turbomind/models/llama/unified_attention_layer.h +++ b/src/turbomind/models/llama/unified_attention_layer.h @@ -53,6 +53,7 @@ class UnifiedAttentionLayer { UnifiedAttentionLayer(size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, LlamaAttentionParams attn_params, NcclParam tensor_para, LoraParams lora_params, @@ -64,7 +65,7 @@ class UnifiedAttentionLayer { int quant_policy): head_num_(head_num), size_per_head_(size_per_head), - hidden_units_(head_num * size_per_head), + hidden_units_(hidden_units), local_head_num_(head_num / tensor_para.world_size_), local_kv_head_num_(kv_head_num / tensor_para.world_size_), head_n_rep_(head_num / kv_head_num), diff --git a/src/turbomind/models/llama/unified_decoder.cc b/src/turbomind/models/llama/unified_decoder.cc index e29d42680d..db9482fb48 100644 --- a/src/turbomind/models/llama/unified_decoder.cc +++ b/src/turbomind/models/llama/unified_decoder.cc @@ -39,6 +39,7 @@ void UnifiedDecoder::initialize(const LlamaAttentionParams& attn_params, attn_layer_ = new UnifiedAttentionLayer(head_num_, kv_head_num, size_per_head_, + hidden_units_, attn_params, tensor_para_, lora_params_, @@ -51,6 +52,7 @@ void UnifiedDecoder::initialize(const LlamaAttentionParams& attn_params, ffn_layer_ = new LlamaFfnLayer(head_num_, size_per_head_, + hidden_units_, inter_size_, tensor_para_, stream_, diff --git a/src/turbomind/models/llama/unified_decoder.h b/src/turbomind/models/llama/unified_decoder.h index 0a80b415d5..b2acbe1b44 100644 --- a/src/turbomind/models/llama/unified_decoder.h +++ b/src/turbomind/models/llama/unified_decoder.h @@ -63,6 +63,7 @@ class UnifiedDecoder { UnifiedDecoder(size_t head_num, size_t kv_head_num, size_t size_per_head, + size_t hidden_units, size_t inter_size, size_t num_layer, const LlamaAttentionParams& attn_params, @@ -84,7 +85,7 @@ class UnifiedDecoder { head_num_(head_num), size_per_head_(size_per_head), inter_size_(inter_size), - hidden_units_(head_num * size_per_head), + hidden_units_(hidden_units), num_layer_(num_layer), rmsnorm_eps_(rmsnorm_eps), tensor_para_(tensor_para), diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index d025935bf7..cb9ea29f48 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -199,6 +199,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, model_name_ = reader.Get("llama", "model_name"); head_num_ = reader.GetInteger("llama", "head_num"); kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0); + hidden_units_ = reader.GetInteger("llama", "hidden_units"); size_per_head_ = reader.GetInteger("llama", "size_per_head"); inter_size_ = reader.GetInteger("llama", "inter_size"); num_layer_ = reader.GetInteger("llama", "num_layer"); @@ -338,6 +339,7 @@ std::unique_ptr> LlamaTritonModel::createSh auto llama = std::make_unique>(head_num_, kv_head_num_, size_per_head_, + hidden_units_, inter_size_, num_layer_, vocab_size_, @@ -401,6 +403,7 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) shared_weights_[device_id] = std::make_shared>(head_num_, kv_head_num_, size_per_head_, + hidden_units_, inter_size_, vocab_size_, num_layer_, diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index fc7cfca0f2..02736e0f23 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -91,6 +91,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { size_t head_num_; size_t kv_head_num_; + size_t hidden_units_; size_t size_per_head_; size_t inter_size_; size_t num_layer_;