Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split token_embs and lm_head weights #2252

Merged
merged 20 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,12 +333,16 @@ def pad_weight(tensor):

if emb is not None:
emb = pad_weight(emb)
self.export_weight(emb, 'tok_embeddings.weight')
# try split along hidden dim
if emb.shape[1] % self.cfg.tensor_para_size == 0:
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
self.save_split(emb, 'tok_embeddings.weight', 1)
else:
self.export_weight(emb, 'tok_embeddings.weight')
if norm_weight is not None:
self.export_weight(norm_weight, 'norm.weight')
if output_weight is not None:
output_weight = pad_weight(output_weight)
self.export_weight(output_weight, 'output.weight')
self.save_split(output_weight, 'output.weight', 0)

def export_transformer_block(self, bin: BaseReader, i: int) -> None:
"""Export transformer block."""
Expand Down
4 changes: 2 additions & 2 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,8 @@ void LlamaBatch<T>::AllocateBuffer(size_t batch_size, size_t session_len, int ca
context_decoder_output_buf_, sizeof(T) * max_forward_token_num_ * hidden_units, false);
}

context_decoder_input_buf_ =
(T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_forward_token_num_ * hidden_units, false);
context_decoder_input_buf_ = (T*)allocator_->reMalloc(
context_decoder_input_buf_, sizeof(T) * max_forward_token_num_ * hidden_units * 2, false);
irexyc marked this conversation as resolved.
Show resolved Hide resolved
context_decoder_ids_buf_ =
(int*)allocator_->reMalloc(context_decoder_ids_buf_, sizeof(int) * max_forward_token_num_, false);

Expand Down
68 changes: 54 additions & 14 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,59 @@ void LlamaV2<T>::forwardUnified(T* out,
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

invokeInputIdsEmbeddingLookupPosEncoding(decoder_input,
nullptr, // processed somewhere else
weights_->pre_decoder_embedding_table,
static_cast<T*>(nullptr),
pPromptTuningParam<T>{},
input_ids,
0, // only used for position encoding
token_num,
token_num,
1,
hidden_units_,
stream_);
if (tensor_para_.world_size_ == 1) {
invokeInputIdsEmbeddingLookupPosEncoding(decoder_input,
nullptr, // processed somewhere else
weights_->pre_decoder_embedding_table,
static_cast<T*>(nullptr),
pPromptTuningParam<T>{},
input_ids,
0, // only used for position encoding
token_num,
token_num,
1,
hidden_units_,
stream_);
}
else {
const size_t local_hidden_units = hidden_units_ / tensor_para_.world_size_;
T* local_decoder_input = decoder_input + token_num * hidden_units_; // workspace
invokeInputIdsEmbeddingLookupPosEncoding(local_decoder_input
+ tensor_para_.rank_ * token_num * local_hidden_units,
nullptr, // processed somewhere else
weights_->pre_decoder_embedding_table,
static_cast<T*>(nullptr),
pPromptTuningParam<T>{},
input_ids,
0, // only used for position encoding
token_num,
token_num,
1,
local_hidden_units,
stream_);

irexyc marked this conversation as resolved.
Show resolved Hide resolved
{
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllGather(local_decoder_input, // send_buf
local_decoder_input, // recv_buf
token_num * hidden_units_ / tensor_para_.world_size_, // data_size
tensor_para_.rank_,
tensor_para_,
stream_);

sync_check_cuda_error();
}

invokeInPlaceTranspose102(decoder_input,
local_decoder_input,
tensor_para_.world_size_,
token_num,
local_hidden_units,
false,
stream_);

sync_check_cuda_error();
}

count_and_fix(decoder_input, token_num * hidden_units_, "embedding", 1);

Expand Down Expand Up @@ -299,8 +340,7 @@ void LlamaV2<T>::postDecodeEmbedding(float* logits, float* local_logits, const T
batch_size,
hidden_units_, // k
&alpha,
weights_->post_decoder_embedding_kernel
+ tensor_para_.rank_ * local_vocab_size * hidden_units_,
weights_->post_decoder_embedding_kernel,
data_type,
hidden_units_, // k
decoder_output,
Expand Down
35 changes: 21 additions & 14 deletions src/turbomind/models/llama/LlamaWeight.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,13 @@ LlamaWeight<T>::~LlamaWeight()
template<typename T>
void LlamaWeight<T>::mallocWeights()
{
deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_);
FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0);
size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ?
vocab_size_padded_ * hidden_units_ / tensor_para_size_ :
vocab_size_padded_ * hidden_units_;
deviceMalloc((T**)&pre_decoder_embedding_table, embedding_table_size);
deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_);
}

