Skip to content

Commit

Permalink
fix illegal address bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zhang rui authored and ruizhang1230 committed Jan 26, 2024
1 parent 4b127f9 commit 7376839
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.k_ptr + bidb * params.k_batch_stride) + row_offset_k),
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.k_ptr + params.slot_m[bidb] * params.k_batch_stride) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.v_ptr + bidb * params.v_batch_stride) + row_offset_v),
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.v_ptr + params.slot_m[bidb] * params.v_batch_stride) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
Expand Down Expand Up @@ -673,11 +673,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.k_ptr + bidb * params.k_batch_stride) + row_offset_k),
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.k_ptr + params.slot_m[bidb] * params.k_batch_stride) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.v_ptr + bidb * params.v_batch_stride) + row_offset_v),
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>((char*)params.v_ptr + params.slot_m[bidb] * params.v_batch_stride) + row_offset_v),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.v_row_stride, _1{}));

Expand Down

0 comments on commit 7376839

Please sign in to comment.