Skip to content

Commit

Permalink
Fixed incorrect transpose in find 2.0 (#3285)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrizztDoUrden authored Oct 1, 2024
1 parent d68dbc1 commit 2d69aeb
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,11 @@ Problem::FindSolutions(Handle& handle, const FindOptions& options, std::size_t m
auto ret = std::visit(
boost::hof::match(
[&](const ConvolutionDescriptor& op_desc) {
return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc);
if(op_desc.mode == miopenTranspose)
return MakeTransposed().FindSolutionsImpl(
handle, options, max_solutions, buffers, op_desc);
else
return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc);
},
[&](const SoftmaxDescriptor& op_desc) {
return FindSolutionsImpl(handle, options, max_solutions, buffers, op_desc);
Expand Down Expand Up @@ -477,21 +481,17 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
const auto& w = buffers.at(miopenTensorConvolutionW);
auto y = buffers.at(miopenTensorConvolutionY);

const auto conv_problem =
conv_desc.mode == miopenTranspose ? MakeTransposed().AsConvolution() : AsConvolution();

std::size_t workspace_size;
Allocator::ManageDataPtr owned_workspace;
Data_t workspace;

if(conv_desc.mode == miopenTranspose)
{
std::swap(x, y);
std::swap(x_desc, y_desc);
}

const auto conv_problem = AsConvolution();

ValidateGroupCount(x_desc, w_desc, conv_desc);

std::size_t workspace_size;
Allocator::ManageDataPtr owned_workspace;
Data_t workspace;

if(options.preallocated_workspace)
{
workspace = options.preallocated_workspace->buffer;
Expand Down

0 comments on commit 2d69aeb

Please sign in to comment.