Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix hidden size and support mistral nemo #2215

Merged
merged 4 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lmdeploy/turbomind/deploy/source_model/deepseek_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def model_info(self):
'language_config'].get('model_type', None) == 'llama':
model_arg = model_arg['language_config'] # depseek-vl
num_layer = model_arg['num_hidden_layers']
hidden_units = model_arg['hidden_size']
norm_eps = model_arg.get('rms_norm_eps', 1e-06)
attn_head_num = model_arg.get('num_attention_heads', 32)
if 'num_key_value_heads' in model_arg:
Expand All @@ -67,8 +68,9 @@ def model_info(self):

return dict(num_layer=num_layer,
norm_eps=norm_eps,
attn_head_num=attn_head_num,
head_num=attn_head_num,
kv_head_num=kv_head_num,
hidden_units=hidden_units,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk=use_dynamic_ntk,
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/turbomind/deploy/source_model/glm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def tokenizer_info(self):
def model_info(self):
"""Read model info."""
config = self.config
hidden_units = config.get('hidden_size', None)
num_layer = config.get('num_hidden_layers', None)
num_layer = config.get('num_layers', num_layer)
norm_eps = config['layernorm_epsilon']
Expand All @@ -98,8 +99,9 @@ def model_info(self):
seq_length = config['seq_length']
return dict(num_layer=num_layer,
norm_eps=norm_eps,
attn_head_num=attn_head_num,
head_num=attn_head_num,
kv_head_num=kv_head_num,
hidden_units=hidden_units,
rope_theta=rope_theta,
max_position_embeddings=seq_length,
rotary_embedding=64,
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/turbomind/deploy/source_model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def model_info(self):
model_arg = json.load(f)['llm_config']
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
hidden_units = model_arg['hidden_size']
attn_head_num = model_arg['num_attention_heads']
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
Expand All @@ -80,7 +81,8 @@ def model_info(self):

return dict(num_layer=num_layer,
norm_eps=norm_eps,
attn_head_num=attn_head_num,
hidden_units=hidden_units,
head_num=attn_head_num,
kv_head_num=kv_head_num,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
Expand Down
4 changes: 3 additions & 1 deletion lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def model_info(self):
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
hidden_units = model_arg['hidden_size']
rope_theta = float(model_arg.get('rope_theta', 10000.0))
max_position_embeddings = int(
model_arg.get('max_position_embeddings', 0))
Expand Down Expand Up @@ -239,8 +240,9 @@ def model_info(self):
return dict(
num_layer=num_layer,
norm_eps=norm_eps,
attn_head_num=attn_head_num,
head_num=attn_head_num,
kv_head_num=kv_head_num,
hidden_units=hidden_units,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
original_max_position_embeddings=original_max_position_embeddings,
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/turbomind/deploy/source_model/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,18 +65,22 @@ def model_info(self):
params_path = osp.join(self.model_path, 'config.json')
with open(params_path) as f:
config = json.load(f)
hidden_units = config['hidden_size']
num_layer = config['num_hidden_layers']
norm_eps = config['layer_norm_epsilon']
rope_theta = float(config.get('rotary_emb_base', 10000.0))
if 'num_key_value_heads' in config:
kv_head_num = config['num_key_value_heads']
else:
kv_head_num = config['num_attention_heads']
attn_head_num = config['num_attention_heads']
seq_length = config['seq_length']
use_dynamic_ntk = int(config['use_dynamic_ntk'])
use_logn_attn = int(config['use_logn_attn'])
return dict(num_layer=num_layer,
norm_eps=norm_eps,
hidden_units=hidden_units,
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
head_num=attn_head_num,
kv_head_num=kv_head_num,
rope_theta=rope_theta,
max_position_embeddings=seq_length,
Expand Down
6 changes: 3 additions & 3 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TurbomindModelConfig:
tensor_para_size: int = None
head_num: int = None
kv_head_num: int = None
hidden_units: int = None
vocab_size: int = None
num_layer: int = None
inter_size: int = None
Expand Down Expand Up @@ -190,14 +191,13 @@ def get_config(self, cfg: TurbomindModelConfig) -> TurbomindModelConfig:
final_cfg.update(dict(start_id=bos_id, end_id=eos_id))
final_cfg.update(self.input_model.model_info())

# head_num, vocab_size
# vocab_size
for bin in self.input_model.bins():
emb = bin.tok_embeddings()
if emb is not None:
_vocab_size, dim = emb.shape
head_num = dim // cfg.size_per_head
break
final_cfg.update(dict(head_num=head_num, vocab_size=_vocab_size))
final_cfg.update(dict(vocab_size=_vocab_size))
return TurbomindModelConfig.from_dict(final_cfg, allow_none=True)

def export_config(self) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(int layer_idx,
size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
WeightType weight_type,
int group_size,
Expand All @@ -48,7 +49,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(int layer_idx,
head_num_(head_num),
kv_head_num_(kv_head_num),
size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head),
hidden_units_(hidden_units),
inter_size_(inter_size),
weight_type_(weight_type),
attn_bias_(attn_bias),
Expand Down Expand Up @@ -100,7 +101,7 @@ LlamaDecoderLayerWeight<T>::LlamaDecoderLayerWeight(int layer_idx,
self_attn_weights.qkv.type = weight_type;
self_attn_weights.qkv.group_size = group_size;

self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_;
self_attn_weights.output.input_dims = (head_num * size_per_head) / tensor_para_size_;
self_attn_weights.output.output_dims = hidden_units_;
self_attn_weights.output.type = weight_type;
self_attn_weights.output.group_size = group_size;
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaDecoderLayerWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ struct LlamaDecoderLayerWeight {
size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
WeightType weight_type,
int group_size,
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/models/llama/LlamaFfnLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class LlamaFfnLayer {
public:
LlamaFfnLayer(size_t head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
NcclParam tensor_para,
cudaStream_t stream,
Expand All @@ -42,7 +43,7 @@ class LlamaFfnLayer {
head_num_(head_num),
size_per_head_(size_per_head),
inter_size_(inter_size / tensor_para.world_size_),
hidden_units_(head_num * size_per_head),
hidden_units_(hidden_units),
stream_(stream),
linear_(linear),
allocator_(allocator),
Expand Down
4 changes: 3 additions & 1 deletion src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ template<typename T>
LlamaV2<T>::LlamaV2(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
size_t num_layer,
size_t vocab_size,
Expand Down Expand Up @@ -85,7 +86,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
rmsnorm_eps_(norm_eps),
start_id_(start_id),
end_id_(end_id),
hidden_units_(head_num * size_per_head),
hidden_units_(hidden_units),
local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num / tensor_para.world_size_),
weights_(weights),
Expand Down Expand Up @@ -137,6 +138,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
unified_decoder_.reset(new UnifiedDecoder<T>(head_num_,
kv_head_num,
size_per_head_,
hidden_units_,
inter_size_,
num_layer_,
attn_params,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class LlamaV2 {
LlamaV2(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
size_t num_layer,
size_t vocab_size,
Expand Down
4 changes: 3 additions & 1 deletion src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ template<typename T>
LlamaWeight<T>::LlamaWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
size_t vocab_size,
size_t num_layer,
Expand All @@ -37,7 +38,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
LoraParams lora_params,
size_t tensor_para_size,
size_t tensor_para_rank):
hidden_units_(head_num * size_per_head),
hidden_units_(hidden_units),
inter_size_(inter_size),
vocab_size_(vocab_size),
vocab_size_padded_(vocab_size),
Expand All @@ -56,6 +57,7 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
head_num,
kv_head_num,
size_per_head,
hidden_units_,
inter_size_,
weight_type_,
group_size,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ struct LlamaWeight {
LlamaWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
size_t vocab_size,
size_t num_layer,
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/models/llama/unified_attention_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class UnifiedAttentionLayer {
UnifiedAttentionLayer(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
LlamaAttentionParams attn_params,
NcclParam tensor_para,
LoraParams lora_params,
Expand All @@ -64,7 +65,7 @@ class UnifiedAttentionLayer {
int quant_policy):
head_num_(head_num),
size_per_head_(size_per_head),
hidden_units_(head_num * size_per_head),
hidden_units_(hidden_units),
local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num / tensor_para.world_size_),
head_n_rep_(head_num / kv_head_num),
Expand Down
2 changes: 2 additions & 0 deletions src/turbomind/models/llama/unified_decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void UnifiedDecoder<T>::initialize(const LlamaAttentionParams& attn_params,
attn_layer_ = new UnifiedAttentionLayer<T>(head_num_,
kv_head_num,
size_per_head_,
hidden_units_,
attn_params,
tensor_para_,
lora_params_,
Expand All @@ -51,6 +52,7 @@ void UnifiedDecoder<T>::initialize(const LlamaAttentionParams& attn_params,

ffn_layer_ = new LlamaFfnLayer<T>(head_num_,
size_per_head_,
hidden_units_,
inter_size_,
tensor_para_,
stream_,
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/models/llama/unified_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class UnifiedDecoder {
UnifiedDecoder(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t hidden_units,
size_t inter_size,
size_t num_layer,
const LlamaAttentionParams& attn_params,
Expand All @@ -84,7 +85,7 @@ class UnifiedDecoder {
head_num_(head_num),
size_per_head_(size_per_head),
inter_size_(inter_size),
hidden_units_(head_num * size_per_head),
hidden_units_(hidden_units),
num_layer_(num_layer),
rmsnorm_eps_(rmsnorm_eps),
tensor_para_(tensor_para),
Expand Down
3 changes: 3 additions & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
model_name_ = reader.Get("llama", "model_name");
head_num_ = reader.GetInteger("llama", "head_num");
kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0);
hidden_units_ = reader.GetInteger("llama", "hidden_units");
size_per_head_ = reader.GetInteger("llama", "size_per_head");
inter_size_ = reader.GetInteger("llama", "inter_size");
num_layer_ = reader.GetInteger("llama", "num_layer");
Expand Down Expand Up @@ -338,6 +339,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
auto llama = std::make_unique<ft::LlamaV2<T>>(head_num_,
kv_head_num_,
size_per_head_,
hidden_units_,
inter_size_,
num_layer_,
vocab_size_,
Expand Down Expand Up @@ -401,6 +403,7 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
shared_weights_[device_id] = std::make_shared<ft::LlamaWeight<T>>(head_num_,
kv_head_num_,
size_per_head_,
hidden_units_,
inter_size_,
vocab_size_,
num_layer_,
Expand Down
1 change: 1 addition & 0 deletions src/turbomind/triton_backend/llama/LlamaTritonModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct LlamaTritonModel: public AbstractTransformerModel {

size_t head_num_;
size_t kv_head_num_;
size_t hidden_units_;
size_t size_per_head_;
size_t inter_size_;
size_t num_layer_;
Expand Down
Loading