Skip to content

Commit

Permalink
Fix gmem to smem WAW conflict in awq gemm kernel (#2111)
Browse files Browse the repository at this point in the history
  • Loading branch information
foreverrookie authored Jul 30, 2024
1 parent cd0b6d8 commit fed65b1
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/turbomind/kernels/gemm_s_f16/cta_iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,9 @@ struct IteratorB {
int iter_k_{0};
int iter_n_{0};

// upper bound N
int upper_n_;

IteratorB() = default;

__device__ IteratorB(const half* src, void* smem, int k, int n, int cta_n, int cta_k, int warp_id, int lane_id):
Expand Down Expand Up @@ -557,7 +560,10 @@ struct IteratorB {
tmp_src_offset_ = src_offset_;
tmp_dst_offset_ = dst_offset_;
tmp_src_offset_n_ = src_offset_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;

// avoid (global mem -> shared mem) WAW(write after write) conflict
upper_n_ = std::min(cta_n_ + CTA_N, n_);
is_valid_n_ = tmp_src_offset_n_ < upper_n_;
}

__device__ void prefetch_stage(bool mask)
Expand Down Expand Up @@ -601,7 +607,7 @@ struct IteratorB {
tmp_src_offset_n_ += kWarpAccessN;
tmp_src_offset_ += src_step_n_;
tmp_dst_offset_ += dst_step_n_;
is_valid_n_ = tmp_src_offset_n_ < n_;
is_valid_n_ = tmp_src_offset_n_ < upper_n_;
++iter_n_;

return *this;
Expand All @@ -621,7 +627,7 @@ struct IteratorB {
tmp_dst_offset_ = dst_offset_;
tmp_src_offset_n_ = src_offset_n_;

is_valid_n_ = tmp_src_offset_n_ < n_;
is_valid_n_ = tmp_src_offset_n_ < upper_n_;
}

__device__ void prefetch(bool mask)
Expand Down

0 comments on commit fed65b1

Please sign in to comment.