Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bwinline: Forceinline all functions of the bw #453

Merged
merged 1 commit into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ struct TypeTraits<float> {
};

template <typename integer>
constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) {
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
return (n + m - 1) / m;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct AttentionBackwardKernel {
int32_t num_batches;
bool causal;

__device__ void advance_batches(int32_t batch_id) {
CUTLASS_DEVICE void advance_batches(int32_t batch_id) {
constexpr int32_t kAlignLSE = 32; // block size of backward
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;

Expand All @@ -110,6 +110,23 @@ struct AttentionBackwardKernel {
grad_query_ptr += batch_id * head_dim * num_queries;
grad_key_ptr += batch_id * head_dim * num_keys;
grad_value_ptr += batch_id * head_dim_value * num_keys;

head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);

query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
output_ptr = warp_uniform(output_ptr);
grad_output_ptr = warp_uniform(grad_output_ptr);
delta_ptr = warp_uniform(delta_ptr);

grad_query_ptr = warp_uniform(grad_query_ptr);
grad_key_ptr = warp_uniform(grad_key_ptr);
grad_value_ptr = warp_uniform(grad_value_ptr);
Comment on lines +114 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note to self: this trick is needed otherwise there would be a performance slowdown

}

__host__ dim3 getBlocksGrid() const {
Expand Down Expand Up @@ -600,7 +617,7 @@ struct AttentionBackwardKernel {
typename MatmulGradV::Mma::FragmentC gradV;
typename MatmulGradK::Mma::FragmentC gradK;

__device__ __forceinline__ void clear() {
CUTLASS_DEVICE void clear() {
gradV.clear();
gradK.clear();
}
Expand All @@ -620,7 +637,7 @@ struct AttentionBackwardKernel {
"value is not correctly aligned");
}

static __device__ void kernel(Params& p_) {
static CUTLASS_DEVICE void kernel(Params& p_) {
// Hint to nvcc to store points & tensor shapes in registers
// as we use them a lot
register const Params p = p_;
Expand Down Expand Up @@ -657,7 +674,7 @@ struct AttentionBackwardKernel {
__syncthreads();
}

OutputFragments output_frags;
OutputFragments register output_frags;
int32_t key_start = 0;
int32_t key_end = p.num_keys / kBlockSizeJ * kBlockSizeJ;
for (; key_start < key_end; key_start += kBlockSizeJ) {
Expand Down Expand Up @@ -694,7 +711,7 @@ struct AttentionBackwardKernel {
}
}

static __device__ __forceinline__ void loadDi(
static CUTLASS_DEVICE void loadDi(
cutlass::Array<accum_t, kBlockSizeI>& di,
Params const& p,
int32_t query_start) {
Expand All @@ -709,17 +726,17 @@ struct AttentionBackwardKernel {
}

template <bool skipBoundsChecks>
static __device__ __forceinline__ void processBlockIJ(
static CUTLASS_DEVICE void processBlockIJ(
SharedStorage& shared_storage,
OutputFragments& output_frags,
Params const& p,
int32_t query_start,
int32_t key_start) {
cutlass::MatrixCoord no_offset{0, 0};
accum_t scale = accum_t(1.0 / std::sqrt(float(p.head_dim)));
int32_t thread_id = threadIdx.x + threadIdx.y * blockDim.x;
int32_t warp_id = threadIdx.y;
int32_t lane_id = threadIdx.x;
int16_t thread_id = threadIdx.x + threadIdx.y * blockDim.x;
int8_t warp_id = warp_uniform(threadIdx.y);
int8_t lane_id = threadIdx.x;
__syncthreads();
loadDi(shared_storage.di(), p, query_start);

Expand Down Expand Up @@ -1258,7 +1275,7 @@ struct AttentionBackwardKernel {
}

template <bool skipBoundsChecks>
static __device__ __forceinline__ void writeFragsToGmem(
static CUTLASS_DEVICE void writeFragsToGmem(
SharedStorage& shared_storage,
OutputFragments& output_frags,
Params const& p,
Expand Down Expand Up @@ -1291,7 +1308,7 @@ struct AttentionBackwardKernel {
}

template <typename MatmulT>
static __device__ __forceinline__ void accumulateInGmem(
static CUTLASS_DEVICE void accumulateInGmem(
typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem,
typename MatmulT::Mma::FragmentC const& accum,
typename MatmulT::OutputTileIterator output_it,
Expand Down Expand Up @@ -1333,7 +1350,9 @@ struct AttentionBackwardKernel {
}

template <int kElementsPerAccess>
static __device__ void computeDelta(Params const& p, int32_t query_start) {
static CUTLASS_DEVICE void computeDelta(
Params const& p,
int32_t query_start) {
// Each thread computes one value for Delta
// Depending on warp configuration, we might have multiple
// threads of the same warp working on the same row
Expand Down Expand Up @@ -1429,13 +1448,13 @@ struct AttentionBackwardKernel {
}
}

static __device__ __forceinline__ int8_t get_lane_id() {
static CUTLASS_DEVICE int8_t get_lane_id() {
return threadIdx.x;
}
static __device__ __forceinline__ int8_t get_warp_id() {
static CUTLASS_DEVICE int8_t get_warp_id() {
return threadIdx.y;
}
static __device__ __forceinline__ int16_t get_thread_id() {
static CUTLASS_DEVICE int16_t get_thread_id() {
return threadIdx.x + threadIdx.y * blockDim.x;
}
};
Expand Down