template<typename T>
Expand All @@ -102,15 +106,16 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
}
dir_path += '/';

loadWeightFromBin((T*)pre_decoder_embedding_table,
{vocab_size_padded_ * hidden_units_},
dir_path + "tok_embeddings.weight",
model_file_type);
size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ?
vocab_size_padded_ * hidden_units_ / tensor_para_size_ :
vocab_size_padded_ * hidden_units_;
loadWeightFromBin(
(T*)pre_decoder_embedding_table, {embedding_table_size}, dir_path + "tok_embeddings.weight", model_file_type);

loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);

loadWeightFromBin((T*)post_decoder_embedding_kernel,
{hidden_units_ * vocab_size_padded_},
{hidden_units_ * vocab_size_padded_ / tensor_para_size_},
dir_path + "output.weight",
model_file_type);

Expand All @@ -123,20 +128,22 @@ template<typename T>
TensorMap LlamaWeight<T>::getParams()
{
TensorMap output;
FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0);

output.insert("tok_embeddings.weight",
Tensor{MEMORY_GPU,
getTensorType<T>(),
{vocab_size_padded_ * hidden_units_ * sizeof(T)},
pre_decoder_embedding_table});
size_t embedding_table_size = (hidden_units_ % tensor_para_size_ == 0) ?
vocab_size_padded_ * hidden_units_ / tensor_para_size_ :
vocab_size_padded_ * hidden_units_;
output.insert(
"tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight",
Tensor{MEMORY_GPU, getTensorType<T>(), {embedding_table_size * sizeof(T)}, pre_decoder_embedding_table});

output.insert("norm.weight",
Tensor{MEMORY_GPU, getTensorType<T>(), {hidden_units_ * sizeof(T)}, output_norm_weight});

output.insert("output.weight",
output.insert("output." + std::to_string(tensor_para_rank_) + ".weight",
Tensor{MEMORY_GPU,
getTensorType<T>(),
{hidden_units_ * vocab_size_padded_ * sizeof(T)},
{hidden_units_ * vocab_size_padded_ * sizeof(T) / tensor_para_size_},
post_decoder_embedding_kernel});

// transformer layers
Expand Down
35 changes: 27 additions & 8 deletions src/turbomind/utils/memory_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -874,23 +874,42 @@ __global__ void transpose102(T_OUT* dst, T_IN* src, const int dim0, const int di
}

template<typename T>
void invokeInPlaceTranspose102(T* data, T* workspace, const int dim0, const int dim1, const int dim2)
void invokeInPlaceTranspose102(
T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream)
{
// copy data to workspace, and then transpose from workspace to data
// Note that this kernel is used for pre-processing and not very efficient.
cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2);
transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2);
const size_t count = dim0 * dim1 * dim2;
if (copy) {
cudaAutoCpy(workspace, data, count, stream);
}
const int block = 512;
const int grid = std::min((count + block - 1) / block, 8192ul);
transpose102<<<grid, block, 0, stream>>>(data, workspace, dim0, dim1, dim2);
}

#ifdef ENABLE_FP8
template void invokeInPlaceTranspose102(
__nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const int dim0, const int dim1, const int dim2);
template void invokeInPlaceTranspose102(__nv_fp8_e4m3* data,
__nv_fp8_e4m3* workspace,
const int dim0,
const int dim1,
const int dim2,
bool copy,
cudaStream_t stream);
#endif // ENABLE_FP8
#ifdef ENABLE_BF16
template void invokeInPlaceTranspose102(
__nv_bfloat16* data, __nv_bfloat16* workspace, const int dim0, const int dim1, const int dim2);
template void invokeInPlaceTranspose102(__nv_bfloat16* data,
__nv_bfloat16* workspace,
const int dim0,
const int dim1,
const int dim2,
bool copy,
cudaStream_t stream);
#endif // ENABLE_BF16
template void invokeInPlaceTranspose102(float* data, float* workspace, const int dim0, const int dim1, const int dim2);
template void invokeInPlaceTranspose102(
half* data, half* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream);
template void invokeInPlaceTranspose102(
float* data, float* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream);

template<typename T>
void __global__ multiplyScale(T* tensor, float scale, const size_t size)
Expand Down
3 changes: 2 additions & 1 deletion src/turbomind/utils/memory_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ template<typename T>
void invokeInPlaceTranspose0213(T* data, T* workspace, const int dim0, const int dim1, const int dim2, const int dim3);

template<typename T>
void invokeInPlaceTranspose102(T* data, T* workspace, const int dim0, const int dim1, const int dim2);
void invokeInPlaceTranspose102(
T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy = true, cudaStream_t stream = 0);

template<typename T>
void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream);
Expand Down
Loading