Skip to content

Commit

Permalink
gpu: intel: ocl: gemm_with_post_ops: move c_scales to io
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Jan 24, 2025
1 parent 933fcbd commit ca82718
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/gpu/intel/ocl/gemm/gemm_with_post_ops.cl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ __kernel void gemm_post_ops(__global SRC_DATA_T *src,
APPLY_POST_OPS_SERIAL(accumulator, POST_OP_DATA_T, sum_src,
POST_OP_DATA_T, d0, 1, d1, 1, d2, 1, d3, 1, 0, 1, 0, 1);

if (C_SCALES) accumulator /= DST_SCALES_TO_REF(c_scales[0]);
float c_scale = 1;
if (C_SCALES) {
load(&c_scale, c_scales);
accumulator /= c_scale;
}
if (DST_ZERO_POINT) accumulator += dst_zp[0];
}

Expand Down
4 changes: 2 additions & 2 deletions src/gpu/intel/ocl/gemm/gemm_with_post_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ status_t gemm_with_post_ops_t::pd_t::init_kernel_ctx(
def_data_type(kernel_ctx, acc_type, "ACC");

kernel_ctx.define_int("NDIMS", ndims);
CHECK(def_attr_info(
kernel_ctx, attr_info_, attr()->post_ops_, *gemm_pd_->dst_md()));
CHECK(def_attr_info(kernel_ctx, attr_info_, attr()->post_ops_,
*gemm_pd_->dst_md(), false));
kernel_ctx.define_int("A_SCALES", with_src_scales);
kernel_ctx.define_int("B_SCALES", with_wei_scales);
kernel_ctx.define_int("C_SCALES", with_dst_scales);
Expand Down

0 comments on commit ca82718

Please sign in to comment.