Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed implicit split-k while moving from conv to gemm #1527

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from

Conversation

ravil-mobile
Copy link
Contributor

No description provided.

@ravil-mobile
Copy link
Contributor Author

Hi @krzysz00 @manupak

I'd like to generate conv_bwd_weight with mfma enabled. Below you can see an example. However, the resulting Op doesn't
have mfma flag set. Im confused whether it is an old bug or it was done on purpose or I am doing something wrong.

rocmlir-gen --conv-config " --operation conv_bwd_weight --arch amdgcn-amd-amdhsa:gfx908:sramecc+:xnack- -mfma=on --in_type fp32 --fil_type fp32 --out_type fp32 --fil_layout GNCHW --in_layout NGCHW --out_layout NGCHW --batchsize 64 --in_channels 1024 --out_channels 1024 --in_h 14 --in_w 14 --out_h 14 --out_w 14 --fil_h 1 --fil_w 1 --dilation_h 1 --dilation_w 1 --conv_stride_h 1 --conv_stride_w 1 --padding_h 0 --padding_w 0 --kernel_name conv --groupsize 1 "
module attributes {mhal.arch = "amdgcn-amd-amdhsa:gfx908:sramecc+:xnack-"} {
  func.func @conv_0(%arg0: memref<1x1024x1024x1x1xf32>, %arg1: memref<64x1x1024x14x14xf32>, %arg2: memref<64x1x1024x14x14xf32>) attributes {kernel = 0 : i32, mhal.arch = "amdgcn-amd-amdhsa:gfx908:sramecc+:xnack-"} {
    rock.conv_bwd_weight(%arg0, %arg1, %arg2) features =  dot|atomic_add {arch = "amdgcn-amd-amdhsa:gfx908:sramecc+:xnack-", dilations = [1 : index, 1 : index], filter_layout = ["g", "k", "c", "0", "1"], input_layout = ["ni", "gi", "ci", "0i", "1i"], numCU = 120 : i32, output_layout = ["no", "go", "ko", "0o", "1o"], padding = [0 : index, 0 : index, 0 : index, 0 : index], strides = [1 : index, 1 : index]} : memref<1x1024x1024x1x1xf32>, memref<64x1x1024x14x14xf32>, memref<64x1x1024x14x14xf32>
    return
  }
}

The absence of mfma flag results in skipping the following branch:

if (ConvOpType::BwdWeight == convOpType &&
isWrWAtomicKernel(features, dataType, maybeGemmExtraPad.has_value())) {
return backwardWeightAtomicAdd(cast<ConvBwdWeightOp>(op), b);
}

I will look at rocmlir-gen regarding why mfma flag gets dropped. I wonder maybe there is a reason for it

@ravil-mobile
Copy link
Contributor Author

@jerryyin, can you remember what was the reason around !requiredPadding. I mean why we restrict "no padding" to enabled atomic write ops.

bool mlir::rock::isWrWAtomicKernel(GemmFeatures features, Type dataType,
bool requiredPadding) {
return isAccel(features) &&
bitEnumContainsAll(features, GemmFeatures::atomic_add) &&
(dataType.isF32() || dataType.isF16()) && !requiredPadding;
}

I seems to me this part became some kind of legacy code

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant