Skip to content

Commit

Permalink
[GPU] Extend gemm to fuse broadcast and reshape layers (#23513)
Browse files Browse the repository at this point in the history
### Details:
- Fuse `broadcast` and `reshape` layers into `gemm` layer for LLM's 2nd
latency optimization
     - before : [`broadcast`] --> [`reshape`] --> `gemm`
     - after : `gemm`
- `gemm` is extended to have `input0_target_shape`,
`input1_target_shape`, `input0_output_pattern` and
`input1_output_pattern` from `broadcast` and `reshape` layers

### Tickets:
 - 128343

---------

Signed-off-by: Andrew Park <andrew.park@intel.com>
  • Loading branch information
andrew-k-park authored Mar 25, 2024
1 parent dbef32e commit 133b139
Show file tree
Hide file tree
Showing 19 changed files with 925 additions and 109 deletions.
29 changes: 26 additions & 3 deletions src/plugins/intel_gpu/include/intel_gpu/op/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,30 @@ class Gemm : public ov::op::v0::MatMul {
const std::vector<int64_t>& order_c,
const ov::element::Type output_type = ov::element::undefined);

Gemm(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const std::vector<int32_t>& target_shape_a,
const std::vector<int32_t>& target_shape_b,
const std::vector<int64_t>& output_pattern_a,
const std::vector<int64_t>& output_pattern_b,
const std::vector<int64_t>& order_a,
const std::vector<int64_t>& order_b,
const std::vector<int64_t>& order_c,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::vector<int64_t> get_input0_order() const { return m_order_a; }
std::vector<int64_t> get_input1_order() const { return m_order_b; }
std::vector<int64_t> get_output_order() const { return m_order_c; }
std::vector<int32_t> get_input0_broadcast_target_shape() const { return m_target_shape_a; }
std::vector<int32_t> get_input1_broadcast_target_shape() const { return m_target_shape_b; }
std::vector<int64_t> get_input0_reshape_pattern() const { return m_output_pattern_a; }
std::vector<int64_t> get_input1_reshape_pattern() const { return m_output_pattern_b; }
std::vector<int64_t> get_input0_transpose_order() const { return m_order_a; }
std::vector<int64_t> get_input1_transpose_order() const { return m_order_b; }
std::vector<int64_t> get_output_transpose_order() const { return m_order_c; }
ov::element::Type get_output_type() const { return m_output_type; }

static std::vector<int64_t> default_order(size_t rank) {
Expand All @@ -44,6 +59,10 @@ class Gemm : public ov::op::v0::MatMul {
}

protected:
std::vector<int32_t> m_target_shape_a;
std::vector<int32_t> m_target_shape_b;
std::vector<int64_t> m_output_pattern_a;
std::vector<int64_t> m_output_pattern_b;
std::vector<int64_t> m_order_a;
std::vector<int64_t> m_order_b;
std::vector<int64_t> m_order_c;
Expand All @@ -52,6 +71,10 @@ class Gemm : public ov::op::v0::MatMul {

std::vector<ov::PartialShape> shape_infer(const Gemm* op,
std::vector<ov::PartialShape> input_shapes,
const std::vector<int32_t>& target_shape_a,
const std::vector<int32_t>& target_shape_b,
const std::vector<int64_t>& output_pattern_a,
const std::vector<int64_t>& output_pattern_b,
const std::vector<int64_t>& order_a,
const std::vector<int64_t>& order_b,
const std::vector<int64_t>& order_c);
Expand Down
109 changes: 71 additions & 38 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ struct gemm : public primitive_base<gemm> {
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
transpose_input0(transpose_input0 ? 1 : 0),
transpose_input1(transpose_input1 ? 1 : 0),
input0_broadcast_target_shape({}),
input1_broadcast_target_shape({}),
input0_reshape_pattern({}),
input1_reshape_pattern({}),
alpha(alpha),
beta(beta),
input_rank(input_rank),
Expand All @@ -70,9 +74,9 @@ struct gemm : public primitive_base<gemm> {
return order;
};

input0_order = get_transposed_order(input_rank, transpose_input0);
input1_order = get_transposed_order(weight_rank, transpose_input1);
output_order = {};
input0_transpose_order = get_transposed_order(input_rank, transpose_input0);
input1_transpose_order = get_transposed_order(weight_rank, transpose_input1);
output_transpose_order = {};
}

/// @brief Constructs gemm layer.
Expand All @@ -86,69 +90,89 @@ struct gemm : public primitive_base<gemm> {
gemm(const primitive_id& id,
const std::vector<input_info>& inputs,
const data_types data_type,
const std::vector<int64_t>& input0_order = {0, 1, 2, 3},
const std::vector<int64_t>& input1_order = {0, 1, 2, 3},
const std::vector<int64_t>& output_order = {},
const std::vector<int32_t>& input0_broadcast_target_shape = {},
const std::vector<int32_t>& input1_broadcast_target_shape = {},
const std::vector<int64_t>& input0_reshape_pattern = {},
const std::vector<int64_t>& input1_reshape_pattern = {},
const std::vector<int64_t>& input0_transpose_order = {0, 1, 2, 3},
const std::vector<int64_t>& input1_transpose_order = {0, 1, 2, 3},
const std::vector<int64_t>& output_transpose_order = {},
const float alpha = 1.0f,
const float beta = 0.0f,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
input0_order(input0_order),
input1_order(input1_order),
output_order(output_order),
input0_broadcast_target_shape(input0_broadcast_target_shape),
input1_broadcast_target_shape(input1_broadcast_target_shape),
input0_reshape_pattern(input0_reshape_pattern),
input1_reshape_pattern(input1_reshape_pattern),
input0_transpose_order(input0_transpose_order),
input1_transpose_order(input1_transpose_order),
output_transpose_order(output_transpose_order),
alpha(alpha),
beta(beta),
input_rank(input0_order.size()),
weight_rank(input1_order.size()) {
input_rank(input0_transpose_order.size()),
weight_rank(input1_transpose_order.size()) {
if (inputs.size() != 2 && inputs.size() != 3) {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}

transpose_input0 = get_transpose_mode(input0_order);
transpose_input1 = get_transpose_mode(input1_order);
transpose_input0 = get_transpose_mode(input0_transpose_order);
transpose_input1 = get_transpose_mode(input1_transpose_order);
}

gemm(const primitive_id& id,
const std::vector<input_info>& inputs,
const input_info& beam_table,
const data_types data_type,
const std::vector<int64_t>& input0_order,
const std::vector<int64_t>& input1_order,
const std::vector<int64_t>& output_order,
const std::vector<int64_t>& input0_transpose_order,
const std::vector<int64_t>& input1_transpose_order,
const std::vector<int64_t>& output_transpose_order,
bool indirect_a,
bool indirect_b,
const float alpha = 1.0f,
const float beta = 0.0f,
const padding& output_padding = padding())
: primitive_base(id, inputs, {output_padding}, {optional_data_type{ data_type }}),
input0_order(input0_order),
input1_order(input1_order),
output_order(output_order),
input0_broadcast_target_shape({}),
input1_broadcast_target_shape({}),
input0_reshape_pattern({}),
input1_reshape_pattern({}),
input0_transpose_order(input0_transpose_order),
input1_transpose_order(input1_transpose_order),
output_transpose_order(output_transpose_order),
alpha(alpha),
beta(beta),
input_rank(input0_order.size()),
weight_rank(input1_order.size()),
input_rank(input0_transpose_order.size()),
weight_rank(input1_transpose_order.size()),
beam_table(beam_table),
indirect_a(indirect_a),
indirect_b(indirect_b) {
if (inputs.size() != 2 && inputs.size() != 3) {
throw std::invalid_argument("Invalid inputs count - gemm expects either two or three inputs");
}

transpose_input0 = get_transpose_mode(input0_order);
transpose_input1 = get_transpose_mode(input1_order);
transpose_input0 = get_transpose_mode(input0_transpose_order);
transpose_input1 = get_transpose_mode(input1_transpose_order);
}

/// @brief Flag for transposing first input matrix
uint32_t transpose_input0 = 0;
/// @brief Flag for transposing second input matrix
uint32_t transpose_input1 = 0;
/// @brief broadcasted target shape of input 0
std::vector<int32_t> input0_broadcast_target_shape;
/// @brief broadcasted target shape of input 1
std::vector<int32_t> input1_broadcast_target_shape;
/// @brief reshaped output pattern of input 0
std::vector<int64_t> input0_reshape_pattern;
/// @brief reshaped output pattern of input 1
std::vector<int64_t> input1_reshape_pattern;
/// @brief order of input 0
std::vector<int64_t> input0_order;
std::vector<int64_t> input0_transpose_order;
/// @brief order of input 1
std::vector<int64_t> input1_order;
std::vector<int64_t> input1_transpose_order;
/// @brief order of output
std::vector<int64_t> output_order;
std::vector<int64_t> output_transpose_order;
/// @brief Variable containing ALPHA parameter
float alpha = 1.0f;
/// @brief Variable containing BETA parameter
Expand All @@ -169,12 +193,13 @@ struct gemm : public primitive_base<gemm> {
seed = hash_combine(seed, transpose_input1);
seed = hash_combine(seed, indirect_a);
seed = hash_combine(seed, indirect_b);
for (auto order : input0_order)
seed = hash_combine(seed, order);
for (auto order : input1_order)
seed = hash_combine(seed, order);
for (auto order : output_order)
seed = hash_combine(seed, order);
seed = hash_range(seed, input0_broadcast_target_shape.begin(), input0_broadcast_target_shape.end());
seed = hash_range(seed, input1_broadcast_target_shape.begin(), input1_broadcast_target_shape.end());
seed = hash_range(seed, input0_reshape_pattern.begin(), input0_reshape_pattern.end());
seed = hash_range(seed, input1_reshape_pattern.begin(), input1_reshape_pattern.end());
seed = hash_range(seed, input0_transpose_order.begin(), input0_transpose_order.end());
seed = hash_range(seed, input1_transpose_order.begin(), input1_transpose_order.end());
seed = hash_range(seed, output_transpose_order.begin(), output_transpose_order.end());
seed = hash_combine(seed, alpha);
seed = hash_combine(seed, beta);
return seed;
Expand All @@ -200,9 +225,13 @@ struct gemm : public primitive_base<gemm> {
primitive_base<gemm>::save(ob);
ob << transpose_input0;
ob << transpose_input1;
ob << input0_order;
ob << input1_order;
ob << output_order;
ob << input0_broadcast_target_shape;
ob << input1_broadcast_target_shape;
ob << input0_reshape_pattern;
ob << input1_reshape_pattern;
ob << input0_transpose_order;
ob << input1_transpose_order;
ob << output_transpose_order;
ob << alpha;
ob << beta;
ob << input_rank;
Expand All @@ -217,9 +246,13 @@ struct gemm : public primitive_base<gemm> {
primitive_base<gemm>::load(ib);
ib >> transpose_input0;
ib >> transpose_input1;
ib >> input0_order;
ib >> input1_order;
ib >> output_order;
ib >> input0_broadcast_target_shape;
ib >> input1_broadcast_target_shape;
ib >> input0_reshape_pattern;
ib >> input1_reshape_pattern;
ib >> input0_transpose_order;
ib >> input1_transpose_order;
ib >> output_transpose_order;
ib >> alpha;
ib >> beta;
ib >> input_rank;
Expand Down
Loading

0 comments on commit 133b139

Please sign in to comment.