Skip to content

Commit

Permalink
fix-trans-conv-issue-2459(03) Fix issue 2459. Add some documentation.
Browse files Browse the repository at this point in the history
  • Loading branch information
atamazov committed Oct 26, 2023
1 parent 3d3710c commit 12f2c9d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 42 deletions.
60 changes: 22 additions & 38 deletions src/convolution_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,13 @@ static inline auto MakeFwdCtxAndProblem(miopenHandle_t handle,
const auto direction =
(conv.mode != miopenTranspose) ? Direction::Forward : Direction::BackwardData;

auto problem = (conv.mode != miopenTranspose) ? ProblemDescription{miopen::deref(xDesc),
miopen::deref(wDesc),
miopen::deref(yDesc),
conv,
direction}
: ProblemDescription{miopen::deref(yDesc),
miopen::deref(wDesc),
miopen::deref(xDesc),
conv,
direction};
/// \anchor transpose_convolutions_x_y_swapping
/// In transpose mode we exchange x with y. From the other hand, when Backward*
/// ProblemDescription instances are constructed, x and y shall be swapped as well.
/// As transpose mode swaps Forward with Backward AND x with y, the order of
/// ctor arguments remains the same.
auto problem = ProblemDescription{
miopen::deref(xDesc), miopen::deref(wDesc), miopen::deref(yDesc), conv, direction};

auto ctx = ExecutionContext{&miopen::deref(handle)};
problem.SetupFloats(ctx);
Expand All @@ -82,16 +79,9 @@ static inline auto MakeBwdCtxAndProblem(miopenHandle_t handle,
const auto direction =
(conv.mode != miopenTranspose) ? Direction::BackwardData : Direction::Forward;

auto problem = (conv.mode != miopenTranspose) ? ProblemDescription{miopen::deref(dyDesc),
miopen::deref(wDesc),
miopen::deref(dxDesc),
conv,
direction}
: ProblemDescription{miopen::deref(dxDesc),
miopen::deref(wDesc),
miopen::deref(dyDesc),
conv,
direction};
/// \ref transpose_convolutions_x_y_swapping
auto problem = ProblemDescription{
miopen::deref(dyDesc), miopen::deref(wDesc), miopen::deref(dxDesc), conv, direction};

auto ctx = ExecutionContext{&miopen::deref(handle)};
problem.SetupFloats(ctx);
Expand Down Expand Up @@ -490,7 +480,7 @@ miopenFindConvolutionForwardAlgorithm(miopenHandle_t handle,

miopen::debug::LogCmdFindConvolution(
xDesc, wDesc, convDesc, yDesc, miopen::debug::ConvDirection::Fwd, false);
/// workaround for previous trans conv logic

if(miopen::deref(convDesc).mode == miopenTranspose)
return miopen::try_([&] {
miopen::deref(convDesc).FindConvBwdDataAlgorithm(miopen::deref(handle),
Expand Down Expand Up @@ -563,7 +553,6 @@ extern "C" miopenStatus_t miopenConvolutionForward(miopenHandle_t handle,
miopen::debug::LogCmdConvolution(
xDesc, wDesc, convDesc, yDesc, miopen::debug::ConvDirection::Fwd, false);

/// workaround for previous trans conv logic
if(miopen::deref(convDesc).mode == miopenTranspose)
return miopen::try_([&] {
// It is guaranteed that enum values are equal, see conv_algo_name.cpp
Expand Down Expand Up @@ -1067,7 +1056,7 @@ miopenFindConvolutionBackwardDataAlgorithm(miopenHandle_t handle,

miopen::debug::LogCmdFindConvolution(
dxDesc, wDesc, convDesc, dyDesc, miopen::debug::ConvDirection::Bwd, false);
/// workaround for previous trans conv logic

if(miopen::deref(convDesc).mode == miopenTranspose)
return miopen::try_([&] {
miopen::deref(convDesc).FindConvFwdAlgorithm(miopen::deref(handle),
Expand Down Expand Up @@ -1141,7 +1130,6 @@ miopenConvolutionBackwardData(miopenHandle_t handle,
miopen::debug::LogCmdConvolution(
dxDesc, wDesc, convDesc, dyDesc, miopen::debug::ConvDirection::Bwd, false);

/// workaround for previous trans conv logic
if(miopen::deref(convDesc).mode == miopenTranspose)
return miopen::try_([&] {
// It is guaranteed that enum values are equal, see conv_algo_name.cpp
Expand Down Expand Up @@ -1247,15 +1235,13 @@ miopenFindConvolutionBackwardWeightsAlgorithm(miopenHandle_t handle,
xDesc, dwDesc, convDesc, dyDesc, miopen::debug::ConvDirection::WrW, false);

return miopen::try_([&] {
const auto trans = (miopen::deref(convDesc).mode == miopenTranspose);
miopen::deref(convDesc).FindConvBwdWeightsAlgorithm(
miopen::deref(handle),
/// workaround for previous trans conv logic
miopen::deref(convDesc).mode == miopenTranspose ? miopen::deref(xDesc)
: miopen::deref(dyDesc),
miopen::deref(convDesc).mode == miopenTranspose ? DataCast(x) : DataCast(dy),
miopen::deref(convDesc).mode == miopenTranspose ? miopen::deref(dyDesc)
: miopen::deref(xDesc),
miopen::deref(convDesc).mode == miopenTranspose ? DataCast(dy) : DataCast(x),
trans ? miopen::deref(xDesc) : miopen::deref(dyDesc),
trans ? DataCast(x) : DataCast(dy),
trans ? miopen::deref(dyDesc) : miopen::deref(xDesc),
trans ? DataCast(dy) : DataCast(x),
miopen::deref(dwDesc),
DataCast(dw),
requestAlgoCount,
Expand Down Expand Up @@ -1300,16 +1286,14 @@ miopenConvolutionBackwardWeights(miopenHandle_t handle,
xDesc, dwDesc, convDesc, dyDesc, miopen::debug::ConvDirection::WrW, false);

return miopen::try_([&] {
const auto trans = (miopen::deref(convDesc).mode == miopenTranspose);
miopen::deref(convDesc).ConvolutionBackwardWeights(
miopen::deref(handle),
alpha,
/// workaround for previous trans conv logic
miopen::deref(convDesc).mode == miopenTranspose ? miopen::deref(xDesc)
: miopen::deref(dyDesc),
miopen::deref(convDesc).mode == miopenTranspose ? DataCast(x) : DataCast(dy),
miopen::deref(convDesc).mode == miopenTranspose ? miopen::deref(dyDesc)
: miopen::deref(xDesc),
miopen::deref(convDesc).mode == miopenTranspose ? DataCast(dy) : DataCast(x),
trans ? miopen::deref(xDesc) : miopen::deref(dyDesc),
trans ? DataCast(x) : DataCast(dy),
trans ? miopen::deref(dyDesc) : miopen::deref(xDesc),
trans ? DataCast(dy) : DataCast(x),
algo,
beta,
miopen::deref(dwDesc),
Expand Down
5 changes: 3 additions & 2 deletions src/include/miopen/conv/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ struct ProblemDescription : ProblemDescriptionBase
{
ProblemDescription() = default;

ProblemDescription(const TensorDescriptor& in_,
/// \todo Get rid of the swapping of x and y.
ProblemDescription(const TensorDescriptor& in_, // x for Forward, y for Backward*
const TensorDescriptor& weights_,
const TensorDescriptor& out_,
const TensorDescriptor& out_, // y for Forward, x for Backward*
const ConvolutionDescriptor& conv_,
Direction direction_,
int bias_ = 0)
Expand Down
4 changes: 2 additions & 2 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -578,12 +578,12 @@ ConvolutionDescriptor::GetSolutionsFallback(const ExecutionContext& ctx,
/// \todo This is terrible. Should do away when we converge to
/// single conv::ProblemDescription type.
const auto legacy_problem = ProblemDescription{problem};
const auto& inDesc =
const auto& xDesc =
(problem.GetDirection() == conv::Direction::Forward) ? problem.GetIn() : problem.GetOut();
const auto& weightsDesc = problem.GetWeights();
// This check is needed on fallback path only.
// On regular path (find-db hit) this was checked during Find().
ValidateGroupCount(inDesc, weightsDesc, *this);
ValidateGroupCount(xDesc, weightsDesc, *this);

auto interim = std::vector<miopenConvSolution_t>{};
interim.reserve(maxSolutionCount); // For speed. In most cases we have less entries than asked.
Expand Down

0 comments on commit 12f2c9d

Please sign in to comment.