From eef4cf6a4c3f880275bf951c46aae09a607f4549 Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Wed, 17 May 2023 23:48:37 +0000 Subject: [PATCH 1/2] fix multi-gpu build --- src/fastertransformer/utils/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/utils/CMakeLists.txt b/src/fastertransformer/utils/CMakeLists.txt index 9796ad076..22f735c27 100644 --- a/src/fastertransformer/utils/CMakeLists.txt +++ b/src/fastertransformer/utils/CMakeLists.txt @@ -57,7 +57,7 @@ add_library(mpi_utils STATIC mpi_utils.cc) set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) if (BUILD_MULTI_GPU) - target_link_libraries(mpi_utils PUBLIC -lmpi logger) + target_link_libraries(mpi_utils PUBLIC -lmpi -lmpi_cxx logger) endif() add_library(nccl_utils STATIC nccl_utils.cc) From 7a71b1888f81bb6f5d9ffffb6b1af493927e041a Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Fri, 9 Jun 2023 01:55:12 +0000 Subject: [PATCH 2/2] add support for size_per_head=112 for gpt decoder --- .../decoder_masked_multihead_attention.cu | 3 + .../decoder_masked_multihead_attention_112.cu | 101 ++++++++++++++++++ .../DecoderCrossAttentionLayer.cu | 6 +- .../DecoderSelfAttentionLayer.cc | 4 +- ...used_self_multihead_attention_unit_test.py | 6 +- 5 files changed, 112 insertions(+), 8 deletions(-) create mode 100644 src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_112.cu diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu index 4618673d8..8c6e682a4 100644 --- a/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention.cu @@ -41,6 +41,9 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& case 96: mmha_launch_kernel(params, stream); break; + case 112: + mmha_launch_kernel(params, stream); + break; case 128: mmha_launch_kernel(params, stream); break; diff --git a/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_112.cu b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_112.cu new file mode 100644 index 000000000..3261791c3 --- /dev/null +++ b/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_112.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu b/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu index 7d022c4f4..55c4d9071 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu +++ b/src/fastertransformer/layers/attention_layers/DecoderCrossAttentionLayer.cu @@ -796,8 +796,8 @@ DecoderCrossAttentionLayer::DecoderCrossAttentionLayer(size_t max_b q_scaling_(q_scaling) { FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80 - || size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160 - || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); + || size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144 + || size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); } template @@ -1030,4 +1030,4 @@ template class DecoderCrossAttentionLayer; template class DecoderCrossAttentionLayer<__nv_bfloat16>; #endif -} // namespace fastertransformer \ No newline at end of file +} // namespace fastertransformer diff --git a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc index 7ff426128..44fed478b 100644 --- a/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/DecoderSelfAttentionLayer.cc @@ -278,8 +278,8 @@ DecoderSelfAttentionLayer::DecoderSelfAttentionLayer(size_t max_bat int8_mode_(int8_mode) { FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80 - || size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160 - || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); + || size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144 + || size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256); if (int8_mode_ == 1) { FT_CHECK_WITH_INFO(!(std::is_same::value), "Weight only quant not supported for fp32."); weight_only_int8_fc_runner_ = std::make_shared>(); diff --git a/tests/decoding/tf_fused_self_multihead_attention_unit_test.py b/tests/decoding/tf_fused_self_multihead_attention_unit_test.py index a4a7031e5..a09c0028d 100644 --- a/tests/decoding/tf_fused_self_multihead_attention_unit_test.py +++ b/tests/decoding/tf_fused_self_multihead_attention_unit_test.py @@ -56,12 +56,12 @@ def test_attn_head_fp16(self): self.run_attn(4, 128, head, 64, tf.float16) def test_attn_size_fp32(self): - for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]: + for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]: tf.reset_default_graph() self.run_attn(4, 128, 12, size, tf.float32) def test_attn_size_fp16(self): - for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]: + for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]: tf.reset_default_graph() self.run_attn(4, 128, 12, size, tf.float16) @@ -171,4 +171,4 @@ def run_attn(self, batch_size, seq_len, head_num, size_per_head, data_type): assert(v_cache_max_diff < threshold) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()