Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

xe: sdpa: fix several errors related to out of bound accesses in sdpa #2497

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions src/gpu/intel/jit/gemm/generator/microkernel_provider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ static inline bool getStrategyByHeuristics(HW hw, GEMMStrategy &strategy, bool l
return false;

s.systolic = systolic;
if (systolic && hw >= HW::XeHPC)
s.extendedAtomicFMA = s.atomicFMA = true;
s.registerScheme = GEMMStrategy::VAvoid;
if (s.wgTile(LoopM) * s.wgTile(LoopN) > 512)
s.GRFs = 256;
Expand Down
44 changes: 32 additions & 12 deletions src/gpu/intel/jit/gemm/generator/pieces/k_loop.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ void BLASKernelGenerator<hw>::kLoop(KLoop type, const GEMMProblem &problem, GEMM
auto Ta = problem.Ta, Tb = problem.Tb, Tc = problem.Tc;
auto Ta_ext = problem.Ta_ext, Tb_ext = problem.Tb_ext;
auto Ta_load = state.Ta_load, Tb_load = state.Tb_load;
auto Tao = problem.Tao, Tbo = problem.Tbo;
auto Ta_scale = problem.Ta_scale, Tb_scale = problem.Tb_scale;

bool cLoadAhead = strategy.cLoadAhead;
auto opCountMain = outerProductCount(hw, problem, strategy);
Expand Down Expand Up @@ -667,24 +669,24 @@ void BLASKernelGenerator<hw>::kLoop(KLoop type, const GEMMProblem &problem, GEMM
// Quantization parameter address increment helpers.
auto doIncAq = [&](Iteration h) {
auto kaInc = kInc(h, state.kaqStride, problem.aqGroupK);
if (ao2D) incAddrK(problem.Tao, state.A_offsetAddrs, true, kaInc, state.ldao, state.ldaoIncrements, state.A_offsetLayout, problem.AO, strategy.AO, strategy, state);
if (as2D) incAddrK(problem.Ta_scale, state.A_scaleAddrs, true, kaInc, state.ldaScale, state.ldasIncrements, state.A_scaleLayout, problem.A_scale, strategy.A_scale, strategy, state);
if (ao2D) incAddrK(Tao, state.A_offsetAddrs, true, kaInc, state.ldao, state.ldaoIncrements, state.A_offsetLayout, problem.AO, strategy.AO, strategy, state);
if (as2D) incAddrK(Ta_scale, state.A_scaleAddrs, true, kaInc, state.ldaScale, state.ldasIncrements, state.A_scaleLayout, problem.A_scale, strategy.A_scale, strategy, state);
};

auto doIncBq = [&](Iteration h) {
auto kbInc = kInc(h, state.kbqStride, problem.bqGroupK);
if (bo2D) incAddrK(problem.Tbo, state.B_offsetAddrs, false, kbInc, state.ldbo, state.ldboIncrements, state.B_offsetLayout, problem.BO, strategy.BO, strategy, state);
if (bs2D) incAddrK(problem.Tb_scale, state.B_scaleAddrs, false, kbInc, state.ldbScale, state.ldbsIncrements, state.B_scaleLayout, problem.B_scale, strategy.B_scale, strategy, state);
if (bo2D) incAddrK(Tbo, state.B_offsetAddrs, false, kbInc, state.ldbo, state.ldboIncrements, state.B_offsetLayout, problem.BO, strategy.BO, strategy, state);
if (bs2D) incAddrK(Tb_scale, state.B_scaleAddrs, false, kbInc, state.ldbScale, state.ldbsIncrements, state.B_scaleLayout, problem.B_scale, strategy.B_scale, strategy, state);
};

auto doIncAqLate = [&](Iteration h) {
auto kaInc = kInc(h, state.kaqLate, problem.aqGroupK);
incAddrK(problem.Ta_scale, state.A_scaleAddrs, true, kaInc, state.ldaScale, state.ldasIncrements, state.A_scaleLayout, problem.A_scale, strategy.A_scale, strategy, state);
incAddrK(Ta_scale, state.A_scaleAddrs, true, kaInc, state.ldaScale, state.ldasIncrements, state.A_scaleLayout, problem.A_scale, strategy.A_scale, strategy, state);
};

auto doIncBqLate = [&](Iteration h) {
auto kbInc = kInc(h, state.kbqLate, problem.bqGroupK);
incAddrK(problem.Tb_scale, state.B_scaleAddrs, false, kbInc, state.ldbScale, state.ldbsIncrements, state.B_scaleLayout, problem.B_scale, strategy.B_scale, strategy, state);
incAddrK(Tb_scale, state.B_scaleAddrs, false, kbInc, state.ldbScale, state.ldbsIncrements, state.B_scaleLayout, problem.B_scale, strategy.B_scale, strategy, state);
};

// SLM quantization parameter address increment.
Expand Down Expand Up @@ -875,24 +877,42 @@ void BLASKernelGenerator<hw>::kLoop(KLoop type, const GEMMProblem &problem, GEMM
auto reqRepackAqLate = every(kaq_loadLate);
auto reqRepackBqLate = every(kbq_loadLate);

bool remaskAs = as2D && (minOPCount > 1) && (problem.aqGroupK == 1);
bool remaskBs = bs2D && (minOPCount > 1) && (problem.bqGroupK == 1);
int iremaskScale = 2;
if (dequantize2DA) ls.schedule(reqRepackAq, [&](Iteration h) {
if (ao2D) gemmRepack2DOffsetData(Ta_ext, problem.Tao, state.Tao_int, state.A_offsetLayout, state.Ar_offsetLayout, state.A_offsetRegs, state.Ar_offsetRegs, problem, strategy, state);
if (as2D) gemmRepack2DQuantizationData(problem.Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state);
if (remaskAs) {
int ms, ks;
getLayoutDims(state.A_scaleLayout, ms, ks);
setupTeardownRemask(Ta_scale, iremaskScale, true, ks, state.K, strategy, state, -h.counterOffset());
remaskLayout(Ta_scale, iremaskScale, true, state.A_scaleLayout, state.A_scaleRegs, strategy, state, h % ks);
setupTeardownRemask(Ta_scale, iremaskScale, false, ks, state.K, strategy, state);
}
if (ao2D) gemmRepack2DOffsetData(Ta_ext, Tao, state.Tao_int, state.A_offsetLayout, state.Ar_offsetLayout, state.A_offsetRegs, state.Ar_offsetRegs, problem, strategy, state);
if (as2D) gemmRepack2DQuantizationData(Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state);
});

if (dequantize2DB) ls.schedule(reqRepackBq, [&](Iteration h) {
if (bo2D) gemmRepack2DOffsetData(Tb_ext, problem.Tbo, state.Tbo_int, state.B_offsetLayout, state.Br_offsetLayout, state.B_offsetRegs, state.Br_offsetRegs, problem, strategy, state);
if (bs2D) gemmRepack2DQuantizationData(problem.Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state);
if (remaskBs) {
int ks, ns;
getLayoutDims(state.B_scaleLayout, ks, ns);
setupTeardownRemask(Tb_scale, iremaskScale, true, ks, state.K, strategy, state, -h.counterOffset());
remaskLayout(Tb_scale, iremaskScale, false, state.B_scaleLayout, state.B_scaleRegs, strategy, state, h % ks);
setupTeardownRemask(Tb_scale, iremaskScale, false, ks, state.K, strategy, state);
}
if (bo2D) gemmRepack2DOffsetData(Tb_ext, Tbo, state.Tbo_int, state.B_offsetLayout, state.Br_offsetLayout, state.B_offsetRegs, state.Br_offsetRegs, problem, strategy, state);
if (bs2D) gemmRepack2DQuantizationData(Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state);
});

if (as2DLate) ls.schedule(reqRepackAqLate, [&](Iteration h) {
gemmRepack2DQuantizationData(problem.Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state);
gemmRepack2DQuantizationData(Ta_scale, state.Ta_scaleOp, state.A_scaleLayout, state.Ar_scaleLayout, state.A_scaleRegs, state.Ar_scaleRegs, problem, strategy, state);
});

if (bs2DLate) ls.schedule(reqRepackBqLate, [&](Iteration h) {
gemmRepack2DQuantizationData(problem.Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state);
gemmRepack2DQuantizationData(Tb_scale, state.Tb_scaleOp, state.B_scaleLayout, state.Br_scaleLayout, state.B_scaleRegs, state.Br_scaleRegs, problem, strategy, state);
});


// A/B repacking.
auto reqRepackA = every(ka_repackMain)
| variants(A_copies);
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/intel/jit/gemm/generator/pieces/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ struct CommonState {
ngen::FlagRegister flagSwizzle;
EmulationState emulate;
ngen::GRFRange eatomicAddRegs[2];
ngen::GRFRange remaskRegs[2];
ngen::GRFRange remaskRegs[3];
VirtualFlag vflagEAtomicAdd;
VirtualFlag blockEMask;
ngen::Label blockDone;
Expand Down
18 changes: 14 additions & 4 deletions src/gpu/intel/ocl/micro_sdpa.cl
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,13 @@ DECLARE_2D_TILE(
#define mask_nbc ugemm_kq_c_type_nblock1
#endif

DECLARE_2D_TILE(kmask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_sg_tile_m,
1, 1, 1)

#if WITH_ATTN_MASK
DECLARE_2D_TILE(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, mask_bc,
mask_nbr, mask_nbc)

#if BROADCAST_MASK_Q
DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br,
mask_bc, mask_nbr, mask_nbc)
Expand Down Expand Up @@ -143,8 +147,14 @@ DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0,

DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1, mask_tile_type_float, SUBGROUP_SIZE,
ugemm_kq_c_type_nblock1, kmask_tile_type_float, SUBGROUP_SIZE,
ugemm_kq_sg_tile_m, 1, 1, 1)
#if WITH_ATTN_MASK
DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1, mask_tile_type_float, SUBGROUP_SIZE, mask_br,
mask_bc, mask_nbr, mask_nbc)
#endif

DECLARE_2D_TILE_HREDUCE(a_tile_type, SUBGROUP_SIZE, ugemm_vs_c_type_block0,
ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0,
Expand Down Expand Up @@ -214,9 +224,11 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,

#if KEY_SCALES || KEY_ZERO_POINTS
uint ldkq = KEY_D3;
uint num_key_groups = d / KEY_GROUP_SIZE;
#endif
#if VAL_SCALES || VAL_ZERO_POINTS
uint ldvq = div_up(d, VAL_GROUP_SIZE);
uint num_val_groups = d / VAL_GROUP_SIZE;
#endif

/* Subgroup IDs for each GEMM */
Expand Down Expand Up @@ -263,7 +275,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,

#if KEY_SCALES
K_scales += KEY_OFF(b1, b0_kv, 0, 0) / KEY_GROUP_SIZE;
uint num_key_groups = d / KEY_GROUP_SIZE;
#endif
#if KEY_SCALES == QUANTIZE_COMMON
float k_scale = KEY_SCALES_TO_FLOAT(*K_scales);
Expand All @@ -274,7 +285,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
#endif
#if VAL_SCALES
V_scales += VAL_OFF(b1, b0_kv, 0, 0) / VAL_GROUP_SIZE;
uint num_val_groups = d / VAL_GROUP_SIZE;
#endif
#if VAL_SCALES == QUANTIZE_COMMON
float v_scale = VAL_SCALES_TO_FLOAT(*V_scales);
Expand Down Expand Up @@ -410,7 +420,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,

#if REMAINDER_K
/* Prepare k mask: NaN in bounds, -inf out of bounds */
mask_tile_type_float k_mask;
kmask_tile_type_float k_mask;
#pragma unroll
for (int ii = 0; ii < ugemm_kq_sg_tile_m / SUBGROUP_SIZE; ii++)
k_mask.x[0][ii] = (k0 + sg_i0_kq + ii * SUBGROUP_SIZE
Expand Down
12 changes: 6 additions & 6 deletions src/gpu/intel/ocl/micro_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ sdpa_config_t xehpc_h128_2nd = {32, 32, 32, 16, 8, 1, 4, 2};
sdpa_config_t xehpc_q_h128 = {16, 64, 16, 32, 16, 2, 8, 4};
sdpa_config_t xehpc_q_h128_s64 = {16, 16, 32, 16, 4, 4, 4, 4};
sdpa_config_t xehpc_q_h128_s32 = {16, 16, 32, 16, 4, 2, 4, 2};
sdpa_config_t xehpc_q_h128_2nd = {32, 32, 16, 32, 4, 1, 4, 1};
sdpa_config_t xehpc_q_h128_s32_2nd = {16, 32, 16, 16, 8, 1, 4, 2};
sdpa_config_t xehpc_q_h128_2nd = {32, 16, 32, 16, 4, 1, 4, 1};
sdpa_config_t xehpc_q_h128_s32_2nd = {16, 32, 32, 16, 8, 1, 4, 2};

sdpa_config_t xehpc_h256 = {16, 32, 32, 32, 8, 4, 8, 4};
sdpa_config_t xehpc_h256_s64 = {16, 32, 32, 32, 8, 1, 8, 1};
Expand Down Expand Up @@ -236,8 +236,8 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
/* Retrieve pre-tuned kernel configuration */
sdpa_config_t *config = nullptr;
bool thin_q = (d->queries() <= 16);
bool quantized = types::is_integral_dt(key_md()->data_type)
|| types::is_integral_dt(val_md()->data_type);
bool quantized = with_key_scales() || with_key_zp() || with_value_scales()
|| with_value_zp();

switch (arch_) {
case arch_t::xe_hpg:
Expand Down Expand Up @@ -379,7 +379,6 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
if (with_value_zp()) {
auto zp_dt = value_zp_dt();
problem_vs.Tao = jit::convert_dnnl_to_kernel_type(zp_dt);
problem_vs.AO.alignment = uint8_t(types::data_type_size(zp_dt));
problem_vs.AO.setAlignment(uint8_t(d->head_size() / value_group_size()
* types::data_type_size(zp_dt)));
problem_vs.AO.layout = MatrixLayout::N;
Expand Down Expand Up @@ -539,7 +538,8 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
kernel_ctx.define_int(
"BROADCAST_MASK_Q", msk_mdw.dims()[pd_t::mask_q_index] == 1);

kernel_ctx.define_int("WITH_CAUSAL_MASK", pd()->with_causal_mask());
kernel_ctx.define_int(
"WITH_CAUSAL_MASK", pd()->with_causal_mask() && (d->queries() > 1));

kernel_ctx.define_int("SUBGROUP_SIZE", pd()->sg_size());
kernel_ctx.define_int("D_MAX", pd()->d_max());
Expand Down
20 changes: 18 additions & 2 deletions tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,24 @@
# phi3-mini-4k-instruct
--reset
--dt=0:s8+2:s8+6:s8+8:s8
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+5:1x1x384x384+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1
--op-attrs=0:group_shape:1x1x96x1+99:group_shape:1x1x1x96 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
--in-shapes=0:1x32x96x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x96+5:1x1x384x384+6:1x32x384x96+7:1x32x384x1+8:1x32x384x1,\
0:1x32x96x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x96+5:1x1x385x385+6:1x32x385x96+7:1x32x385x1+8:1x32x385x1,\
0:1x32x96x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x96+5:1x1x512x512+6:1x32x512x96+7:1x32x512x1+8:1x32x512x1,\
0:1x32x96x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x96+5:1x1x513x513+6:1x32x513x96+7:1x32x513x1+8:1x32x513x1,\
0:1x32x96x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x96+5:1x1x1024x1024+6:1x32x1024x96+7:1x32x1024x1+8:1x32x1024x1,\
0:1x32x96x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x96+5:1x1x1025x1025+6:1x32x1025x96+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x96x1+99:group_shape:1x1x1x96
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json

# llama-2-7b-chat
--in-shapes=0:1x32x128x384*abdc+1:1x32x1x384+2:1x32x1x384+3:1x32x384x128+5:1x1x384x384+6:1x32x384x128+7:1x32x384x1+8:1x32x384x1,\
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better to reset dt here, to align with other tests, and avoid misunderstanding.

0:1x32x128x385*abdc+1:1x32x1x385+2:1x32x1x385+3:1x32x1x128+5:1x1x385x385+6:1x32x385x128+7:1x32x385x1+8:1x32x385x1,\
0:1x32x128x512*abdc+1:1x32x1x512+2:1x32x1x512+3:1x32x512x128+5:1x1x512x512+6:1x32x512x128+7:1x32x512x1+8:1x32x512x1,\
0:1x32x128x513*abdc+1:1x32x1x513+2:1x32x1x513+3:1x32x1x128+5:1x1x513x513+6:1x32x513x128+7:1x32x513x1+8:1x32x513x1,\
0:1x32x128x1024*abdc+1:1x32x1x1024+2:1x32x1x1024+3:1x32x1024x128+5:1x1x1024x1024+6:1x32x1024x128+7:1x32x1024x1+8:1x32x1024x1,\
0:1x32x128x1025*abdc+1:1x32x1x1025+2:1x32x1x1025+3:1x32x1x128+5:1x1x1025x1025+6:1x32x1025x128+7:1x32x1025x1+8:1x32x1025x1
--op-attrs=0:group_shape:1x1x128x1+99:group_shape:1x1x1x128
--case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json

# 0: key, 2: key zps, 6: value, 8: value zps. Change them to use s8 data type.
--reset --dt=0:s8+2:s8+6:s8+8:s8 --case=complex_fusion/mha/sdpa-compressed-kv-int4-gs32.json
Expand Down
Loading