Skip to content

Commit

Permalink
Fix reshape unpacked error caused by find_contiguous_pointwise (#2721)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar authored Feb 5, 2024
1 parent b29d3d9 commit 2d4a650
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 7 deletions.
12 changes: 8 additions & 4 deletions src/targets/gpu/compile_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,18 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_BENCHMARKING);

struct precompile_op
{
operation op = op::identity{};
std::size_t additional_args = 1;
bool ignore_modules = false;
operation op = op::identity{};
std::size_t additional_args = 1;
bool ignore_modules = false;
std::optional<shape> output_shape = nullopt;

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.op, "op"),
f(self.additional_args, "additional_args"),
f(self.ignore_modules, "ignore_modules"));
f(self.ignore_modules, "ignore_modules"),
f(self.output_shape, "output_shape"));
}

std::string name() const { return "gpu::precompile_op"; }
Expand All @@ -59,6 +61,8 @@ struct precompile_op
{
// Pop off additional args
inputs.resize(inputs.size() - additional_args);
if(output_shape.has_value())
return output_shape.value();
if(ignore_modules)
return op.compute_shape(inputs);
return op.compute_shape(inputs, mods);
Expand Down
8 changes: 6 additions & 2 deletions src/targets/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -782,7 +782,11 @@ struct find_contiguous_pointwise
auto args = pw->inputs();
args.back() = alloc;

m.replace_instruction(ins, pw->get_operator(), args, pw->module_inputs());
// Ensure the output shape of the pointwise module is contiguous
auto pw_op_val = pw->get_operator().to_value();
pw_op_val["output_shape"] = to_value(ins->get_shape());

m.replace_instruction(ins, make_op(pw->name(), pw_op_val), args, pw->module_inputs());
}
};

Expand Down
58 changes: 57 additions & 1 deletion test/gpu/fuse_ops.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -104,4 +104,60 @@ TEST_CASE(layernorm_pointwise)
}
}

TEST_CASE(contiguous_pointwise)
{
migraphx::shape s1{migraphx::shape::float_type, {128, 4, 196, 32}};
migraphx::shape s2{migraphx::shape::float_type, {128, 196, 4, 32}};

auto create_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s2);
auto x_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s2)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x_trans, y}, single_pointwise("add"));
auto add1 = mm->add_instruction(
make_precompile_op("pointwise"), {x_trans, y, alloc_ins}, {pw_add1});

auto alloc_ins2 = mm->add_instruction(alloc);
auto cont = mm->add_instruction(migraphx::make_op("gpu::contiguous"), add1, alloc_ins2);
auto rsp =
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {25088, 128}}}), cont);
mm->add_return({rsp});
return p;
};

auto create_fused_program = [=]() {
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", s1);
auto y = mm->add_parameter("y", s2);
auto x_trans =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), x);
auto alloc = migraphx::make_op("allocate", {{"shape", to_value(s2)}});
auto alloc_ins = mm->add_instruction(alloc);
auto* pw_add1 =
create_pointwise_module(p, "main:pointwise0", {x_trans, y}, single_pointwise("add"));

auto pw_op = migraphx::make_op("pointwise");
auto pre_comp_op = migraphx::make_op(
"gpu::precompile_op",
{{"op", migraphx::to_value(pw_op)}, {"output_shape", migraphx::to_value(s2)}});
auto add1 = mm->add_instruction(pre_comp_op, {x_trans, y, alloc_ins}, {pw_add1});
auto rsp =
mm->add_instruction(migraphx::make_op("reshape_lazy", {{"dims", {25088, 128}}}), add1);
mm->add_return({rsp});
return p;
};

migraphx::program p1 = create_program();
run_pass(p1);
migraphx::program p2 = create_fused_program();
EXPECT(p1 == p2);
}

int main(int argc, const char* argv[]) { test::run(argc, argv); }

0 comments on commit 2d4a650

Please sign in to comment.