Skip to content

Commit

Permalink
Improve based on review #1
Browse files Browse the repository at this point in the history
  • Loading branch information
wenscarl committed Oct 30, 2023
1 parent a4140da commit 9684568
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
24 changes: 11 additions & 13 deletions xla/stream_executor/cuda/cuda_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
/*static*/ tsl::StatusOr<BlasLt::MatmulDesc> BlasLt::MatmulDesc::Create(
blas::ComputationType compute_type, blas::DataType scale_type,
blas::Transpose trans_a, blas::Transpose trans_b,
gpu::BlasLt::Epilogue epilogue, bool fast_accum, PointerMode pointer_mode) {
gpu::BlasLt::Epilogue epilogue, bool enable_fast_accum,
PointerMode pointer_mode) {
VLOG(2) << "MatmulDesc::Create: compute_type: " << (int)compute_type
<< " scale:" << (int)scale_type << " trans a/b: " << (int)trans_a
<< "," << (int)trans_b << " epilogue:" << (int)epilogue
Expand All @@ -210,8 +211,8 @@ cudaDataType_t BlasLt::MatrixLayout::type() const {
AsCublasOperation(trans_b)));
TF_ASSIGN_OR_RETURN(cublasLtEpilogue_t epi, AsCublasLtEpilogue(epilogue));
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, epi));
TF_RETURN_IF_ERROR(
SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, int8_t(fast_accum)));
TF_RETURN_IF_ERROR(SetAttr(cu_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM,
static_cast<int8_t>(enable_fast_accum)));
return std::move(desc);
}

Expand Down Expand Up @@ -317,20 +318,17 @@ auto BlasLt::GetMatmulPlan(const gpu::GemmConfig& cfg,
cfg.compute_precision));
}

// For FP8 matrix multiplications, the PrecisionConfig determines whether fast
// accumulation should be enabled. In the DEFAULT precision mode, typically
// encountered during forward propagation with E4M3 operands, fast
// accumulation is enabled. When Precision is set to HIGHEST, indicative of
// scenarios in backward propagation, a higher precision accumulation method
// is utilized.
bool fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) ||
xla::primitive_util::IsF8Type(rhs_layout.dtype)) &&
cfg.compute_precision == 0;
// For FP8 matmuls, there are two options available: fast
// accumulation(PrecisionConfig.Precision.DEFAULT) and
// higher precision accumulation (PrecisionConfig.Precision.HIGHEST).
bool enable_fast_accum = (xla::primitive_util::IsF8Type(lhs_layout.dtype) ||
xla::primitive_util::IsF8Type(rhs_layout.dtype)) &&
cfg.compute_precision == 0;
TF_ASSIGN_OR_RETURN(
auto op_desc,
MatmulDesc::Create(*compute_type,
gpu::GetScaleType(output_dtype, *compute_type),
trans_a, trans_b, epilogue, fast_accum));
trans_a, trans_b, epilogue, enable_fast_accum));

TF_ASSIGN_OR_RETURN(auto a_desc, MatrixLayout::Create(lhs_layout));
TF_ASSIGN_OR_RETURN(auto b_desc, MatrixLayout::Create(rhs_layout));
Expand Down
2 changes: 1 addition & 1 deletion xla/stream_executor/cuda/cuda_blas_lt.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BlasLt : public gpu::BlasLt {
blas::ComputationType compute_type, blas::DataType scale_type,
blas::Transpose trans_a = blas::Transpose::kNoTranspose,
blas::Transpose trans_b = blas::Transpose::kNoTranspose,
Epilogue epilogue = Epilogue::kDefault, bool fast_accum = false,
Epilogue epilogue = Epilogue::kDefault, bool enable_fast_accum = false,
PointerMode pointer_mode = PointerMode::kHost);

cublasComputeType_t compute_type() const;
Expand Down

0 comments on commit 9684568

Please sign in to comment.