Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a check for packed tensors for convolution solvers #2471

Merged
merged 9 commits into from
Oct 25, 2023
11 changes: 11 additions & 0 deletions src/include/miopen/conv/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,17 @@ struct ProblemDescription : ProblemDescriptionBase
bool IsNCHWc_NCHWc() const;
bool IsNCHWc_CHWNc() const;

bool HasNonPackedTensors() const

This comment was marked as resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamazov : favoring the current name as it is more descriptive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[resolved]

{
return !(in.IsPacked() && weights.IsPacked() && out.IsPacked());
}

bool HasMixedDataTypes() const
{
return !(GetInDataType() == GetWeightsDataType() &&
GetWeightsDataType() == GetOutDataType());
}

amberhassaan marked this conversation as resolved.
Show resolved Hide resolved
void HeuristicUpdateLayouts();

void MakeNetworkConfig(std::string& conf_key) const;
Expand Down
2 changes: 2 additions & 0 deletions src/include/miopen/fusion/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes
return false;
if(!conv_problem.IsFp32())
return false;
if(conv_problem.HasNonPackedTensors())
return false;
if(!conv_problem.IsLayoutDefault())
return false;
if(!conv_problem.direction.IsForward())
Expand Down
6 changes: 2 additions & 4 deletions src/solver/conv_MP_bidirectional_winograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,9 +229,7 @@ static bool IsApplicableTransform(const ExecutionContext& ctx, const ProblemDesc
}

if(!problem.IsLayoutDefault())
{
return false;
}

{
unsigned int const waves_in_group = 512 / wave_size;
Expand Down Expand Up @@ -323,11 +321,11 @@ bool ConvMPBidirectWinograd<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>::IsA
{
// HIP backend required for sending ptr (buffer + offset)
// ROCBLAS for GEMM step
if(problem.HasNonPackedTensors())
return false;

if(!problem.IsLayoutDefault())
{
return false;
}

if(problem.IsTensorsCasted())
return false;
Expand Down
6 changes: 2 additions & 4 deletions src/solver/conv_asm_1x1u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
return false;
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
Expand All @@ -545,13 +547,9 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip

const std::string name = ctx.GetStream().GetDeviceName();
if(name.find("gfx9") == std::string::npos)
{
return false;
}
if(!problem.IsLayoutDefault())
{
return false;
}

if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
return false;
Expand Down
3 changes: 0 additions & 3 deletions src/solver/conv_asm_1x1u_bias_activ_fused.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,6 @@ bool ConvBiasActivAsm1x1U::IsApplicable(const FusionContext& context,
if(conv_problem.GetDilationH() != 1)
return false;

if(conv_problem.IsTensorsCasted())
return false;

// Check if the conovlution part is applicable
return sol.IsApplicable(conv_ctx, conv_problem);
}
Expand Down
6 changes: 2 additions & 4 deletions src/solver/conv_asm_1x1u_stride2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx,
return false;
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand All @@ -505,13 +507,9 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx,

const std::string name = ctx.GetStream().GetDeviceName();
if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos)
{
return false;
}
averinevg marked this conversation as resolved.
Show resolved Hide resolved
if(!problem.IsLayoutDefault())
{
return false;
}

if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
return false;
Expand Down
4 changes: 2 additions & 2 deletions src/solver/conv_asm_3x3u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!(problem.direction.IsForward() || problem.direction.IsBackwardData()))
Expand All @@ -194,9 +196,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx90")))
return false;
if(!problem.IsLayoutDefault())
{
return false;
}

if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
return false;
Expand Down
8 changes: 2 additions & 6 deletions src/solver/conv_asm_5x10u2v2b1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand All @@ -63,17 +65,11 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx,
return false;
#endif
if(!device_is_gfx8_9_no_xnack)
{
return false;
}
if(!problem.direction.IsBackwardData())
{
return false;
}
if(!problem.IsLayoutDefault())
{
return false;
}
if(problem.IsTensorsCasted() || problem.IsFp8() || problem.IsBfp8())
return false;

