-
Notifications
You must be signed in to change notification settings - Fork 219
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
aarch64: cherrypick onednn pr#1768 to improve torch.compile perf (#1716)
this improves the bert base torch.compile perf by 5.8x on AWS c7g instance.
- Loading branch information
Showing
1 changed file
with
106 additions
and
0 deletions.
There are no files selected for viewing
106 changes: 106 additions & 0 deletions
106
mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
From b8595f6bcbc4f8e38ea9d1c42a65f82d72a07598 Mon Sep 17 00:00:00 2001 | ||
From: Sunita Nadampalli <nadampal@amazon.com> | ||
Date: Wed, 6 Mar 2024 00:37:30 +0000 | ||
Subject: [PATCH] onednn: pr1768: aarch64: add acl sbgemm inner product | ||
primitive | ||
|
||
--- | ||
src/cpu/aarch64/acl_inner_product.hpp | 33 +++++++++++++++++++++++---- | ||
src/cpu/cpu_inner_product_list.cpp | 9 ++++++++ | ||
2 files changed, 37 insertions(+), 5 deletions(-) | ||
|
||
diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp | ||
index a2be164f09..762d5d4896 100644 | ||
--- a/src/cpu/aarch64/acl_inner_product.hpp | ||
+++ b/src/cpu/aarch64/acl_inner_product.hpp | ||
@@ -93,20 +93,33 @@ struct acl_inner_product_fwd_t : public primitive_t { | ||
|
||
status_t init(engine_t *engine) { | ||
using namespace data_type; | ||
+ const format_kind_t weights_format_kind_received | ||
+ = weights_md_.format_kind; | ||
+ | ||
const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) | ||
&& attr()->has_default_values( | ||
primitive_attr_t::skip_mask_t::post_ops, f16); | ||
const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) | ||
&& attr()->has_default_values( | ||
primitive_attr_t::skip_mask_t::post_ops, f32); | ||
+ | ||
+ const bool is_fp32_bf16_ok | ||
+ = expect_data_types(f32, bf16, f32, f32, undef) | ||
+ && attr()->has_default_values( | ||
+ primitive_attr_t::skip_mask_t::post_ops, f32); | ||
+ | ||
+ const bool is_weights_md_format_ok | ||
+ = utils::one_of(weights_format_kind_received, | ||
+ format_kind::any, format_kind::blocked); | ||
+ | ||
const bool ok = is_fwd() && !has_zero_dim_memory() | ||
- && utils::one_of(true, is_fp16_ok, is_fp32_ok) | ||
- && weights_md_.format_kind == format_kind::any | ||
- && set_default_params() == status::success; | ||
+ && utils::one_of(true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok) | ||
+ && is_weights_md_format_ok | ||
+ && set_default_params(true) == status::success; | ||
|
||
if (!ok) return status::unimplemented; | ||
|
||
- CHECK(init_conf_ip(engine)); | ||
+ CHECK(init_conf_ip(engine, weights_format_kind_received)); | ||
|
||
return status::success; | ||
} | ||
@@ -115,7 +128,8 @@ struct acl_inner_product_fwd_t : public primitive_t { | ||
|
||
acl_post_ops_t post_ops; | ||
|
||
- status_t init_conf_ip(engine_t *engine) { | ||
+ status_t init_conf_ip( | ||
+ engine_t *engine, format_kind_t weights_format_kind_received) { | ||
|
||
ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims, | ||
"source and weights dimensions must match"); | ||
@@ -257,10 +271,19 @@ struct acl_inner_product_fwd_t : public primitive_t { | ||
return status::unimplemented; | ||
} | ||
|
||
+ const memory_desc_t weights_md_received = weights_md_; | ||
acl_utils::reorder_to_weight_format(aip.wei_tensor_info, | ||
weights_md_, expected_weight_format, inner_dim, o_dim, | ||
remaining_dims, {}); | ||
|
||
+ ACL_CHECK_SUPPORT( | ||
+ (weights_format_kind_received == format_kind::blocked) | ||
+ && !(dnnl_memory_desc_equal( | ||
+ &weights_md_received, &weights_md_)), | ||
+ "specific blocked format not supported by ACL, use " | ||
+ "format_kind_t::any to find a supported blocked format for " | ||
+ "your platform"); | ||
+ | ||
// clang-format off | ||
|
||
// Validate fully connected layer manually to check for return status | ||
diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp | ||
index fdd7b17769..1f59547304 100644 | ||
--- a/src/cpu/cpu_inner_product_list.cpp | ||
+++ b/src/cpu/cpu_inner_product_list.cpp | ||
@@ -83,6 +83,15 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() | ||
CPU_INSTANCE(ref_inner_product_fwd_t) | ||
nullptr, | ||
}}, | ||
+ /* With graph compilation, we are able to reorder and pre-pack the weights during the model load | ||
+ * and compilation phase itself so that redundant and on-the-fly reorders can be avoided. | ||
+ * This primitive definition is to support gemm fastmath mode for the compile scenario where src is | ||
+ * in fp32 and weights are in bf16 | ||
+ */ | ||
+ {{forward, f32, bf16, f32}, { | ||
+ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) | ||
+ nullptr, | ||
+ }}, | ||
{{backward_data, f32, f32, f32}, REG_BWD_PK({ | ||
CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t<avx512_core_amx>) // bf32 | ||
CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t<avx512_core>) | ||
-- | ||
2.34.1 | ||
|