Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support size_per_head=112 #660

Merged
merged 3 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t&
case 96:
mmha_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 112:
mmha_launch_kernel<T, 112, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 128:
mmha_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "decoder_masked_multihead_attention_template.hpp"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>

////////////////////////////////////////////////////////////////////////////////////////////////////

#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
DO_CROSS_ATTENTION, \
HAS_BEAMS><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)

////////////////////////////////////////////////////////////////////////////////////////////////////

// !!! Specialize the launcher for Cross attention
template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep;
if (params.cache_indir == nullptr) {
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream);
}
}
else {
if (tlength < 32) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream);
}
else if (tlength < 2048) {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream);
}
else {
MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream);
}
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template void mmha_launch_kernel<float, 112, 128, Masked_multihead_attention_params<float>>(
const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream);
template void mmha_launch_kernel<uint16_t, 112, 128, Masked_multihead_attention_params<uint16_t>>(
const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Masked_multihead_attention_params<__nv_bfloat16>>(
const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>(
const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
#endif

template void mmha_launch_kernel<float, 112, 128, Cross_multihead_attention_params<float>>(
const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream);
template void mmha_launch_kernel<uint16_t, 112, 128, Cross_multihead_attention_params<uint16_t>>(
const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
template void mmha_launch_kernel<__nv_bfloat16, 112, 128, Cross_multihead_attention_params<__nv_bfloat16>>(
const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
template void mmha_launch_kernel<__nv_fp8_e4m3, 112, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>(
const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
#endif

#undef MMHA_LAUNCH_KERNEL
Original file line number Diff line number Diff line change
Expand Up @@ -796,8 +796,8 @@ DecoderCrossAttentionLayer<T>::DecoderCrossAttentionLayer(size_t max_b
q_scaling_(q_scaling)
{
FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80
|| size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160
|| size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
|| size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144
|| size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
}

template<typename T>
Expand Down Expand Up @@ -1030,4 +1030,4 @@ template class DecoderCrossAttentionLayer<half>;
template class DecoderCrossAttentionLayer<__nv_bfloat16>;
#endif

} // namespace fastertransformer
} // namespace fastertransformer
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,8 @@ DecoderSelfAttentionLayer<T>::DecoderSelfAttentionLayer(size_t max_bat
int8_mode_(int8_mode)
{
FT_CHECK(size_per_head_ == 32 || size_per_head_ == 48 || size_per_head_ == 64 || size_per_head_ == 80
|| size_per_head_ == 96 || size_per_head_ == 128 || size_per_head_ == 144 || size_per_head_ == 160
|| size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
|| size_per_head_ == 96 || size_per_head_ == 112 || size_per_head_ == 128 || size_per_head_ == 144
|| size_per_head_ == 160 || size_per_head_ == 192 || size_per_head_ == 224 || size_per_head_ == 256);
if (int8_mode_ == 1) {
FT_CHECK_WITH_INFO(!(std::is_same<T, float>::value), "Weight only quant not supported for fp32.");
weight_only_int8_fc_runner_ = std::make_shared<CutlassFpAIntBGemmRunner<T, uint8_t>>();
Expand Down
2 changes: 1 addition & 1 deletion src/fastertransformer/utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ add_library(mpi_utils STATIC mpi_utils.cc)
set_property(TARGET mpi_utils PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET mpi_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
if (BUILD_MULTI_GPU)
target_link_libraries(mpi_utils PUBLIC -lmpi logger)
target_link_libraries(mpi_utils PUBLIC -lmpi -lmpi_cxx logger)
Copy link
Contributor

@dwyatte dwyatte Jun 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm not reviewing, but am interested in this functionality)

@dskhudia can you say more about this change and whether it's needed? I'm building this via https://github.com/triton-inference-server/fastertransformer_backend and mpi_cxx is not there for linking (so may require some additional changes across the FasterTransformer ecosystem). If this is not actually needed, omitting it may be easier to get this approved/merged (edit: I am able to run this via my fork(s) without mpi_cxx)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was needed when I was building FT as described in the GPT doc. I created a separate PR for this as well. Willing to remove it from this PR. It accidentally got added in this.

#616

endif()

add_library(nccl_utils STATIC nccl_utils.cc)
Expand Down
6 changes: 3 additions & 3 deletions tests/decoding/tf_fused_self_multihead_attention_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def test_attn_head_fp16(self):
self.run_attn(4, 128, head, 64, tf.float16)

def test_attn_size_fp32(self):
for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]:
for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]:
tf.reset_default_graph()
self.run_attn(4, 128, 12, size, tf.float32)

def test_attn_size_fp16(self):
for size in [32, 64, 80, 96, 128, 144, 160, 192, 224, 256]:
for size in [32, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256]:
tf.reset_default_graph()
self.run_attn(4, 128, 12, size, tf.float16)

Expand Down Expand Up @@ -171,4 +171,4 @@ def run_attn(self, batch_size, seq_len, head_num, size_per_head, data_type):
assert(v_cache_max_diff < threshold)

if __name__ == "__main__":
unittest.main()
unittest.main()