Expand Down
9 changes: 2 additions & 7 deletions src/solver/conv_asm_5x10u2v2f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.Is2d())
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand All @@ -64,18 +66,11 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx,
return false;
#endif
if(!device_is_gfx8_9_no_xnack)
{
return false;
}
if(!problem.direction.IsForward())
{
return false;
}
if(!problem.IsLayoutDefault())
{
return false;
}

if(problem.IsTensorsCasted())
return false;

Expand Down
9 changes: 3 additions & 6 deletions src/solver/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx
if(!ctx.rmv.IsV2orV3())
return false;

if(problem.HasNonPackedTensors())
return false;

if(problem.IsTensorsCasted())
return false;

Expand All @@ -65,17 +68,11 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx
#endif
if(!(name == "gfx800" || name == "gfx802" || name == "gfx803" || name == "gfx804" ||
name == "gfx900" || name == "gfx904" || name == "gfx906" || name == "gfx908"))
{
return false;
}
Comment on lines 69 to -70
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curly braces are needed here. There was some misunderstanding here, I meant that curly braces are not required where the condition is short enough and the returned statement occupies one line. My comment related primarily to the new code. I should have explained this in more detail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, this is quite arbitrary and I find it hard to agree with. The language allows two choices: (1) no braces when the conditional has a single statement in the body and that has nothing to do with the length of the condition. That can be a source of bugs in absence of formatting, so I use {} all the time (choice #2). What you're asking is an arbitrary style choice and I find it inconsistent. Should we define how long is a "long condition"?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it must be formalized and forced by clang-format.
If it's not forced by clang-format, we are free to do whatever we want - and most probably we want to argue what's the best codestyle ever.
@atamazov @averinevg please add appropriate check for it and we automatically will follow. I see no point to check codestyle on review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll see if I can set up a clang format for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm of the opinion that we should always use {} and modify .clang-format to put the starting { on the same line as if so that code height remains somewhat small.

On the general topic of style, I've opened up #2482 for discussion on adopting Google Style Guide.

Copy link
Contributor

@atamazov atamazov Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CAHEK7

If it's not forced by clang-format, we are free to do whatever we want

NO. If there is no enforcing tool, then code review comes into play. That's one of the purposes of code reviews.

The reviewers are spending time, and good authors value this and learn from the reviews. And professional engineers are able to adapt his/her own tastes to whatever practices used in the communities (they joined).

@amberhassaan
Is it fine that a small code fix you were asked to make here transformed to a separate ticket, discussion thread here and additional work for Evgenii?

Copy link
Contributor

@CAHEK7 CAHEK7 Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamazov NO for you NO)
The problem happened because in the first comment it was recommended to remove the brackets for a single statement conditions (which was done) and later for some of the cases it was recommended to return the brackets back if the condition is long enough. It is too fuzzy and it's a matter of taste, not a rule.

clang-format can handle curly braced case in one way or another. Probably it handles it not ideally, but in uniform and automated way, which simplify everything. So, it must be used, moreover, if the ideal rules can't be formalized and each reviewer may have his own flavor of ideal solution.
And as you've already said - we must adapt our tastes to make the life easier.

Just to mention, I'm talking about the case which can be handled by clang-format.
If it can't handle the case with includes (quotes vs brackets), so I don't ask for it, but @pfultz2 provided a reference how to properly fix it.
If it can't handle naming convention, I don't ask for it, I'm asking for some sort of small notes like "WeUseCamelCase for types and classes and we_use_underscores for variables". It would help a lot for newcomers or if the code is mess.
Since it's an opensource project and available for everyone, it would be nice to have such a document.

And as I've already said - we want to argue about the codestyle, it happens over and over again.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@atamazov : no, it's not fine that we are discussing style issues that would've been easily handled if we had a style document. We are creating unnecessary back and forth for minor matters. We need to have a developer's guide document in the repository that documents all the nuances of things like Invoker & InvokerFactory and how in is out and out is in in backward convolution and style issues like include files and variable/function naming etc. Hence the discussion in #2483

if(!problem.direction.IsForward())
{
return false;
}
if(!problem.IsLayoutDefault())
{
return false;
}

// clang-format off
return problem.GetPadW() == 3 // -q
Expand Down
7 changes: 2 additions & 5 deletions src/solver/conv_asm_dir_BwdWrW1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,12 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.direction.IsBackwardWrW())
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
return false;

