Skip to content

Commit

Permalink
Update unittest.
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Sep 27, 2024
1 parent 1d45b4d commit 7837845
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 23 deletions.
5 changes: 2 additions & 3 deletions xla/service/gpu/transforms/gemm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1254,9 +1254,8 @@ class GemmRewriterVisitor : public DfsHloRewriteVisitor {
} else {
dim_nums->set_lhs_contracting_dimensions(0, num_batch_dims);
}
a.fp8_input =
TransposeMatrix(a.fp8_input, a_contracting_dims[0],
a_batch_dims, /*col_maj*/true);
a.fp8_input = TransposeMatrix(a.fp8_input, a_contracting_dims[0],
a_batch_dims, /*col_maj*/ true);
}

// Similarly, cuBLASLt requires the second operand to be column-major, so
Expand Down
38 changes: 18 additions & 20 deletions xla/service/gpu/transforms/gemm_rewriter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5034,19 +5034,17 @@ TEST_P(ParameterizedFp8GemmRewriteTest, UnscaledABUnscaledDMatrixBiasF8) {

TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) {
const char* hlo_text = R"(
HloModule test
HloModule test
ENTRY test {
x = <<F8E4M3>>[16,32]{0,1} parameter(0)
y = <<F8E4M3>>[32,16]{1,0} parameter(1)
x = <<F8E4M3>>[2,16,32]{1,0,2} parameter(0)
y = <<F8E4M3>>[2,32,16]{2,1,0} parameter(1)
x_scale = f32[] parameter(2)
y_scale = f32[] parameter(3)
dq_scale = f32[] multiply(x_scale, y_scale)
dq_scale_bcast = f32[16,16] broadcast(dq_scale), dimensions={}
out = f32[16,16] dot(x, y), lhs_contracting_dims={1}, rhs_contracting_dims={0}
ROOT out.1 = f32[16,16] multiply(out, dq_scale_bcast)
dq_scale_bcast = f32[2,16,16] broadcast(dq_scale), dimensions={}
out.0 = f32[2,16,16] dot(x, y), lhs_contracting_dims={2}, rhs_contracting_dims={1}, lhs_batch_dims={0}, rhs_batch_dims={0}
ROOT out = f32[2,16,16] multiply(out.0, dq_scale_bcast)
}
)";

CheckFp8IfSupported(hlo_text);
Expand All @@ -5055,28 +5053,28 @@ TEST_P(ParameterizedFp8GemmRewriteTest, ScaledABUnscaledDColMajorLhsF8) {
GemmRewriter(CudaHopperOrRocmMI300(), GetToolkitVersion(),
GemmRewriterOptions{GemmRewriterOptions::DType::kFp8Only}),
R"(
; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[16,32], {{.*}}: <<F8E4M3>>[32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[16,32]{0,1} parameter(0)
; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} bitcast([[P0]])
; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P0_BT]]), dimensions={1,0}
; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <<F8E4M3>>[32,16]{0,1} bitcast([[P0_TR]])
; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[32,16]{1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[16,32]{1,0} transpose([[P1]]), dimensions={1,0}
; CHECK-LABEL: ENTRY %test ({{.*}}: <<F8E4M3>>[2,16,32], {{.*}}: <<F8E4M3>>[2,32,16], {{.*}}: f32[], {{.*}}: f32[]) -> f32[2,16,16] {
; CHECK-NEXT: [[P0:%[^ ]+]] = <<F8E4M3>>[2,16,32]{1,0,2} parameter(0)
; CHECK-NEXT: [[P0_BT:%[^ ]+]] = <<F8E4M3>>[32,2,16]{2,1,0} bitcast([[P0]])
; CHECK-NEXT: [[P0_TR:%[^ ]+]] = <<F8E4M3>>[16,2,32]{2,1,0} transpose([[P0_BT]]), dimensions={2,1,0}
; CHECK-NEXT: [[P0_BT1:%[^ ]+]] = <<F8E4M3>>[2,32,16]{1,0,2} bitcast([[P0_TR]])
; CHECK-NEXT: [[P1:%[^ ]+]] = <<F8E4M3>>[2,32,16]{2,1,0} parameter(1)
; CHECK-NEXT: [[P1_TRANSPOSE:%[^ ]+]] = <<F8E4M3>>[2,16,32]{2,1,0} transpose([[P1]]), dimensions={0,2,1}
; CHECK-NEXT: [[P2:%[^ ]+]] = f32[] parameter(2)
; CHECK-NEXT: [[P3:%[^ ]+]] = f32[] parameter(3)
; CHECK-NEXT: [[DQ:%[^ ]+]] = f32[] multiply([[P2]], [[P3]])
; CHECK-NEXT: [[C1:%[^ ]+]] = f32[] constant(1)
; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[16,16]{1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]),
; CHECK-NEXT: [[OUT:%[^ ]+]] = (f32[2,16,16]{2,1,0}, s8[{{[0-9]+}}]{0}) custom-call([[P0_BT1]], [[P1_TRANSPOSE]], [[DQ]], [[C1]], [[C1]], /*index=5*/[[C1]]),
; CHECK: custom_call_target="__cublas$lt$matmul$f8",
; CHECK: backend_config={
; CHECK-DAG: "alpha_real":1
; CHECK-DAG: "alpha_imag":0
; CHECK-DAG: "beta":0
; CHECK-DAG: "dot_dimension_numbers":{
; CHECK-DAG: "lhs_contracting_dimensions":["0"]
; CHECK-DAG: "rhs_contracting_dimensions":["1"]
; CHECK-DAG: "lhs_batch_dimensions":[]
; CHECK-DAG: "rhs_batch_dimensions":[]
; CHECK-DAG: "lhs_contracting_dimensions":["1"]
; CHECK-DAG: "rhs_contracting_dimensions":["2"]
; CHECK-DAG: "lhs_batch_dimensions":["0"]
; CHECK-DAG: "rhs_batch_dimensions":["0"]
; CHECK-DAG: }
; CHECK-DAG: "precision_config":{
; CHECK-DAG: "operand_precision":["DEFAULT","DEFAULT"]
Expand Down

0 comments on commit 7837845

Please sign in to comment.