diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index c1536427e412..cfb819a1791d 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -91,6 +91,103 @@ __global__ void apply_rotary_pos_emb(__half* mixed_query, } #endif } +__global__ void apply_rotary_pos_emb1(float* mixed_query, + float* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count) +{ + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int id = threadIdx.x; + int gid = id >> 5; + int lane = id & 0x1f; + + unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; + unsigned offset = head_id * head_size; + + unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + + if (head_id < total_count) { + while (lane < rotary_dim) { + float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + float q = mixed_query[offset + lane]; + float k = key_layer[offset + lane]; + float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); + float q_rot = (q * rotary_sign); + float k_rot = (k * rotary_sign); + q_rot = g.shfl_xor(q_rot, 1); + k_rot = g.shfl_xor(k_rot, 1); + q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); + k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); + + mixed_query[offset + lane] = q; + key_layer[offset + lane] = k; + + lane += WARP_SIZE; + } + } +} +__global__ void apply_rotary_pos_emb1(__half* mixed_query, + __half* key_layer, + unsigned rotary_dim, + unsigned seq_len, + unsigned seq_offset, + unsigned num_heads, + unsigned head_size, + unsigned total_count) +{ +#if __CUDA_ARCH__ >= 700 + cg::thread_block b = cg::this_thread_block(); + cg::thread_block_tile g = cg::tiled_partition(b); + + int id = threadIdx.x; + int gid = id >> 5; + int lane = id & 0x1f; + + unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; + unsigned offset = head_id * head_size; + + constexpr unsigned mask[32] = { + 0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000, + 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, 0x100 | 0x100000, 0x200 | 0x200000, + 0x400 | 0x400000, 0x800 | 0x800000, 0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, + 0x8000 | 0x8, 0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80, + 0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, 0x1000000, + 0x2000000, 0x4000000, 0x8000000, 0x10000000, 0x20000000, + 0x40000000, 0x80000000}; + + unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + unsigned half_dim = rotary_dim >> 1; + if (head_id < total_count) { + while (lane < rotary_dim) { + float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim; + inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + float q = (float)mixed_query[offset + lane]; + float k = (float)key_layer[offset + lane]; + float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0); + float q_rot = (q * rotary_sign); + float k_rot = (k * rotary_sign); + auto q_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], q_rot, lane + half_dim) + : __shfl_sync(mask[lane], q_rot, lane - half_dim); + auto k_rot_tmp = lane < half_dim ? __shfl_sync(mask[lane], k_rot, lane + half_dim) + : __shfl_sync(mask[lane], k_rot, lane - half_dim); + q = q * cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); + k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq); + + mixed_query[offset + lane] = (__half)q; + key_layer[offset + lane] = (__half)k; + + lane += WARP_SIZE; + } + } +#endif +} template void launch_apply_rotary_pos_emb(T* mixed_query, @@ -101,14 +198,19 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned offset, unsigned num_heads, unsigned batch, + bool rotate_half, + bool rotate_every_two, cudaStream_t stream) { int total_count = batch * num_heads * seq_len; dim3 block_dims(1024); dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size); - - apply_rotary_pos_emb<<>>( - mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); + if (rotate_every_two) + apply_rotary_pos_emb<<>>( + mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); + else if (rotate_half) + apply_rotary_pos_emb1<<>>( + mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); } template void launch_apply_rotary_pos_emb(float*, @@ -119,6 +221,8 @@ template void launch_apply_rotary_pos_emb(float*, unsigned, unsigned, unsigned, + bool, + bool, cudaStream_t); template void launch_apply_rotary_pos_emb<__half>(__half*, __half*, @@ -128,4 +232,141 @@ template void launch_apply_rotary_pos_emb<__half>(__half*, unsigned, unsigned, unsigned, + bool, + bool, cudaStream_t); +/* +__global__ void apply_rotary_pos_emb(float* mixed_query, +float* key_layer, +unsigned rotary_dim, +unsigned seq_len, +unsigned seq_offset, +unsigned num_heads, +unsigned head_size, +unsigned total_count) +{ +cg::thread_block b = cg::this_thread_block(); +cg::thread_block_tile g = cg::tiled_partition(b); + +int id = threadIdx.x; +int gid = id >> 5; +int lane = id & 0x1f; + +unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; +unsigned offset = head_id * head_size; + +unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + +if (head_id < total_count) { +while (lane < rotary_dim) { +float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; +inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; +float q = mixed_query[offset + lane]; +float k = key_layer[offset + lane]; +float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0); +float q_rot = (q * rotary_sign); +float k_rot = (k * rotary_sign); +q_rot = g.shfl_xor(q_rot, 1); +k_rot = g.shfl_xor(k_rot, 1); +q = q * cosf(inv_freq) + q_rot * sinf(inv_freq); +k = k * cosf(inv_freq) + k_rot * sinf(inv_freq); + +mixed_query[offset + lane] = q; +key_layer[offset + lane] = k; + +lane += WARP_SIZE; +} +} +} + +__global__ void apply_rotary_pos_emb(__half* mixed_query, +__half* key_layer, +unsigned rotary_dim, +unsigned seq_len, +unsigned seq_offset, +unsigned num_heads, +unsigned head_size, +unsigned total_count) +{ +#if __CUDA_ARCH__ >= 700 +cg::thread_block b = cg::this_thread_block(); +cg::thread_block_tile g = cg::tiled_partition(b); + +int id = threadIdx.x; +int gid = id >> 5; +int lane = id & 0x1f; + +unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; +unsigned offset = head_id * head_size; +constexpr unsigned mask[32] = {0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, +0x10 | 0x10000, 0x20 | 0x20000, 0x40 | 0x40000, 0x80 | 0x80000, +0x100 | 0x100000, 0x200 | 0x200000, 0x400 | 0x400000, 0x800 | 0x800000, +0x1000 | 0x1, 0x2000 | 0x2, 0x4000 | 0x4, 0x8000 | 0x8, +0x10000 | 0x10, 0x20000 | 0x20, 0x40000 | 0x40, 0x80000 | 0x80, +0x100000 | 0x100, 0x200000 | 0x200, 0x400000 | 0x400, 0x800000 | 0x800, +0x1000000, 0x2000000, 0x4000000, 0x8000000, +0x10000000, 0x20000000, 0x40000000, 0x80000000}; +unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset; + +if (head_id < total_count) { +while (lane < rotary_dim) { +//float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim; +float inv_freq = (float)((lane % (rotary_dim >> 1)) * 2) / (float)rotary_dim; +inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; +float q = (float)mixed_query[offset + lane]; +float k = (float)key_layer[offset + lane]; +float rotary_sign = (lane > 11 ? -1.0 : 1.0); +float q_rot = (q * rotary_sign); +float k_rot = (k * rotary_sign); +auto q_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], q_rot, lane + 12) : __shfl_sync(mask[lane], +q_rot, lane - 12);//g.shfl_xor(q_rot, 12); auto k_rot_tmp = lane < 12 ? __shfl_sync(mask[lane], +k_rot, lane + 12) : __shfl_sync(mask[lane], k_rot, lane - 12);//g.shfl_xor(k_rot, 12); q = q * +cosf(inv_freq) + q_rot_tmp * sinf(inv_freq); k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq); + +mixed_query[offset + lane] = (__half)q; +key_layer[offset + lane] = (__half)k; + +lane += WARP_SIZE; +} +} +#endif +} + +template +void launch_apply_rotary_pos_emb(T* mixed_query, +T* key_layer, +unsigned head_size, +unsigned seq_len, +unsigned rotary_dim, +unsigned offset, +unsigned num_heads, +unsigned batch, +cudaStream_t stream) +{ +int total_count = batch * num_heads * seq_len; +dim3 block_dims(1024); +dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size); + +apply_rotary_pos_emb<<>>( +mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count); +} + +template void launch_apply_rotary_pos_emb(float*, +float*, +unsigned, +unsigned, +unsigned, +unsigned, +unsigned, +unsigned, +cudaStream_t); +template void launch_apply_rotary_pos_emb<__half>(__half*, +__half*, +unsigned, +unsigned, +unsigned, +unsigned, +unsigned, +unsigned, +cudaStream_t); +*/ diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index e6a2c905a356..70bbf42cf9ed 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -168,123 +168,148 @@ template void launch_bias_add(float*, const float*, int, int, cudaStream_ template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t); __global__ void fused_bias_residual(float* input, - const float* residual, - const float* bias, + float* output, + float* attn, + float* bias, + float* attnbias, int total_count, int intermediate_size, - bool add_bias) + int mp_size) { float4* input_cast = reinterpret_cast(input); - const float4* residual_cast = reinterpret_cast(residual); - const float4* bias_cast = reinterpret_cast(bias); + float4* output_cast = reinterpret_cast(output); + float4* attn_cast = reinterpret_cast(attn); + float4* bias_cast = reinterpret_cast(bias); + float4* attnbias_cast = reinterpret_cast(attnbias); int offset = blockIdx.x * blockDim.x + threadIdx.x; if (offset < total_count) { float4 data = input_cast[offset]; - float4 res_vec = residual_cast[offset]; - if (add_bias) { - float4 bias_data = bias_cast[offset % intermediate_size]; - data.x += (res_vec.x + bias_data.x); - data.y += (res_vec.y + bias_data.y); - data.z += (res_vec.z + bias_data.z); - data.w += (res_vec.w + bias_data.w); - } else { - data.x += res_vec.x; - data.y += res_vec.y; - data.z += res_vec.z; - data.w += res_vec.w; - } + float4 out = output_cast[offset]; + float4 res_vec = attn_cast[offset]; + float4 bias_data = bias_cast[offset % intermediate_size]; + float4 attn_bias = attnbias_cast[offset % intermediate_size]; - input_cast[offset] = data; + data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x); + data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y); + data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z); + data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w); + + output_cast[offset] = data; } } __global__ void fused_bias_residual(__half* input, - const __half* residual, - const __half* bias, + __half* output, + __half* attn, + __half* bias, + __half* attn_bias, int total_count, int intermediate_size, - bool add_bias) + int mp_size) { #ifdef HALF_PRECISION_AVAILABLE float2* input_cast = reinterpret_cast(input); - const float2* residual_cast = reinterpret_cast(residual); + float2* output_cast = reinterpret_cast(output); + float2* attn_cast = reinterpret_cast(attn); - const float2* bias_cast = reinterpret_cast(bias); + float2* bias_cast = reinterpret_cast(bias); + float2* attnbias_cast = reinterpret_cast(attn_bias); int offset = blockIdx.x * blockDim.x + threadIdx.x; if (offset < total_count) { float2 vals_vec = input_cast[offset]; - float2 res_vec = residual_cast[offset]; + float2 out_vec = output_cast[offset]; + float2 res_vec = attn_cast[offset]; + + float2 bias_vec = bias_cast[offset % intermediate_size]; + float2 attn_bias_vec = attnbias_cast[offset % intermediate_size]; __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); + __half2* out_half = reinterpret_cast<__half2*>(&out_vec); __half2* res_half = reinterpret_cast<__half2*>(&res_vec); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + __half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec); float2 low_data = __half22float2(vals_half[0]); float2 high_data = __half22float2(vals_half[1]); + float2 low_out = __half22float2(out_half[0]); + float2 high_out = __half22float2(out_half[1]); + float2 low_res = __half22float2(res_half[0]); float2 high_res = __half22float2(res_half[1]); - if (add_bias) { - float2 bias_vec = bias_cast[offset % intermediate_size]; - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - low_data.x += (low_res.x + low_bias.x); - low_data.y += (low_res.y + low_bias.y); - high_data.x += (high_res.x + high_bias.x); - high_data.y += (high_res.y + high_bias.y); - } else { - low_data.x += low_res.x; - low_data.y += low_res.y; - high_data.x += high_res.x; - high_data.y += high_res.y; - } + float2 low_bias = __half22float2(bias_half[0]); + float2 high_bias = __half22float2(bias_half[1]); + + float2 attn_low_bias = __half22float2(attnbias_half[0]); + float2 attn_high_bias = __half22float2(attnbias_half[1]); + + low_data.x = + (low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x)); + low_data.y = + (low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y)); + high_data.x = + (high_data.x + high_res.x) * mp_size + (high_out.x + (high_bias.x + attn_high_bias.x)); + high_data.y = + (high_data.y + high_res.y) * mp_size + (high_out.y + (high_bias.y + attn_high_bias.y)); vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); - input_cast[offset] = vals_vec; + output_cast[offset] = vals_vec; } #endif } template void launch_bias_residual(T* input, - const T* residual, - const T* bias, + T* output, + T* attn, + T* bias, + T* attn_bias, int batch, - int intermediate_size, - bool add_bias, + int hidden_dim, + int mp_size, cudaStream_t stream) { - int total_count = batch * intermediate_size / 4; + int total_count = batch * hidden_dim / 4; dim3 block_dims(1024); dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); fused_bias_residual<<>>( - input, residual, bias, total_count, intermediate_size / 4, add_bias); + input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size); } template void -launch_bias_residual(float*, const float*, const float*, int, int, bool, cudaStream_t); -template void -launch_bias_residual<__half>(__half*, const __half*, const __half*, int, int, bool, cudaStream_t); +launch_bias_residual(float*, float*, float*, float*, float*, int, int, int, cudaStream_t); +template void launch_bias_residual<__half>(__half*, + __half*, + __half*, + __half*, + __half*, + int, + int, + int, + cudaStream_t); __global__ void gptj_residual_add(float* input, float* output, float* attn, float* bias, + float* attnbias, int total_count, - int intermediate_size) + int intermediate_size, + float mp_size) { float4* input_cast = reinterpret_cast(input); float4* output_cast = reinterpret_cast(output); float4* attn_cast = reinterpret_cast(attn); float4* bias_cast = reinterpret_cast(bias); + float4* attnbias_cast = reinterpret_cast(attnbias); int offset = blockIdx.x * blockDim.x + threadIdx.x; if (offset < total_count) { @@ -292,11 +317,12 @@ __global__ void gptj_residual_add(float* input, float4 out = output_cast[offset]; float4 res_vec = attn_cast[offset]; float4 bias_data = bias_cast[offset % intermediate_size]; + float4 attn_bias = attnbias_cast[offset % intermediate_size]; - data.x += (out.x + res_vec.x + bias_data.x); - data.y += (out.y + res_vec.y + bias_data.y); - data.z += (out.z + res_vec.z + bias_data.z); - data.w += (out.w + res_vec.w + bias_data.w); + data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x); + data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y); + data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z); + data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w); output_cast[offset] = data; } @@ -306,8 +332,10 @@ __global__ void gptj_residual_add(__half* input, __half* output, __half* attn, __half* bias, + __half* attn_bias, int total_count, - int intermediate_size) + int intermediate_size, + float mp_size) { #if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__) @@ -316,6 +344,7 @@ __global__ void gptj_residual_add(__half* input, float2* attn_cast = reinterpret_cast(attn); float2* bias_cast = reinterpret_cast(bias); + float2* attnbias_cast = reinterpret_cast(attn_bias); int offset = blockIdx.x * blockDim.x + threadIdx.x; @@ -325,11 +354,13 @@ __global__ void gptj_residual_add(__half* input, float2 res_vec = attn_cast[offset]; float2 bias_vec = bias_cast[offset % intermediate_size]; + float2 attn_bias_vec = attnbias_cast[offset % intermediate_size]; __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); __half2* out_half = reinterpret_cast<__half2*>(&out_vec); __half2* res_half = reinterpret_cast<__half2*>(&res_vec); __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); + __half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec); float2 low_data = __half22float2(vals_half[0]); float2 high_data = __half22float2(vals_half[1]); @@ -343,10 +374,17 @@ __global__ void gptj_residual_add(__half* input, float2 low_bias = __half22float2(bias_half[0]); float2 high_bias = __half22float2(bias_half[1]); - low_data.x += (low_out.x + low_res.x + low_bias.x); - low_data.y += (low_out.y + low_res.y + low_bias.y); - high_data.x += (high_out.x + high_res.x + high_bias.x); - high_data.y += (high_out.y + high_res.y + high_bias.y); + float2 attn_low_bias = __half22float2(attnbias_half[0]); + float2 attn_high_bias = __half22float2(attnbias_half[1]); + + low_data.x = + low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x)); + low_data.y = + low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y)); + high_data.x = + high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x)); + high_data.y = + high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y)); vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); @@ -361,8 +399,10 @@ void launch_gptj_residual_add(T* input, T* output, T* attn, T* bias, + T* attn_bias, int hidden_dim, int batch, + int mp_size, cudaStream_t stream) { int total_count = batch * hidden_dim / 4; @@ -370,13 +410,27 @@ void launch_gptj_residual_add(T* input, dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size); gptj_residual_add<<>>( - input, output, attn, bias, total_count, hidden_dim / 4); + input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size); } -template void -launch_gptj_residual_add(float*, float*, float*, float*, int, int, cudaStream_t); -template void -launch_gptj_residual_add<__half>(__half*, __half*, __half*, __half*, int, int, cudaStream_t); +template void launch_gptj_residual_add(float*, + float*, + float*, + float*, + float*, + int, + int, + int, + cudaStream_t); +template void launch_gptj_residual_add<__half>(__half*, + __half*, + __half*, + __half*, + __half*, + int, + int, + int, + cudaStream_t); __global__ void moe_res_matmul(float* residual, float* coef, diff --git a/csrc/transformer/inference/csrc/normalize.cu b/csrc/transformer/inference/csrc/normalize.cu old mode 100755 new mode 100644 index 417588f2ded3..e5da6cc6fe2e --- a/csrc/transformer/inference/csrc/normalize.cu +++ b/csrc/transformer/inference/csrc/normalize.cu @@ -211,7 +211,8 @@ __global__ void fused_residual_layer_norm(float* norm, const float* beta, float epsilon, int row_stride, - bool preLN) + bool preLN, + bool mlp_after_attn) { int iteration_stride = blockDim.x; @@ -233,8 +234,8 @@ __global__ void fused_residual_layer_norm(float* norm, inp_reg[k] = vals[input_id + row * row_stride]; float res_f = (residual[input_id + row * row_stride]); float bias_f = (bias[input_id]); - inp_reg[k] += res_f + bias_f; - if (preLN) res_add[input_id + row * row_stride] = inp_reg[k]; + if (mlp_after_attn) inp_reg[k] += res_f + bias_f; + // if (preLN) res_add[input_id + row * row_stride] = inp_reg[k]; sum += inp_reg[k++]; input_id += iteration_stride; } @@ -285,7 +286,8 @@ __global__ void fused_residual_layer_norm(__half* norm, const __half* beta, float epsilon, int row_stride, - bool preLN) + bool preLN, + bool mlp_after_attn) { #ifdef HALF_PRECISION_AVAILABLE int iteration_stride = blockDim.x; @@ -315,11 +317,13 @@ __global__ void fused_residual_layer_norm(__half* norm, float2 inp_f = __half22float2(inp_reg[k]); float2 res_f = __half22float2(residual_cast[input_id + row * row_stride]); float2 bias_f = __half22float2(bias_cast[input_id]); - inp_f.x += res_f.x + bias_f.x; - inp_f.y += res_f.y + bias_f.y; + if (mlp_after_attn) { + inp_f.x += res_f.x + bias_f.x; + inp_f.y += res_f.y + bias_f.y; + } inp_reg[k] = __float22half2_rn(inp_f); - - if (preLN) res_add_cast[input_id + row * row_stride] = inp_reg[k]; + // if (preLN) res_add_cast[input_id + row * row_stride] = __float22half2_rn(res_f); + // //inp_reg[k]; sum += inp_f.x + inp_f.y; input_id += iteration_stride; k++; @@ -376,6 +380,7 @@ void launch_residual_layer_norm(T* norm, int batch_size, int hidden_dim, bool preLN, + bool mlp_after_attn, cudaStream_t stream); template <> @@ -390,6 +395,7 @@ void launch_residual_layer_norm(float* norm, int batch_size, int hidden_dim, bool preLN, + bool mlp_after_attn, cudaStream_t stream) { constexpr int threads = 1024; @@ -398,8 +404,17 @@ void launch_residual_layer_norm(float* norm, dim3 block_dim(threads); - fused_residual_layer_norm<<>>( - norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim, preLN); + fused_residual_layer_norm<<>>(norm, + res_add, + vals, + residual, + bias, + gamma, + beta, + epsilon, + hidden_dim, + preLN, + mlp_after_attn); } template <> @@ -414,6 +429,7 @@ void launch_residual_layer_norm<__half>(__half* norm, int batch_size, int hidden_dim, bool preLN, + bool mlp_after_attn, cudaStream_t stream) { constexpr int threads = 1024; @@ -421,6 +437,15 @@ void launch_residual_layer_norm<__half>(__half* norm, dim3 grid_dim(batch_size); dim3 block_dim(threads); - fused_residual_layer_norm<<>>( - norm, res_add, vals, residual, bias, gamma, beta, epsilon, hidden_dim / 2, preLN); + fused_residual_layer_norm<<>>(norm, + res_add, + vals, + residual, + bias, + gamma, + beta, + epsilon, + hidden_dim / 2, + preLN, + mlp_after_attn); } diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 39ea65e5a22a..5432314bb6dd 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -233,13 +233,13 @@ at::Tensor ds_bias_residual(at::Tensor& input, at::Tensor& residual, at::Tensor& auto residual_cont = residual.contiguous(); int bsz = input_cont.size(0) * input_cont.size(1); - launch_bias_residual((T*)input_cont.data_ptr(), - (T*)residual_cont.data_ptr(), - (T*)bias.data_ptr(), - bsz, - input_cont.size(2), - (bias.size(0) > 1), - Context::Instance().GetCurrentStream()); + // launch_bias_residual((T*)input_cont.data_ptr(), + // (T*)residual_cont.data_ptr(), + // (T*)bias.data_ptr(), + // bsz, + // input_cont.size(2), + // (bias.size(0) > 1), + // Context::Instance().GetCurrentStream()); return input_cont; } @@ -517,7 +517,6 @@ at::Tensor ds_vector_matmul_int8(at::Tensor& input, template void mlp_unfused_cublas(at::Tensor& output, - at::Tensor& residual_add, at::Tensor& input, at::Tensor& residual, at::Tensor& input_bias, @@ -526,13 +525,14 @@ void mlp_unfused_cublas(at::Tensor& output, at::Tensor& gamma, at::Tensor& beta, const float epsilon, - bool preLayerNorm) + bool preLayerNorm, + bool mlp_after_attn) { int bsz = input.size(0) * input.size(1); - auto inp_norm = preLayerNorm ? at::empty_like(input) : residual_add; + auto inp_norm = at::empty_like(input); launch_residual_layer_norm((T*)inp_norm.data_ptr(), - (T*)residual_add.data_ptr(), + (T*)nullptr, (T*)input.data_ptr(), (T*)residual.data_ptr(), (T*)input_bias.data_ptr(), @@ -542,6 +542,7 @@ void mlp_unfused_cublas(at::Tensor& output, bsz, input.size(2), preLayerNorm, + mlp_after_attn, Context::Instance().GetCurrentStream()); float alpha = (T)1.0; @@ -566,15 +567,16 @@ void mlp_unfused_cublas(at::Tensor& output, Context::Instance().GetCurrentStream()); } template -std::vector ds_mlp_gemm(at::Tensor& input, - at::Tensor& residual, - at::Tensor& input_bias, - at::Tensor& weight, - at::Tensor& bias, - at::Tensor& gamma, - at::Tensor& beta, - const float epsilon, - bool preLayerNorm) +at::Tensor ds_mlp_gemm(at::Tensor& input, + at::Tensor& residual, + at::Tensor& input_bias, + at::Tensor& weight, + at::Tensor& bias, + at::Tensor& gamma, + at::Tensor& beta, + const float epsilon, + bool preLayerNorm, + bool mlp_after_attn) { auto input_cont = input.contiguous(); auto options = at::TensorOptions() @@ -584,12 +586,10 @@ std::vector ds_mlp_gemm(at::Tensor& input, .requires_grad(false); auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - auto residual_add = at::empty_like(input_cont); int bsz = input_cont.size(0) * input_cont.size(1); mlp_unfused_cublas(output, - residual_add, - input, + mlp_after_attn ? input : residual, residual, input_bias, weight, @@ -597,9 +597,10 @@ std::vector ds_mlp_gemm(at::Tensor& input, gamma, beta, epsilon, - preLayerNorm); + preLayerNorm, + mlp_after_attn); - return {output, residual_add}; + return output; } template @@ -629,18 +630,18 @@ std::vector ds_mlp_gemm_int8(at::Tensor& input, auto residual_add = (preLayerNorm ? at::empty_like(input_cont) : inp_norm); // computing the blocking across K dimension - launch_residual_layer_norm((T*)inp_norm.data_ptr(), - (T*)residual_add.data_ptr(), - (T*)input_cont.data_ptr(), - (T*)residual.data_ptr(), - (T*)input_bias.data_ptr(), - (T*)gamma.data_ptr(), - (T*)beta.data_ptr(), - epsilon, - bsz, - input_cont.size(2), - preLayerNorm, - Context::Instance().GetCurrentStream()); + // launch_residual_layer_norm((T*)inp_norm.data_ptr(), + // (T*)residual_add.data_ptr(), + // (T*)input_cont.data_ptr(), + // (T*)residual.data_ptr(), + // (T*)input_bias.data_ptr(), + // (T*)gamma.data_ptr(), + // (T*)beta.data_ptr(), + // epsilon, + // bsz, + // input_cont.size(2), + // preLayerNorm, + // Context::Instance().GetCurrentStream()); quantized_gemm(output, inp_norm, weight, q_scale, groups, 0); launch_bias_gelu((T*)output.data_ptr(), @@ -710,30 +711,58 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, return output; } -void gptj_residual_add(at::Tensor& output, +void residual_add_bias(at::Tensor& output, at::Tensor& input, at::Tensor& attention_output, - at::Tensor& output_b) + at::Tensor& output_b, + at::Tensor& attention_b, + int mp_size, + bool mlp_after_attn) { int bsz = input.size(0) * input.size(1); int hidden_size = input.size(2); // cudaStreamWaitEvent( // Context::Instance().GetCurrentStream(), Context::Instance().GetCompEvent(2), 0); if (input.scalar_type() == at::kFloat) - launch_gptj_residual_add((float*)input.data_ptr(), - (float*)output.data_ptr(), - (float*)attention_output.data_ptr(), - (float*)output_b.data_ptr(), - hidden_size, - bsz, - Context::Instance().GetCurrentStream()); + if (mlp_after_attn) + launch_bias_residual((float*)input.data_ptr(), + (float*)output.data_ptr(), + (float*)attention_output.data_ptr(), + (float*)output_b.data_ptr(), + (float*)attention_b.data_ptr(), + bsz, + hidden_size, + mp_size, + Context::Instance().GetCurrentStream()); + else + launch_gptj_residual_add((float*)input.data_ptr(), + (float*)output.data_ptr(), + (float*)attention_output.data_ptr(), + (float*)output_b.data_ptr(), + (float*)attention_b.data_ptr(), + hidden_size, + bsz, + mp_size, + Context::Instance().GetCurrentStream()); + else if (mlp_after_attn) + launch_bias_residual((__half*)input.data_ptr(), + (__half*)output.data_ptr(), + (__half*)attention_output.data_ptr(), + (__half*)output_b.data_ptr(), + (__half*)attention_b.data_ptr(), + bsz, + hidden_size, + mp_size, + Context::Instance().GetCurrentStream()); else launch_gptj_residual_add<__half>((__half*)input.data_ptr(), (__half*)output.data_ptr(), (__half*)attention_output.data_ptr(), (__half*)output_b.data_ptr(), + (__half*)attention_b.data_ptr(), hidden_size, bsz, + mp_size, Context::Instance().GetCurrentStream()); } @@ -741,7 +770,9 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, at::Tensor& key_layer, unsigned rotary_dim, unsigned offset, - unsigned num_heads) + unsigned num_heads, + bool rotate_half, + bool rotate_every_two) { auto query_cont = mixed_query.contiguous(); auto key_cont = key_layer.contiguous(); @@ -759,6 +790,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, offset, num_heads, bsz, + rotate_half, + rotate_every_two, Context::Instance().GetCurrentStream()); else launch_apply_rotary_pos_emb<__half>((__half*)query_cont.data_ptr(), @@ -769,6 +802,8 @@ std::vector apply_rotary_pos_emb(at::Tensor& mixed_query, offset, num_heads, bsz, + rotate_half, + rotate_every_two, Context::Instance().GetCurrentStream()); return {query_cont, key_cont}; } @@ -863,7 +898,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) "DeepSpeed linear_layer with int8 (CUDA)"); m.def("fused_gemm_gelu_fp32", &fused_gemm_gelu, "DeepSpeed mlp with fp32 (CUDA)"); m.def("fused_gemm_gelu_fp16", &fused_gemm_gelu<__half>, "DeepSpeed mlp with fp16 (CUDA)"); - m.def("gptj_residual_add", &gptj_residual_add, "DeepSpeed mlp with fp16 (CUDA)"); + m.def("residual_add", &residual_add_bias, "DeepSpeed mlp with fp16 (CUDA)"); m.def("apply_rotary_pos_emb", &apply_rotary_pos_emb, "DeepSpeed mlp with fp16 (CUDA)"); m.def("einsum_sec_sm_ecm_fp32", &einsum_sec_sm_ecm, diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index 94ab9bf185c4..06b4340061c9 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -48,11 +48,13 @@ void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, c template void launch_bias_residual(T* input, - const T* residual, - const T* bias, - int size, - int intermediate_size, - bool add_bias, + T* output, + T* attn, + T* bias, + T* attn_bias, + int batch, + int hidden_dim, + int mp_size, cudaStream_t stream); template @@ -77,6 +79,7 @@ void launch_residual_layer_norm(T* norm, int batch_size, int hidden_dim, bool preLN, + bool mlp_after_attn, cudaStream_t stream); template void launch_dequantize(T* output, @@ -93,9 +96,12 @@ void launch_gptj_residual_add(T* input, T* output, T* attn, T* bias, + T* attn_bias, int batch, int head_size, + int mp_size, cudaStream_t stream); + template void launch_apply_rotary_pos_emb(T* mixed_query, T* key_layer, @@ -105,6 +111,8 @@ void launch_apply_rotary_pos_emb(T* mixed_query, unsigned offset, unsigned num_heads, unsigned batch, + bool rotate_half, + bool rotate_every_two, cudaStream_t stream); template diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index eca00056337d..2292b4195c07 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -220,6 +220,7 @@ def add_config_arguments(parser): def init_inference(model, triangular_masking=True, mp_size=1, + training_mp_size=1, mpu=None, ep_group=None, expert_mp_group=None, @@ -233,7 +234,8 @@ def init_inference(model, ep_size=1, moe=False, moe_experts=1, - moe_type='standard'): + moe_type='standard', + args=None): """Initialize the DeepSpeed InferenceEngine. Arguments: @@ -245,6 +247,9 @@ def init_inference(model, mp_size: Optional: Desired model parallel size, default is 1 meaning no model parallelism. + training_mp_size: Optional: if loading a checkpoint this is the mp size that it was trained with, + it may be different than what the mp size that you want to use during inference. + mpu: Optional: A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}() @@ -277,25 +282,24 @@ def init_inference(model, __git_branch__), ranks=[0]) - if isinstance(model, PipelineModule): - raise NotImplementedError("pipeline module support is not implemented yet") - else: - engine = InferenceEngine(model, - triangular_masking, - mp_size, - ep_size, - mpu, - ep_group, - expert_mp_group, - checkpoint, - dtype, - injection_policy, - return_tuple, - replace_method, - quantization_setting, - replace_with_kernel_inject, - moe, - moe_experts, - moe_type) + engine = InferenceEngine(model, + triangular_masking, + mp_size, + training_mp_size, + ep_size, + mpu, + ep_group, + expert_mp_group, + checkpoint, + dtype, + injection_policy, + return_tuple, + replace_method, + quantization_setting, + replace_with_kernel_inject, + moe, + moe_experts, + moe_type, + args) return engine diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 3cd24894092a..a37f7c23f599 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -17,6 +17,8 @@ import torch.distributed as dist import deepspeed.utils.groups as groups +DS_INFERENCE_ENABLED = False + class InferenceEngine(Module): inference_mp_group = None @@ -27,6 +29,7 @@ def __init__(self, model, triangular_masking=True, mp_size=1, + training_mp_size=1, ep_size=1, mpu=None, ep_group=None, @@ -40,7 +43,8 @@ def __init__(self, replace_with_kernel_inject=False, moe=False, moe_experts=1, - moe_type='standard'): + moe_type='standard', + config=None): """ Args: model: torch.nn.Module @@ -58,12 +62,14 @@ def __init__(self, replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise, the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection) """ + global DS_INFERENCE_ENABLED + DS_INFERENCE_ENABLED = True super().__init__() self.module = model - self._get_model_config_generate() + self._get_model_config_generate(config) self.mp_world_size = mp_size self.checkpoint = checkpoint @@ -109,14 +115,16 @@ def __init__(self, replace_with_kernel_inject, moe, moe_experts, - moe_type) + moe_type, + training_mp_size) elif replace_method == 'auto': self._apply_injection_policy( return_tuple=return_tuple, replace_with_kernel_inject=replace_with_kernel_inject, moe=moe, moe_experts=moe_experts, - moe_type=moe_type) + moe_type=moe_type, + training_mp_size=training_mp_size) device = torch.cuda.current_device() logger.info(f"Place model to device: {device}") @@ -128,8 +136,8 @@ def __init__(self, else: self.module.register_forward_pre_hook(self._pre_forward_hook) - def _get_model_config_generate(self): - self.config = getattr(self.module, 'config', None) + def _get_model_config_generate(self, config): + self.config = getattr(self.module, 'config', None) if config is None else config self.generate = getattr(self.module, 'generate', None) def _create_model_parallel_group(self): @@ -221,7 +229,8 @@ def _apply_injection_policy(self, replace_with_kernel_inject=False, moe=False, moe_experts=1, - moe_type='standard'): + moe_type='standard', + training_mp_size=1): replace_transformer_layer(client_module, self.module, @@ -243,7 +252,8 @@ def _apply_injection_policy(self, replace_with_kernel_inject=replace_with_kernel_inject, moe=moe, moe_experts=moe_experts, - moe_type=moe_type) + moe_type=moe_type, + training_mp_size=training_mp_size) def _get_all_ckpt_names(self, checkpoints_path, tag): ckpt_file_pattern = self._get_ckpt_name(checkpoints_path, diff --git a/deepspeed/module_inject/__init__.py b/deepspeed/module_inject/__init__.py index 315d1b963d4b..7fe1a3b36b2e 100755 --- a/deepspeed/module_inject/__init__.py +++ b/deepspeed/module_inject/__init__.py @@ -1,3 +1,3 @@ from .replace_module import replace_transformer_layer, revert_transformer_layer from .module_quantize import quantize_transformer_layer -from .replace_policy import DSPolicy, HFBertLayerPolicy, MegatronLayerPolicy +from .replace_policy import DSPolicy, HFBertLayerPolicy diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index a76c905f8012..62d5cd75e9a1 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -2,7 +2,7 @@ import torch import deepspeed import deepspeed.ops.transformer as transformer_inference -from .replace_policy import HFBertLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy +from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, HFGPTJLayerPolicy from .replace_policy import replace_policies from ..constants import INFERENCE_GENERIC_MODE, INFERENCE_SPECIALIZED_MODE from ..runtime.weight_quantizer import WeightQuantization @@ -53,7 +53,7 @@ def merge_assert(self, dim1, dim2): def qkv_copy(self, dst, src): if src is None: - return src + return torch.nn.Parameter(src) src_shape = src.shape dst_shape = dst.shape @@ -61,7 +61,7 @@ def qkv_copy(self, dst, src): if (len(src_shape) == 2 and len(dst_shape) == 2): if src_shape[1] == dst_shape[1]: - return src + return torch.nn.Parameter(src) self.merge_assert(src_shape[1], dst_shape[1]) qkv_size = dst_shape[1] // 3 @@ -75,7 +75,7 @@ def qkv_copy(self, dst, src): torch.cuda.current_device()).contiguous()) else: if src_shape[0] == dst_shape[0]: - return src + return torch.nn.Parameter(src) qkv_size = dst_shape[0] // 3 qkv_split = [torch.split(src_s, qkv_size, dim=0) for src_s in src_split] @@ -86,11 +86,11 @@ def qkv_copy(self, dst, src): dst.data.copy_(bias_split[self.gpu_index].to( torch.cuda.current_device()).contiguous()) - return dst + return torch.nn.Parameter(dst) def copy(self, dst, src): if src is None: - return src + return torch.nn.Parameter(src) src_shape = src.shape dst_shape = dst.shape @@ -98,7 +98,7 @@ def copy(self, dst, src): if (len(src_shape) == 2 and len(dst_shape) == 2): if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]: - return src + return torch.nn.Parameter(src) if src_shape[0] != dst_shape[0]: self.merge_assert(src_shape[0], dst_shape[0]) @@ -111,13 +111,13 @@ def copy(self, dst, src): torch.cuda.current_device()).contiguous()) else: if src_shape[0] == dst_shape[0]: - return src + return torch.nn.Parameter(src) bias_split = torch.split(src.data, dst_shape[-1]) dst.data.copy_(bias_split[self.gpu_index].to( torch.cuda.current_device()).contiguous()) - return dst + return torch.nn.Parameter(dst) def replace_transformer_layer(orig_layer_impl, @@ -129,6 +129,7 @@ def replace_transformer_layer(orig_layer_impl, hidden_size=-1, num_attention_heads=-1, mp_size=1, + training_mp_size=1, mp_group=None, ep_group=None, expert_mp_group=None, @@ -203,7 +204,7 @@ def replace_with_policy(child, num_experts = child.mlp.num_experts moe = True - attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention = policy.attention() + attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() if not moe or moe_type == 'standard': mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() else: @@ -263,14 +264,19 @@ def replace_with_policy(child, global_experts=num_experts, mlp_type=moe_type) else: + rotary_dim = config.rotary_dim if hasattr(config, 'rotary_dim') else child.attention.rotary_ndims \ + if hasattr(child, 'attention') and hasattr(child.attention,'rotary_ndims') else -1 transformer_config = transformer_inference.DeepSpeedInferenceConfig( hidden_size=hidden_size, heads=num_attention_heads, layer_norm_eps=config.layer_norm_eps if hasattr( config, - 'layer_norm_eps') else (config.layer_norm_epsilon if hasattr( - config, - 'layer_norm_epsilon') else 1e-12), + 'layer_norm_eps') else + (config.layer_norm_epsilon + if hasattr(config, + 'layer_norm_epsilon') else config.layernorm_epsilon + if hasattr(config, + 'layernorm_epsilon') else 1.0e-12), fp16=fp16, pre_layer_norm=preln, mp_size=mp_size, @@ -282,9 +288,9 @@ def replace_with_policy(child, 'attention_layers') else False), window_size=(config.window_size if hasattr(config, 'window_size') else 1), - rotary_dim=(config.rotary_dim if hasattr(config, - 'rotary_dim') else -1), - mlp_after_attn=(policy_cls is not HFGPTJLayerPolicy)) + rotary_dim=rotary_dim, + mlp_after_attn=(rotary_dim is None or rotary_dim < 0), + training_mp_size=training_mp_size) if quantize and quantize_settings is not None: (quantization_scales, @@ -353,6 +359,43 @@ def transpose(data): qkvw.data = transpose(qkvw.data) dense_w.data = transpose(dense_w.data) + if megatron_v2: + new_module.config.rotate_half = True + new_module.config.rotate_every_two = False + + def _transpose(x): + num_attention_heads_per_partition = transformer_config.heads // transformer_config.mp_size + attention_head_size = x.shape[-1] // num_attention_heads_per_partition + new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, + attention_head_size) + x_1 = x.view(*new_x_shape) + (q, + k, + v) = torch.split(x_1, + (x_1.shape[-1] // 3), + dim=(x_1.dim() - 1)) + if len(q.shape) > 2: + return torch.cat((q.reshape(q.shape[0], + -1), + k.reshape(q.shape[0], + -1), + v.reshape(q.shape[0], + -1)), + dim=-1).reshape(x.shape) + else: + return torch.cat((q.reshape(-1), + k.reshape(-1), + v.reshape(-1)), + dim=-1).reshape(x.shape) + + qkvw = torch.nn.Parameter(_transpose(qkvw).contiguous()) + qkvb = torch.nn.Parameter(_transpose(qkvb).contiguous()) + + dense_b = dense_b * (transformer_config.training_mp_size / + transformer_config.mp_size) + _4hh_b = _4hh_b * (transformer_config.training_mp_size / + transformer_config.mp_size) + if mlp_linear_layer: _h4h_w = [transpose(moe_w1.data) for moe_w1 in _h4h_w] if moe else transpose(_h4h_w.data) @@ -683,6 +726,9 @@ def replace_module(model, orig_class, replace_fn, _replace_policy): return replaced_module +from ..pipe import PipelineModule + + def _replace_module(model, policies, layer_id=0): """ Traverse model's children recursively and apply any transformations in ``policies``. Arguments: @@ -693,12 +739,14 @@ def _replace_module(model, policies, layer_id=0): """ for name, child in model.named_children(): if child.__class__ in policies: - setattr( - model, - name, - policies[child.__class__][0](child, - policies[child.__class__][-1], - layer_id)) + replaced_module = policies[child.__class__][0](child, + policies[child.__class__][-1], + layer_id) + setattr(model, name, replaced_module) + if isinstance(model, PipelineModule): + assert hasattr(model, 'forward_funcs'),\ + "we require pipe-module to have the list of fwd_functions" + model.forward_funcs[model.fwd_map[name]] = replaced_module layer_id += 1 else: _, layer_id = _replace_module(child, policies, layer_id=layer_id) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index 47c54f5bb092..c8d14e431d08 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -5,10 +5,15 @@ class DSPolicy(ABC): - def __init__(self, inference=True, linear_layer=True, scale_attention=True): + def __init__(self, + inference=True, + linear_layer=True, + scale_attention=True, + megatron_v2=False): self.inference = inference self.linear_layer = linear_layer self.scale_attention = scale_attention + self.is_megatron_v2 = megatron_v2 def attention(self): """ @@ -70,15 +75,16 @@ def attention(self): vw = self.client_module.attention.self.value.weight vb = self.client_module.attention.self.value.bias - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0)) - qkvb = Parameter(torch.cat((qb, kb, vb), dim=0)) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) + qkvb = Parameter(torch.cat((qb, kb, vb), dim=0), requires_grad=False) return self.linear_layer, \ qkvw, \ qkvb, \ self.client_module.attention.output.dense.weight, \ self.client_module.attention.output.dense.bias, \ - self.scale_attention + self.scale_attention, \ + self.is_megatron_v2 def mlp(self): if self.preln: @@ -124,14 +130,15 @@ def attention(self): kw = self.client_module.attn.attention.k_proj.weight vw = self.client_module.attn.attention.v_proj.weight - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0)) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) return self.linear_layer, \ qkvw, \ None, \ self.client_module.attn.attention.out_proj.weight, \ self.client_module.attn.attention.out_proj.bias, \ - self.scale_attention + self.scale_attention, \ + self.is_megatron_v2 def mlp(self): return self.linear_layer, \ @@ -168,14 +175,15 @@ def attention(self): kw = self.client_module.attn.k_proj.weight vw = self.client_module.attn.v_proj.weight - qkvw = Parameter(torch.cat((qw, kw, vw), dim=0)) + qkvw = Parameter(torch.cat((qw, kw, vw), dim=0), requires_grad=False) return self.linear_layer, \ qkvw, \ None, \ self.client_module.attn.out_proj.weight, \ None, \ - self.scale_attention + self.scale_attention, \ + self.is_megatron_v2 def mlp(self): return self.linear_layer, \ @@ -225,7 +233,8 @@ def attention(self): attention.query_key_value.bias, \ attention.dense.weight, \ attention.dense.bias, \ - self.scale_attention + self.scale_attention, \ + self.is_megatron_v2 def mlp(self, moe_type='standard'): from deepspeed.moe.utils import has_moe_layers @@ -278,7 +287,7 @@ def __init__(self, client_module, inference=True): try: import transformers HFGPT2LayerPolicy._orig_layer_class = transformers.models.gpt2.modeling_gpt2.GPT2Block - except ImportError: + except: HFGPT2LayerPolicy._orig_layer_class = None def get_hidden_heads(self): @@ -291,7 +300,8 @@ def attention(self): self.client_module.attn.c_attn.bias, \ self.client_module.attn.c_proj.weight, \ self.client_module.attn.c_proj.bias, \ - self.scale_attention + self.scale_attention, \ + self.is_megatron_v2 def mlp(self): return self.linear_layer, \ @@ -307,9 +317,62 @@ def layerNorm(self): self.client_module.ln_1.bias +class GPTNEOXLayerPolicy(DSPolicy): + _orig_layer_class = None + version = 0 + + def __init__(self, client_module, inference=True, megatron_v2=True): + super().__init__(inference, megatron_v2=megatron_v2) + self.client_module = client_module + if GPTNEOXLayerPolicy._orig_layer_class is None: + try: + import megatron + from megatron.model.transformer import ParallelTransformerLayerPipe + GPTNEOXLayerPolicy._orig_layer_class = ParallelTransformerLayerPipe + except ImportError: + GPTNEOXLayerPolicy._orig_layer_class = None + + def get_hidden_heads(self): + if GPTNEOXLayerPolicy.version == 0: + attention = self.client_module.attention + else: + attention = self.client_module.self_attention + + return self.client_module.attention.query_key_value.weight.shape[1], \ + self.client_module.attention.num_attention_heads + + def attention(self): + if GPTNEOXLayerPolicy.version == 0: + attention = self.client_module.attention + else: + attention = self.client_module.self_attention + + return self.linear_layer, \ + attention.query_key_value.weight, \ + attention.query_key_value.bias, \ + attention.dense.weight, \ + attention.dense.bias, \ + self.scale_attention, \ + self.is_megatron_v2 + + def mlp(self): + return self.linear_layer, \ + self.client_module.mlp.dense_h_to_4h.weight, \ + self.client_module.mlp.dense_h_to_4h.bias, \ + self.client_module.mlp.dense_4h_to_h.weight, \ + self.client_module.mlp.dense_4h_to_h.bias + + def layerNorm(self): + return self.client_module.post_attention_layernorm.weight, \ + self.client_module.post_attention_layernorm.bias, \ + self.client_module.input_layernorm.weight, \ + self.client_module.input_layernorm.bias + + replace_policies = [ HFBertLayerPolicy, HFGPTNEOLayerPolicy, + GPTNEOXLayerPolicy, HFGPTJLayerPolicy, MegatronLayerPolicy, HFGPT2LayerPolicy, diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index 9acca965c16e..aed03148e919 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -11,7 +11,6 @@ from ... import op_builder import torch.nn as nn import torch.distributed as dist - # Cuda modules will be imported if needed inference_cuda_module = None @@ -67,8 +66,11 @@ def __init__(self, local_attention=False, window_size=256, rotary_dim=-1, + rotate_half=False, + rotate_every_two=True, return_tuple=True, - mlp_after_attn=True): + mlp_after_attn=True, + training_mp_size=1): super(DeepSpeedInferenceConfig, self).__init__( hidden_size, @@ -87,9 +89,12 @@ def __init__(self, self.local_attention = local_attention self.window_size = window_size self.rotary_dim = rotary_dim + self.rotate_half = rotate_half + self.rotate_every_two = rotate_every_two self.return_tuple = return_tuple self.mlp_after_attn = mlp_after_attn self.specialized_mode = False + self.training_mp_size = training_mp_size @classmethod def from_dict(cls, json_object): @@ -131,11 +136,6 @@ def forward(ctx, q_groups, merge_count, qkv_merging): - - #while len(input_mask.shape) < 4: - # input_mask = input_mask.unsqueeze(0) - input_mask = torch.empty(1, device='cuda') - def _transpose_for_scores(x, key=False, reshape=False): attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, @@ -147,13 +147,13 @@ def _transpose_for_scores(x, key=False, reshape=False): x_1 = x_1.permute(0, 2, 1, 3) if reshape: return x_1.reshape(x.shape) - return x_1 + return x_1.contiguous() def _transpose_for_context(x): x = x.permute(0, 2, 1, 3).contiguous() new_x_layer_shape = x.size()[:-2] + \ (hidden_size_per_partition,) - return x.view(*new_x_layer_shape) + return x.view(*new_x_layer_shape).contiguous() def compute_attention(qkv_out, input_mask): score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16) else \ @@ -195,8 +195,9 @@ def compute_attention(qkv_out, input_mask): key_layer, config.rotary_dim, 0 if layer_past is None else layer_past[0].shape[-2], - num_attention_heads_per_partition) - + num_attention_heads_per_partition, + config.rotate_half, + config.rotate_every_two) if layer_past is not None: past_key, past_value = layer_past if unfused_mode: @@ -207,7 +208,6 @@ def compute_attention(qkv_out, input_mask): value_layer), dim=-2) presents = (key_layer, value_layer) - if unfused_mode: mixed_query = _transpose_for_scores(mixed_query, False, True) key_layer = _transpose_for_scores( @@ -215,7 +215,7 @@ def compute_attention(qkv_out, input_mask): True, True) / (norm_factor if config.scale_attention else 1.0) value_layer = _transpose_for_scores(value_layer, False, True) - + #print(f'[{torch.distributed.get_rank()}] {config.layer_id}: {mixed_query.norm()}') if layer_past is None: attn_key_value = score_context_func( mixed_query, @@ -274,8 +274,10 @@ def selfAttention_fp(): norm_b, config.epsilon, (attn_qkvb is not None)) + context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, False) + #print(f'[{torch.distributed.get_rank()}] {config.layer_id}: oooooo -> {output.norm()}') return output, key_layer, value_layer, context_layer, qkv_out[-1] # attn_out, present_key, present_value, context_output, inp_norm @@ -289,7 +291,6 @@ def selfAttention_int8(): (q_groups * (3 if qkv_merging else 1) * (2**merge_count))) else: - #import pdb;pdb.set_trace() qkv_out = inference_cuda_module.qkv_gemm_int8( input, attn_qkvw, @@ -312,8 +313,8 @@ def selfAttention_int8(): output, key_layer, value_layer, context_layer = selfAttention_int8() else: output, key_layer, value_layer, context_layer, inp_norm = selfAttention_fp() - - if mp_group is not None and dist.get_world_size(group=mp_group) > 1: + if config.mlp_after_attn and mp_group is not None and dist.get_world_size( + group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return (output, key_layer, value_layer, context_layer, inp_norm) @@ -409,6 +410,7 @@ class DeepSpeedMLPFunction(Function): def forward(ctx, input, residual, + residual_norm, bias, inter_w, inter_b, @@ -425,6 +427,7 @@ def forward(ctx, fused_gemm_gelu, vector_matmul_func, bias_residual_func): + if config.q_int8: (intermediate, residual_add) = inference_cuda_module.mlp_gemm_int8( @@ -446,7 +449,7 @@ def forward(ctx, (merge_count)) else: if attn_nw is None: - output = fused_gemm_gelu(input, + output = fused_gemm_gelu(residual_norm, inter_w, inter_b, output_w, @@ -454,24 +457,26 @@ def forward(ctx, config.pre_layer_norm, False) else: - (intermediate, - residual_add) = mlp_gemm_func(input, - residual, - bias, - inter_w, - inter_b, - attn_nw, - attn_nb, - config.epsilon, - config.pre_layer_norm) + intermediate = mlp_gemm_func(input, + residual, + bias, + inter_w, + inter_b, + attn_nw, + attn_nb, + config.epsilon, + config.pre_layer_norm, + config.mlp_after_attn) output = vector_matmul_func(intermediate, output_w, False) - + inference_cuda_module.residual_add(output, + residual, + input, + output_b, + bias, + config.mp_size, + config.mlp_after_attn) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group) - - if attn_nw is not None: - output = bias_residual_func(output, residual_add, output_b) - return output @staticmethod @@ -519,9 +524,10 @@ def __init__(self, self.bias_residual_func = inference_cuda_module.bias_residual_fp16 if config.fp16 or config.q_int8 else \ inference_cuda_module.bias_residual_fp32 - def forward(self, input, residual, bias): + def forward(self, input, residual, residual_norm, bias): return DeepSpeedMLPFunction.apply(input, residual, + residual_norm, bias, self.inter_w, self.inter_b, @@ -594,6 +600,7 @@ def __init__(self, self.norm_w = nn.Parameter(torch.Tensor(self.config.hidden_size)) self.norm_b = nn.Parameter(torch.Tensor(self.config.hidden_size)) + self.layer_past = None def forward(self, input, @@ -611,7 +618,12 @@ def forward(self, output_attentions=False): get_present = (get_present or get_key_value or use_cache) input_mask = input_mask if attention_mask is None else attention_mask + layer_past = layer_past if layer_past is not None else self.layer_past + attn_mask = None + if isinstance(input, tuple): + attn_mask = input[1] + input = input[0] input_type = input.dtype if (self.config.fp16 or self.config.q_int8) \ @@ -619,7 +631,8 @@ def forward(self, input = input.half() with torch.no_grad(): - attention_output = self.attention(input, + attention_output, key, value, context_outputtn_ctx, inp_norm = \ + self.attention(input, input_mask, head_mask, layer_past, @@ -629,17 +642,10 @@ def forward(self, output_attentions, self.norm_w, self.norm_b) + presents = (key, value) + self.layer_past = presents - if get_present: - presents = (attention_output[1], attention_output[2]) - elif output_attentions: - context_output = attention_output[3] - - output = self.mlp( - attention_output[0] - if self.config.mlp_after_attn else attention_output[-1], - input, - self.attention.attn_ob) + output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob) if not self.config.pre_layer_norm: ds_layernorm = inference_cuda_module.layer_norm_fp16 if self.config.fp16 or self.config.q_int8 else \ @@ -649,17 +655,13 @@ def forward(self, self.norm_b, self.config.epsilon) - if not self.config.mlp_after_attn: - inference_cuda_module.gptj_residual_add(output, - input, - attention_output[0], - self.mlp.output_b) - output = output.to(input_type) + #print(f'[{torch.distributed.get_rank()}] {self.config.layer_id}: {output.norm()}') + #exit() if get_present: output = (output, presents) if self.config.return_tuple: - return output if type(output) is tuple else (output, ) + return output if type(output) is tuple else (output, attn_mask) else: return output diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 2a96ef897d01..5807fb983ed7 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -74,6 +74,8 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): # We schedule the all-reduces, so disable it in super().backward() self.enable_backward_allreduce = False self.has_bool_tensors = has_bool_tensors + self.eval_return_logits = False + self.outputs = None # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB self.pipeline_enable_backward_allreduce = True @@ -191,6 +193,7 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): if self.is_last_stage(): self.loss_model = self.module.loss_fn + self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe' # Initialize pipeline communicators. Just send a 0. if is_even(self.stage_id): if not self.is_last_stage(): @@ -219,6 +222,10 @@ def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): self.timers('step_microstep').start() self.timers('step_microstep').stop() + def set_has_attention_mask(self, value): + assert isinstance(value, bool) + self.has_attention_mask = value + def _build_data_iter(self, dataset): sampler = torch.utils.data.distributed.DistributedSampler( dataset, @@ -381,7 +388,11 @@ def train_batch(self, data_iter=None): # TODO: should return precisely what loss returned and allow others to be queried? return self.agg_train_loss - def eval_batch(self, data_iter, compute_loss=True, reduce_output='avg'): + def eval_batch(self, + data_iter, + return_logits=False, + compute_loss=True, + reduce_output='avg'): """Evaluate the pipeline on a batch of data from ``data_iter``. The engine will evaluate ``self.train_batch_size()`` total samples collectively across all workers. @@ -408,7 +419,7 @@ def eval_batch(self, data_iter, compute_loss=True, reduce_output='avg'): Returns: The arithmetic mean of the losses computed this batch. """ - + self.eval_return_logits = return_logits self.module.eval() # Curriculum learning could change activation shape @@ -457,7 +468,11 @@ def eval_batch(self, data_iter, compute_loss=True, reduce_output='avg'): # Reset any buffers that may have been populated during the forward passes. #ds_checkpointing.reset() - + self.eval_return_logits = False + if return_logits: + outputs = self.outputs + self.outputs = None + return eval_output, outputs return eval_output def set_train_batch_size(self, train_batch_size): @@ -681,7 +696,8 @@ def _exec_forward_pass(self, buffer_id): else: # Some models just return loss from forward() self.loss = outputs - + if self.eval_return_logits: + self.outputs = outputs if isinstance(self.loss, torch.Tensor): self.fwd_outputs.append(self.loss.detach()) @@ -941,7 +957,7 @@ def _exec_send_activations(self, buffer_id): # NCCL does not like to send torch.BoolTensor types, so cast the mask to half(). # We could do char, but with half() we can eventually flatten with other fp16 # messages (TODO) - if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors: + if self.has_attention_mask or self.has_bool_tensors: outputs = list(outputs) outputs[-1] = outputs[-1].half() outputs = tuple(outputs) @@ -960,7 +976,7 @@ def _exec_send_activations(self, buffer_id): f'{type(outputs)}') # Restore the boolean tensor - if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors: + if self.has_attention_mask or self.has_bool_tensors: outputs = list(outputs) outputs[-1] = outputs[-1].bool() outputs = tuple(outputs) @@ -998,7 +1014,7 @@ def _exec_send_grads(self, buffer_id): # a grad that needs to be communicated. We free the buffer immediately # after, so no need to restore it. The receiver also has a hack that skips # the recv. This is because NCCL does not let us send torch.BoolTensor :-(. - if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors: + if self.has_attention_mask or self.has_bool_tensors: inputs = list(inputs) inputs.pop() inputs = tuple(inputs) @@ -1059,7 +1075,7 @@ def _exec_recv_activations(self, buffer_id): # NCCL does not like to send torch.BoolTensor types, so un-cast the # attention mask - if self.module.__class__.__name__ == 'GPT2ModelPipe' or self.has_bool_tensors: + if self.has_attention_mask or self.has_bool_tensors: recvd[-1] = recvd[-1].bool() recvd = tuple(recvd) @@ -1323,6 +1339,7 @@ def load_module_state_dict(self, state_dict, strict=True): state_dict (str, None): unused strict (bool, optional): Strict state loading. Defaults to True. """ + if (state_dict is not None) and (not isinstance(state_dict, str)): super().load_module_state_dict(state_dict, strict) return diff --git a/deepspeed/runtime/pipe/module.py b/deepspeed/runtime/pipe/module.py index 3efbc62e2b8d..4121a0f8fd3f 100644 --- a/deepspeed/runtime/pipe/module.py +++ b/deepspeed/runtime/pipe/module.py @@ -189,6 +189,7 @@ def forward(self, inputs): self._partition_layers(method=partition_method) self.forward_funcs = [] + self.fwd_map = {} self.tied_modules = nn.ModuleDict() self.tied_weight_attrs = {} @@ -225,6 +226,7 @@ def _build(self): elif isinstance(layer, nn.Module): name = str(layer_idx) self.forward_funcs.append(layer) + self.fwd_map.update({name: len(self.forward_funcs) - 1}) self.add_module(name, layer) # TiedLayerSpec objects contain an nn.Module that should be allocated now. @@ -248,6 +250,7 @@ def _build(self): module = layer.build() name = str(layer_idx) self.forward_funcs.append(module) + self.fwd_map.update({name: len(self.forward_funcs) - 1}) self.add_module(name, module) # Last option: layer may be a functional (e.g., lambda). We do nothing in diff --git a/deepspeed/runtime/state_dict_factory.py b/deepspeed/runtime/state_dict_factory.py index e5c4dafb638b..09887aaa275c 100755 --- a/deepspeed/runtime/state_dict_factory.py +++ b/deepspeed/runtime/state_dict_factory.py @@ -371,7 +371,7 @@ def split_state_dict(self, quantize_bits=8, groups=64, mlp_extra_grouping=True): - self.sanity_check(self.ckpt_list[0]) + #self.sanity_check(self.ckpt_list[0]) sd, num_to_split, ckpt_offset = self.get_split_state_dict(mp_world_size, mp_rank) ds_sd = copy.deepcopy(sd) @@ -405,7 +405,7 @@ def split_state_dict(self, num_to_split, ckpt_offset, ckpt_ver) - elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key: + elif "mlp.dense_h_to_4h.weight" in key or "word_embeddings.weight" in key or "mlp.dense_h_to_4h.bias" in key or "final_linear.weight" in key: assert value.shape[0] % num_to_split == 0 split_size = value.shape[0] // num_to_split if quantize and "mlp.dense_h_to_4h.weight" in key: diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 5c3b67bd6b65..80b1ee34bcec 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -23,6 +23,7 @@ from .linear import LinearModuleForZeroStage3, LinearFunctionForZeroStage3 from .offload_constants import * +import deepspeed from ..utils import get_only_unique_item, see_memory_usage from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks from deepspeed.utils import init_distributed, instrument_w_nvtx, logger @@ -30,7 +31,6 @@ from deepspeed.utils.logging import logger from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus -from ..config import DeepSpeedConfig param_count = 0 partitioned_param_data_shape = [0] @@ -663,8 +663,9 @@ def get_model(): f'zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.' ) - _ds_config = DeepSpeedConfig(config_dict_or_path, - mpu) if config_dict_or_path is not None else None + _ds_config = deepspeed.runtime.config.DeepSpeedConfig( + config_dict_or_path, + mpu) if config_dict_or_path is not None else None super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, diff --git a/op_builder/transformer.py b/op_builder/transformer.py index 40d6f4944557..239f29552d98 100644 --- a/op_builder/transformer.py +++ b/op_builder/transformer.py @@ -15,6 +15,12 @@ def __init__(self, name=None): def absolute_name(self): return f'deepspeed.ops.transformer.{self.NAME}_op' + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] + def sources(self): return [ 'csrc/transformer/ds_transformer_cuda.cpp', diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index bfae87825b01..23eab4886e80 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -22,5 +22,11 @@ def sources(self): 'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu', ] + def extra_ldflags(self): + if not self.is_rocm_pytorch(): + return ['-lcurand'] + else: + return [] + def include_paths(self): return ['csrc/transformer/inference/includes']