if(problem.IsTensorsCasted())
return false;

Expand All @@ -493,13 +494,9 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx,

const std::string name = ctx.GetStream().GetDeviceName();
if(name.find("gfx8") == std::string::npos && name.find("gfx9") == std::string::npos)
{
return false;
}
averinevg marked this conversation as resolved.
Show resolved Hide resolved
if(!problem.IsLayoutDefault())
{
return false;
}

if(name == "gfx90a" && problem.IsGfx90aFp16altRequired())
return false;
Expand Down
5 changes: 2 additions & 3 deletions src/solver/conv_asm_dir_BwdWrW3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx,
return false;
if(!problem.direction.IsBackwardWrW())
return false;
if(problem.HasNonPackedTensors())
return false;
if(problem.IsAsymmetricPadH() || problem.IsAsymmetricPadW())
return false;
if(!ctx.rmv.IsV2orV3())
Expand All @@ -375,10 +377,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx,
if(!(StartsWith(name, "gfx8") || StartsWith(name, "gfx9")))
return false;
if(!problem.IsLayoutDefault())
{
return false;
}

if(problem.IsTensorsCasted())
return false;
#if WORKAROUND_ISSUE_532
Expand Down
5 changes: 3 additions & 2 deletions src/solver/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx
if(!problem.direction.IsBackwardData())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.Is2d())
return false;

Expand All @@ -159,9 +162,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx
return false;

if(!problem.IsLayoutDefault())
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
Expand Down
5 changes: 3 additions & 2 deletions src/solver/conv_asm_implicit_gemm_gtc_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -992,6 +992,9 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsFp32() && !problem.IsFp16())
return false;

Expand All @@ -1006,9 +1009,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext
return false;

if(!problem.IsLayoutDefault())
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
Expand Down
3 changes: 3 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -938,6 +938,9 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsFp32() && !problem.IsFp16() &&
!(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94"))))
return false;
Expand Down
5 changes: 3 additions & 2 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,9 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsFp32() && !problem.IsFp16())
return false;

Expand All @@ -1531,9 +1534,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext
return false;

if(!problem.IsLayoutDefault())
{
return false;
}

#if WORKAROUND_SWDEV_306318
if((problem.GetWeightsHeight_() == 1) && (problem.GetWeightsWidth_() == 1) &&
Expand Down
3 changes: 3 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,9 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsLayoutNCHWc())
return false;

Expand Down
3 changes: 3 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,9 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsFp32() && !problem.IsFp16() &&
!(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94"))))
return false;
Expand Down
3 changes: 3 additions & 0 deletions src/solver/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,9 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable(
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsFp32() && !problem.IsFp16() &&
!(problem.IsBfp16() && (device_name == "gfx90a" || StartsWith(device_name, "gfx94"))))
return false;
Expand Down
5 changes: 3 additions & 2 deletions src/solver/conv_asm_implicit_gemm_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsFp32())
return false;

Expand All @@ -305,9 +308,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx
return false;

if(!problem.IsLayoutDefault())
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,9 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext
if(!problem.Is2d())
return false;

if(problem.HasNonPackedTensors())
return false;

if(!problem.IsFp32() && !problem.IsFp16())
return false;

Expand All @@ -846,9 +849,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext
return false;

if(!problem.IsLayoutDefault())
{
return false;
}

const auto target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
Expand Down
Loading