diff --git a/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp b/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp index e1a0b2411973c..9065e24d0cb5e 100644 --- a/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_bf16_matA_impl.hpp @@ -117,7 +117,7 @@ void matrix_sum_rows(queue q, big_matrix &A, nd_range<2> &r) { sub_a; joint_matrix_load( - sg, sub_a, accA.template get_multi_ptr() + (global_idx * TM * K) + TK, + sg, sub_a, accA.template get_multi_ptr() + (sg_startx * TM * K) + sg_starty / SG_SZ * TK, K); // calculate sum of rows in sum_rows_v[8], there are 8 rows in sub_a @@ -175,7 +175,7 @@ int main() { for (int i = 0; i < MATRIX_M; i++) { for (int j = 0; j < MATRIX_K; j++) { - A[i][j] = i; + A[i][j] = i + j; } } diff --git a/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp b/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp index 76a8968239ced..b86b8c89dfe71 100644 --- a/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp +++ b/sycl/test-e2e/Matrix/get_coord_bf16_matB_impl.hpp @@ -142,7 +142,7 @@ void matrix_sum_cols(queue q, big_matrix &B, nd_range<2> &r) { joint_matrix_load( sg, sub_b, accB.template get_multi_ptr() + - (global_idx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, + (sg_startx * (TK / 4) * N) + sg_starty / SG_SZ * TN * 4, N); int32_t sum_local_cols[N] = {0}; // 4 local cols, N total @@ -207,7 +207,7 @@ int main() { for (int i = 0; i < MATRIX_K; i++) { for (int j = 0; j < MATRIX_N; j++) { - B[i][j] = i; + B[i][j] = i + j; } }