diff --git a/lmdeploy/turbomind/deploy/target_model/base.py b/lmdeploy/turbomind/deploy/target_model/base.py index 6b839876f..b2d69cdd6 100644 --- a/lmdeploy/turbomind/deploy/target_model/base.py +++ b/lmdeploy/turbomind/deploy/target_model/base.py @@ -231,12 +231,12 @@ def pad_weight(tensor): if emb is not None: emb = pad_weight(emb) - self.export_weight(emb, 'tok_embeddings.weight') + self.save_split(emb, 'tok_embeddings.weight', 1) if norm_weight is not None: self.export_weight(norm_weight, 'norm.weight') if output_weight is not None: output_weight = pad_weight(output_weight) - self.export_weight(output_weight, 'output.weight') + self.save_split(output_weight, 'output.weight', 0) def export_transformer_block(self, bin: BaseReader, i: int) -> None: """Export transformer block.""" diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index 3b05e5717..7f74e904d 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -240,12 +240,13 @@ def _from_workspace(self, model_path: str, _cfg = yaml.safe_load(f) cfg = TurbomindModelConfig.from_dict(_cfg) - # check whether input tp is valid - self.gpu_count = engine_config.tp - if cfg.tensor_para_size != 1 and \ - self.gpu_count != cfg.tensor_para_size: - logger.info(f'found tp={cfg.tensor_para_size} in config.yaml.') - self.gpu_count = cfg.tensor_para_size + # always use tp in converted model (config.yaml) + if cfg.tensor_para_size != engine_config.tp: + logger.warning( + 'tp in engine_config is different from in config.yaml' + f'({config_path}), {engine_config.tp} vs ' + f'{cfg.tensor_para_size}, using tp={cfg.tensor_para_size}') + self.gpu_count = cfg.tensor_para_size engine_config.tp = self.gpu_count self._postprocess_config(cfg, engine_config) diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index c2f39a096..3beba40c6 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -281,98 +281,28 @@ void getWeightTensor(LlamaDenseWeight& weights, bool bias, const std::string& } template -void loadWeights(LlamaDenseWeight& w, - std::string prefix, - int rank, - FtCudaDataType model_file_type, - size_t tensor_para_size, - int slice_dim = 0, - std::vector slice_shape = {}) +void loadWeights( + LlamaDenseWeight& w, std::string prefix, int rank, FtCudaDataType model_file_type, size_t tensor_para_size) { - auto max_prefix = prefix + "." + std::to_string(tensor_para_size - 1); - const auto type = model_file_type; - - bool enable_slice = true; - // Disable slice if tensor param rank is 1 - if (tensor_para_size <= 1) { - enable_slice = false; - } - else { - // Disable slice if weight has already been sliced - if (std::filesystem::exists(max_prefix + ".weight") || std::filesystem::exists(max_prefix + ".qweight")) { - TM_LOG_DEBUG("TP weight exists. Disable runtime TP."); - enable_slice = false; - } + auto weight_file = prefix + "." + std::to_string(tensor_para_size - 1) + ".weight"; + auto qweight_file = prefix + "." + std::to_string(tensor_para_size - 1) + ".qweight"; + if (!std::filesystem::exists(weight_file) && !std::filesystem::exists(qweight_file)) { + TM_LOG_ERROR("%s and %s does not exist", weight_file.c_str(), qweight_file.c_str()); + FT_CHECK(false); } - size_t dim0 = w.input_dims; - size_t dim1 = w.output_dims; - if (enable_slice) { - // multiple tp size for slice stride - if (slice_dim == 0) { - dim0 = dim0 * tensor_para_size; - if (slice_shape.size() == 0) { - slice_shape = {dim0}; - } - } - else { - dim1 = dim1 * tensor_para_size; - if (slice_shape.size() == 0) { - slice_shape = {dim1}; - } - } + prefix += "." + std::to_string(rank); - prefix += "." + std::to_string(0); - } - else { - prefix += "." + std::to_string(rank); - } + size_t dim0 = w.input_dims; + size_t dim1 = w.output_dims; + const auto type = model_file_type; if (w.bias) { - std::vector bias_slices{}; - if (enable_slice) { - if (slice_dim == 1) { - size_t start = 0; - ConcateSlice slice0{{{0, 1}}}; - ConcateSlice slice1{{{}}}; - for (auto len : slice_shape) { - size_t stride = len / tensor_para_size; - slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); - start += len; - } - bias_slices = {slice0, slice1}; - } - } - loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type, bias_slices); + loadWeightFromBin((T*)w.bias, {1, dim1}, prefix + ".bias", type); } const size_t bit_size = getBitSize(w.type); if (bit_size >= 16) { // fp16, fp32 - std::vector weight_slices{}; - if (enable_slice) { - if (slice_dim == 1) { - size_t start = 0; - ConcateSlice slice0{{{0, dim0}}}; - ConcateSlice slice1{{{}}}; - for (auto len : slice_shape) { - size_t stride = len / tensor_para_size; - slice1.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); - start += len; - } - weight_slices = {slice0, slice1}; - } - else { - size_t start = 0; - ConcateSlice slice0{{}}; - ConcateSlice slice1{{{0, dim1}}}; - for (auto len : slice_shape) { - size_t stride = len / tensor_para_size; - slice0.slices.push_back({start + stride * rank, start + stride * (rank + 1)}); - start += len; - } - weight_slices = {slice0, slice1}; - } - } - loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type, weight_slices); + loadWeightFromBin((T*)w.kernel, {dim0, dim1}, prefix + ".weight", type); } else { // int8, int4 const int factor = sizeof(float) * 8 / bit_size; @@ -380,12 +310,12 @@ void loadWeights(LlamaDenseWeight& w, FT_CHECK(dim1 % factor == 0); std::vector w_shape{dim0, dim1 / factor * sizeof(uint32_t)}; - loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8, {}); + loadWeightFromBin((int8_t*)w.kernel, w_shape, prefix + ".qweight", FtCudaDataType::INT8); const size_t group_count = w.group_size > 0 ? dim0 / w.group_size : 1; - loadWeightFromBin((half*)w.scales, {group_count, dim1}, prefix + ".scales", type, {}); - loadWeightFromBin((half*)w.zeros, {group_count, dim1}, prefix + ".zeros", type, {}); + loadWeightFromBin((half*)w.scales, {group_count, dim1}, prefix + ".scales", type); + loadWeightFromBin((half*)w.zeros, {group_count, dim1}, prefix + ".zeros", type); } } @@ -430,29 +360,12 @@ void LlamaDecoderLayerWeight::loadModel(std::string dir_path, FtCudaDataType (T*)self_attn_norm_weights, {hidden_units_}, dir_path + ".attention_norm.weight", model_file_type); loadWeightFromBin((T*)ffn_norm_weights, {hidden_units_}, dir_path + ".ffn_norm.weight", model_file_type); - loadWeights(self_attn_weights.qkv, - dir_path + ".attention.w_qkv", - tensor_para_rank_, - type, - tensor_para_size_, - 1, - {head_num_ * size_per_head_, kv_head_num_ * size_per_head_, kv_head_num_ * size_per_head_}); - - loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_, 0); - - // if (fused_up_and_gate_) { - // loadWeights(ffn_weights.fused_gating_intermediate, - // dir_path + ".feed_forward.w13", - // tensor_para_rank_, - // type, - // tensor_para_size_, - // 1); - // } - // else { - loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_, 1); - loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_, 1); - // } - loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_, 0); + loadWeights(self_attn_weights.qkv, dir_path + ".attention.w_qkv", tensor_para_rank_, type, tensor_para_size_); + loadWeights(self_attn_weights.output, dir_path + ".attention.wo", tensor_para_rank_, type, tensor_para_size_); + + loadWeights(ffn_weights.gating, dir_path + ".feed_forward.w1", tensor_para_rank_, type, tensor_para_size_); + loadWeights(ffn_weights.intermediate, dir_path + ".feed_forward.w3", tensor_para_rank_, type, tensor_para_size_); + loadWeights(ffn_weights.output, dir_path + ".feed_forward.w2", tensor_para_rank_, type, tensor_para_size_); } template diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index a69df127d..40af823c0 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -212,18 +212,54 @@ void LlamaV2::forwardUnified(T* out, { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - invokeInputIdsEmbeddingLookupPosEncoding(decoder_input, - nullptr, // processed somewhere else - weights_->pre_decoder_embedding_table, - static_cast(nullptr), - pPromptTuningParam{}, - input_ids, - 0, // only used for position encoding - token_num, - token_num, - 1, - hidden_units_, - stream_); + if (tensor_para_.world_size_ == 1) { + invokeInputIdsEmbeddingLookupPosEncoding(decoder_input, + nullptr, // processed somewhere else + weights_->pre_decoder_embedding_table, + static_cast(nullptr), + pPromptTuningParam{}, + input_ids, + 0, // only used for position encoding + token_num, + token_num, + 1, + hidden_units_, + stream_); + sync_check_cuda_error(); + } + else { + const size_t local_hidden_units = hidden_units_ / tensor_para_.world_size_; + invokeInputIdsEmbeddingLookupPosEncoding(decoder_output + tensor_para_.rank_ * token_num * local_hidden_units, + nullptr, // processed somewhere else + weights_->pre_decoder_embedding_table, + static_cast(nullptr), + pPromptTuningParam{}, + input_ids, + 0, // only used for position encoding + token_num, + token_num, + 1, + local_hidden_units, + stream_); + sync_check_cuda_error(); + + { + NcclGuard nccl_guard(tensor_para_, stream_); + ftNcclAllGather(decoder_output, // send_buf + decoder_output, // recv_buf + token_num * hidden_units_ / tensor_para_.world_size_, // data_size + tensor_para_.rank_, + tensor_para_, + stream_); + + sync_check_cuda_error(); + } + + invokeInPlaceTranspose102( + decoder_input, decoder_output, tensor_para_.world_size_, token_num, local_hidden_units, false, stream_); + + sync_check_cuda_error(); + } count_and_fix(decoder_input, token_num * hidden_units_, "embedding", 1); @@ -299,8 +335,7 @@ void LlamaV2::postDecodeEmbedding(float* logits, float* local_logits, const T batch_size, hidden_units_, // k &alpha, - weights_->post_decoder_embedding_kernel - + tensor_para_.rank_ * local_vocab_size * hidden_units_, + weights_->post_decoder_embedding_kernel, data_type, hidden_units_, // k decoder_output, diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 1b1172f51..e83d736b2 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -51,6 +51,9 @@ LlamaWeight::LlamaWeight(size_t head_num, vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_; TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_); } + + FT_CHECK(hidden_units_ % tensor_para_size_ == 0); + decoder_layer_weights.reserve(num_layer_); for (unsigned l = 0; l < num_layer_; ++l) { decoder_layer_weights.push_back(new LlamaDecoderLayerWeight(l, @@ -78,6 +81,7 @@ LlamaWeight::~LlamaWeight() cudaFree((void*)post_decoder_embedding_kernel); pre_decoder_embedding_table = nullptr; + output_norm_weight = nullptr; post_decoder_embedding_kernel = nullptr; for (auto& p : decoder_layer_weights) { @@ -88,9 +92,10 @@ LlamaWeight::~LlamaWeight() template void LlamaWeight::mallocWeights() { - deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_); + FT_CHECK(vocab_size_padded_ % tensor_para_size_ == 0); + deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_ / tensor_para_size_); deviceMalloc((T**)&output_norm_weight, hidden_units_); - deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_); + deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_ / tensor_para_size_); } template @@ -103,15 +108,15 @@ void LlamaWeight::loadModel(std::string dir_path) dir_path += '/'; loadWeightFromBin((T*)pre_decoder_embedding_table, - {vocab_size_padded_ * hidden_units_}, - dir_path + "tok_embeddings.weight", + {vocab_size_padded_ * hidden_units_ / tensor_para_size_}, + dir_path + "tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight", model_file_type); loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type); loadWeightFromBin((T*)post_decoder_embedding_kernel, - {hidden_units_ * vocab_size_padded_}, - dir_path + "output.weight", + {hidden_units_ * vocab_size_padded_ / tensor_para_size_}, + dir_path + "output." + std::to_string(tensor_para_rank_) + ".weight", model_file_type); for (unsigned layer = 0; layer < num_layer_; ++layer) { @@ -124,19 +129,19 @@ TensorMap LlamaWeight::getParams() { TensorMap output; - output.insert("tok_embeddings.weight", + output.insert("tok_embeddings." + std::to_string(tensor_para_rank_) + ".weight", Tensor{MEMORY_GPU, getTensorType(), - {vocab_size_padded_ * hidden_units_ * sizeof(T)}, + {vocab_size_padded_ * hidden_units_ / tensor_para_size_ * sizeof(T)}, pre_decoder_embedding_table}); output.insert("norm.weight", Tensor{MEMORY_GPU, getTensorType(), {hidden_units_ * sizeof(T)}, output_norm_weight}); - output.insert("output.weight", + output.insert("output." + std::to_string(tensor_para_rank_) + ".weight", Tensor{MEMORY_GPU, getTensorType(), - {hidden_units_ * vocab_size_padded_ * sizeof(T)}, + {hidden_units_ * vocab_size_padded_ * sizeof(T) / tensor_para_size_}, post_decoder_embedding_kernel}); // transformer layers diff --git a/src/turbomind/utils/memory_utils.cu b/src/turbomind/utils/memory_utils.cu index 93547f364..f8bfb8efe 100644 --- a/src/turbomind/utils/memory_utils.cu +++ b/src/turbomind/utils/memory_utils.cu @@ -302,8 +302,7 @@ template void cudaRandomUniform(__nv_fp8_e4m3* buffer, const size_t size); // loads data from binary file. If it succeeds, returns a non-empty vector. If loading fails or // the product of the elements in shape is 0, this function will return an empty vector. template -std::vector -loadWeightFromBinHelper(std::vector shape, std::string filename, std::vector slices = {}) +std::vector loadWeightFromBinHelper(std::vector shape, std::string filename) { if (shape.size() > 2) { printf("[ERROR] shape should have less than two dims \n"); @@ -315,145 +314,48 @@ loadWeightFromBinHelper(std::vector shape, std::string filename, std::ve dim1 = shape[1]; } - if (slices.size() == 0) { - size_t size = dim0 * dim1; - if (size == 0) { - TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); - return std::vector(); - } - - std::vector host_array(size); - std::ifstream in(filename, std::ios::in | std::ios::binary); - if (!in.is_open()) { - TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); - return std::vector(); - } - - size_t loaded_data_size = sizeof(T) * size; - in.seekg(0, in.end); - const auto file_size_in_bytes = (size_t)in.tellg(); - in.seekg(0, in.beg); - - TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); - in.read((char*)host_array.data(), loaded_data_size); - - if (file_size_in_bytes != loaded_data_size) { - TM_LOG_WARNING("file %s has %ld, but request %ld, loading model fails!", - filename.c_str(), - file_size_in_bytes, - loaded_data_size); - return std::vector(); - } - in.close(); - // If we succeed, return an array with values. - return host_array; + size_t size = dim0 * dim1; + if (size == 0) { + TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); + return std::vector(); } - else { - // concate all slices on the same dims - - if (slices.size() != shape.size()) { - printf("[ERROR] slices should have same dims as shape \n"); - return std::vector(); - } - - // get slices - ConcateSlice slice0{{{0, dim0}}}; - ConcateSlice slice1{{{0, dim1}}}; - if (slices.size() > 0 && slices[0].slices.size() > 0) { - slice0 = slices[0]; - } - if (shape.size() == 2 && slices[1].slices.size() > 0) { - slice1 = slices[1]; - } - - size_t w0 = 0; - for (auto& s : slice0.slices) { - if (s.second > dim0) { - s.second = dim0; - } - if (s.second < s.first) { - printf("[ERROR] slice0: end < start \n"); - return std::vector(); - } - w0 += s.second - s.first; - } - size_t w1 = 0; - for (auto& s : slice1.slices) { - if (s.second > dim1) { - s.second = dim1; - } - if (s.second < s.first) { - printf("[ERROR] slice1: end < start \n"); - return std::vector(); - } - w1 += s.second - s.first; - } - - size_t size = w0 * w1; - size_t loaded_data_size = size * sizeof(T); - - TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename + " with slice."); - if (size == 0) { - TM_LOG_WARNING("shape is zero, skip loading weight from file %s \n", filename.c_str()); - return std::vector(); - } + std::vector host_array(size); + std::ifstream in(filename, std::ios::in | std::ios::binary); + if (!in.is_open()) { + TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); + return std::vector(); + } - std::vector host_array(size); - std::ifstream in(filename, std::ios::in | std::ios::binary); - if (!in.is_open()) { - TM_LOG_WARNING("file %s cannot be opened, loading model fails! \n", filename.c_str()); - return std::vector(); - } + size_t loaded_data_size = sizeof(T) * size; + in.seekg(0, in.end); + const auto file_size_in_bytes = (size_t)in.tellg(); + in.seekg(0, in.beg); - char* host_ptr = (char*)host_array.data(); - if (slice1.slices.size() == 0 - || (slice1.slices.size() == 1 && slice1.slices[0].second - slice1.slices[0].first == dim1)) { - for (auto& s : slice0.slices) { - size_t read_size = (s.second - s.first) * dim1 * sizeof(T); - size_t pos = s.first * dim1; - in.seekg(pos * sizeof(T)); - in.read((char*)host_ptr, read_size); - host_ptr += read_size; - } - in.close(); - return host_array; - } + TM_LOG_DEBUG("Read " + std::to_string(loaded_data_size) + " bytes from " + filename); + in.read((char*)host_array.data(), loaded_data_size); - { - for (auto& s0 : slice0.slices) { - // loop over outer slice - for (size_t line_id = s0.first; line_id < s0.second; ++line_id) { - // loop over lines - size_t pos0 = line_id * dim1; - for (auto& s1 : slice1.slices) { - // loop over inner slice - size_t pos = pos0 + s1.first; - size_t read_size = (s1.second - s1.first) * sizeof(T); - in.seekg(pos * sizeof(T)); - in.read(host_ptr, read_size); - host_ptr += read_size; - } - } - } - in.close(); - } - return host_array; + if (file_size_in_bytes != loaded_data_size) { + TM_LOG_WARNING("file %s has %ld, but request %ld, loading model fails!", + filename.c_str(), + file_size_in_bytes, + loaded_data_size); + return std::vector(); } + in.close(); + // If we succeed, return an array with values. + return host_array; } -std::vector loadArrayFromBin(std::vector shape, std::string filename, std::vector slices) +std::vector loadArrayFromBin(std::vector shape, std::string filename) { - return loadWeightFromBinHelper(shape, filename, slices); + return loadWeightFromBinHelper(shape, filename); } template -int loadWeightFromBinFunc(T* ptr, - std::vector shape, - std::string filename, - std::vector slices = std::vector()) +int loadWeightFromBinFunc(T* ptr, std::vector shape, std::string filename) { - std::vector host_array = loadWeightFromBinHelper(shape, filename, slices); + std::vector host_array = loadWeightFromBinHelper(shape, filename); if (host_array.empty()) { return 0; @@ -472,84 +374,49 @@ int loadWeightFromBinFunc(T* ptr, return 0; } -template int loadWeightFromBinFunc(float* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc(half* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc(float* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc(half* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc(int8_t* ptr, - std::vector shape, - std::string filename, - std::vector slices); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(int8_t* ptr, std::vector shape, std::string filename); #ifdef ENABLE_BF16 -template int loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc(float* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc(half* ptr, - std::vector shape, - std::string filename, - std::vector slices); -template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, - std::vector shape, - std::string filename, - std::vector slices); +template int +loadWeightFromBinFunc<__nv_bfloat16, float>(__nv_bfloat16* ptr, std::vector shape, std::string filename); +template int +loadWeightFromBinFunc<__nv_bfloat16, half>(__nv_bfloat16* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(float* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc(half* ptr, std::vector shape, std::string filename); +template int loadWeightFromBinFunc<__nv_bfloat16, __nv_bfloat16>(__nv_bfloat16* ptr, + std::vector shape, + std::string filename); #endif // ENABLE_BF16 -template int loadWeightFromBinFunc(int* ptr, - std::vector shape, - std::string filename, - std::vector slices); +template int loadWeightFromBinFunc(int* ptr, std::vector shape, std::string filename); #ifdef ENABLE_FP8 -template int loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, - std::vector shape, - std::string filename, - std::vector slices); +template int +loadWeightFromBinFunc<__nv_fp8_e4m3, float>(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename); #endif // ENABLE_FP8 template -int loadWeightFromBin(T* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices) +int loadWeightFromBin(T* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) { switch (model_file_type) { case FtCudaDataType::FP32: - loadWeightFromBinFunc(ptr, shape, filename, slices); + loadWeightFromBinFunc(ptr, shape, filename); break; case FtCudaDataType::FP16: - loadWeightFromBinFunc(ptr, shape, filename, slices); + loadWeightFromBinFunc(ptr, shape, filename); break; case FtCudaDataType::INT8: - loadWeightFromBinFunc(ptr, shape, filename, slices); + loadWeightFromBinFunc(ptr, shape, filename); break; #ifdef ENABLE_BF16 case FtCudaDataType::BF16: - loadWeightFromBinFunc(ptr, shape, filename, slices); + loadWeightFromBinFunc(ptr, shape, filename); break; #endif #ifdef ENABLE_FP8 case FtCudaDataType::FP8: - loadWeightFromBinFunc(ptr, shape, filename, slices); + loadWeightFromBinFunc(ptr, shape, filename); break; #endif default: @@ -560,50 +427,28 @@ int loadWeightFromBin(T* ptr, } template<> -int loadWeightFromBin(int* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices) +int loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type) { - loadWeightFromBinFunc(ptr, shape, filename, slices); + loadWeightFromBinFunc(ptr, shape, filename); return 0; } -template int loadWeightFromBin(float* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices); -template int loadWeightFromBin(half* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices); -template int loadWeightFromBin(int8_t* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices); +template int +loadWeightFromBin(float* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int +loadWeightFromBin(half* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); +template int +loadWeightFromBin(int8_t* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); #ifdef ENABLE_BF16 -template int loadWeightFromBin(__nv_bfloat16* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices); +template int +loadWeightFromBin(__nv_bfloat16* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); #endif #ifdef ENABLE_FP8 -template int loadWeightFromBin(__nv_fp8_e4m3* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices); +template int +loadWeightFromBin(__nv_fp8_e4m3* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); #endif -template int loadWeightFromBin(int* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type, - std::vector slices); +template int +loadWeightFromBin(int* ptr, std::vector shape, std::string filename, FtCudaDataType model_file_type); template __global__ void cudaD2DcpyConvert(T_OUT* dst, const T_IN* src, const size_t size) @@ -874,23 +719,42 @@ __global__ void transpose102(T_OUT* dst, T_IN* src, const int dim0, const int di } template -void invokeInPlaceTranspose102(T* data, T* workspace, const int dim0, const int dim1, const int dim2) +void invokeInPlaceTranspose102( + T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream) { // copy data to workspace, and then transpose from workspace to data // Note that this kernel is used for pre-processing and not very efficient. - cudaD2Dcpy(workspace, data, dim0 * dim1 * dim2); - transpose102<<<256, 256>>>(data, workspace, dim0, dim1, dim2); + const size_t count = dim0 * dim1 * dim2; + if (copy) { + cudaAutoCpy(workspace, data, count, stream); + } + const int block = 512; + const int grid = std::min((count + block - 1) / block, (size_t)8192); + transpose102<<>>(data, workspace, dim0, dim1, dim2); } #ifdef ENABLE_FP8 -template void invokeInPlaceTranspose102( - __nv_fp8_e4m3* data, __nv_fp8_e4m3* workspace, const int dim0, const int dim1, const int dim2); +template void invokeInPlaceTranspose102(__nv_fp8_e4m3* data, + __nv_fp8_e4m3* workspace, + const int dim0, + const int dim1, + const int dim2, + bool copy, + cudaStream_t stream); #endif // ENABLE_FP8 #ifdef ENABLE_BF16 -template void invokeInPlaceTranspose102( - __nv_bfloat16* data, __nv_bfloat16* workspace, const int dim0, const int dim1, const int dim2); +template void invokeInPlaceTranspose102(__nv_bfloat16* data, + __nv_bfloat16* workspace, + const int dim0, + const int dim1, + const int dim2, + bool copy, + cudaStream_t stream); #endif // ENABLE_BF16 -template void invokeInPlaceTranspose102(float* data, float* workspace, const int dim0, const int dim1, const int dim2); +template void invokeInPlaceTranspose102( + half* data, half* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream); +template void invokeInPlaceTranspose102( + float* data, float* workspace, const int dim0, const int dim1, const int dim2, bool copy, cudaStream_t stream); template void __global__ multiplyScale(T* tensor, float scale, const size_t size) diff --git a/src/turbomind/utils/memory_utils.h b/src/turbomind/utils/memory_utils.h index e51c90390..bb7a4f9c0 100644 --- a/src/turbomind/utils/memory_utils.h +++ b/src/turbomind/utils/memory_utils.h @@ -49,20 +49,13 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = template void cudaRandomUniform(T* buffer, const size_t size); -struct ConcateSlice { - std::vector> slices; -}; - template -int loadWeightFromBin(T* ptr, - std::vector shape, - std::string filename, - FtCudaDataType model_file_type = FtCudaDataType::FP32, - std::vector slices = std::vector()); +int loadWeightFromBin(T* ptr, + std::vector shape, + std::string filename, + FtCudaDataType model_file_type = FtCudaDataType::FP32); -std::vector loadArrayFromBin(std::vector shape, - std::string filename, - std::vector slices = std::vector()); +std::vector loadArrayFromBin(std::vector shape, std::string filename); // template // int loadWeightFromBinAndQuantizeForWeightOnly(int8_t* quantized_weight_ptr, @@ -115,7 +108,8 @@ template void invokeInPlaceTranspose0213(T* data, T* workspace, const int dim0, const int dim1, const int dim2, const int dim3); template -void invokeInPlaceTranspose102(T* data, T* workspace, const int dim0, const int dim1, const int dim2); +void invokeInPlaceTranspose102( + T* data, T* workspace, const int dim0, const int dim1, const int dim2, bool copy = true, cudaStream_t stream = 0); template void invokeMultiplyScale(T* tensor, float scale, const size_t size, cudaStream_t stream);