Skip to content

Commit

Permalink
[mlir][NVGPU] Support cache all (.ca) in nvgpu.device_async_copy
Browse files Browse the repository at this point in the history
This patch adds support for cache all (.ca) in conversion from nvgpu-to-nvvm for inline asm `cp.async`.

For sizes other than 16 bytes cp.async cache global is not allowed and cache all is required to generate a valid ptx.

Differential revision: https://reviews.llvm.org/D148604

Authored-by: Manish Gupta <manigupta@google.com>
  • Loading branch information
nicolasvasilache committed Apr 18, 2023
1 parent 5fdf4d5 commit 95cb986
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
17 changes: 16 additions & 1 deletion mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,9 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,
ConversionPatternRewriter &rewriter) {
auto asmDialectAttr = LLVM::AsmDialectAttr::get(rewriter.getContext(),
LLVM::AsmDialect::AD_ATT);
const char *asmStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";

const char *cpAsyncCgStr = "cp.async.cg.shared.global [$0], [$1], $2, $3;\n";
const char *cpAsyncCaStr = "cp.async.ca.shared.global [$0], [$1], $2, $3;\n";
const char *asmConstraints = "r,l,n,r";

Value c3I32 = rewriter.create<LLVM::ConstantOp>(
Expand All @@ -382,6 +384,19 @@ static void emitCpAsyncOpZfillAsm(Location loc, Value dstPtr, Value srcPtr,

SmallVector<Value> asmVals{dstPtr, srcPtr, dstBytes, srcBytes};

// Pick the right asm string based on the dstBytes which is a compile-time
// constant.
auto dstByteConstOp =
dyn_cast<mlir::LLVM::ConstantOp>(dstBytes.getDefiningOp());
auto dstByteAttr = dstByteConstOp.getValue().dyn_cast<mlir::IntegerAttr>();
int64_t dstByteVal = dstByteAttr.getValue().getSExtValue();

assert((dstByteVal == 4 || dstByteVal == 8 || dstByteVal == 16) &&
"cp.async byte copy size must be 4, 8 or 16");
// Cache global (.cg) for 16 dst bytes, Cache all (.ca) for sizes other than
// 16 dst bytes.
const char *asmStr = (dstByteVal == 16) ? cpAsyncCgStr : cpAsyncCaStr;

rewriter.create<LLVM::InlineAsmOp>(
loc, LLVM::LLVMVoidType::get(rewriter.getContext()),
/*operands=*/asmVals,
Expand Down
26 changes: 22 additions & 4 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ func.func @async_cp_i4(

// -----

// CHECK-LABEL: @async_cp_zfill(
// CHECK-LABEL: @async_cp_zfill_f32_align4(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_zfill(
func.func @async_cp_zfill_f32_align4(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {

// CHECK-DAG: lvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES:.*]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
// CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(16 : i32) : i32
// CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.cg.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 4, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
Expand All @@ -312,6 +312,24 @@ func.func @async_cp_zfill(

// -----

// CHECK-LABEL: @async_cp_zfill_f32_align1(
// CHECK-SAME: %[[IDX:[a-zA-Z0-9_]+]]: index, %[[SRCELEMENTS:[a-zA-Z0-9_]+]]: index)
func.func @async_cp_zfill_f32_align1(
%src: memref<128x128xf32>, %dst: memref<3x16x128xf32, 3>, %i : index, %srcElements : index) {
// CHECK-DAG: %[[DSTBYTES:.*]] = llvm.mlir.constant(4 : i32) : i32
// CHECK-DAG: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.ca.shared.global [$0], [$1], $2, $3;\0A", "r,l,n,r" %[[DSTPTR:.*]], %[[SRCPTR:.*]], %[[DSTBYTES]], %[[SRCBYTES:.*]] : (!llvm.ptr<3>, !llvm.ptr<1>, i32, i32) -> !llvm.void
%0 = nvgpu.device_async_copy %src[%i, %i], %dst[%i, %i, %i], 1, %srcElements {bypassL1}: memref<128x128xf32> to memref<3x16x128xf32, 3>
// CHECK: nvvm.cp.async.commit.group
%1 = nvgpu.device_async_create_group %0
// CHECK: nvvm.cp.async.wait.group 1
nvgpu.device_async_wait %1 { numGroups = 1 : i32 }

return
}

// -----


// CHECK-LABEL: func @mma_sp_sync_f16_16832(
func.func @mma_sp_sync_f16_16832(%arg0: vector<4x2xf16>,
%arg1: vector<4x2xf16>,
Expand Down

0 comments on commit 95cb986

Please sign in to comment.