diff --git a/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch b/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch new file mode 100644 index 000000000..2f860ca10 --- /dev/null +++ b/mkldnn_fix/onednn-pr1768-aarch64-add-acl-sbgemm-inner-product-primitive.patch @@ -0,0 +1,106 @@ +From b8595f6bcbc4f8e38ea9d1c42a65f82d72a07598 Mon Sep 17 00:00:00 2001 +From: Sunita Nadampalli +Date: Wed, 6 Mar 2024 00:37:30 +0000 +Subject: [PATCH] onednn: pr1768: aarch64: add acl sbgemm inner product + primitive + +--- + src/cpu/aarch64/acl_inner_product.hpp | 33 +++++++++++++++++++++++---- + src/cpu/cpu_inner_product_list.cpp | 9 ++++++++ + 2 files changed, 37 insertions(+), 5 deletions(-) + +diff --git a/src/cpu/aarch64/acl_inner_product.hpp b/src/cpu/aarch64/acl_inner_product.hpp +index a2be164f09..762d5d4896 100644 +--- a/src/cpu/aarch64/acl_inner_product.hpp ++++ b/src/cpu/aarch64/acl_inner_product.hpp +@@ -93,20 +93,33 @@ struct acl_inner_product_fwd_t : public primitive_t { + + status_t init(engine_t *engine) { + using namespace data_type; ++ const format_kind_t weights_format_kind_received ++ = weights_md_.format_kind; ++ + const bool is_fp16_ok = expect_data_types(f16, f16, f16, f16, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f16); + const bool is_fp32_ok = expect_data_types(f32, f32, f32, f32, undef) + && attr()->has_default_values( + primitive_attr_t::skip_mask_t::post_ops, f32); ++ ++ const bool is_fp32_bf16_ok ++ = expect_data_types(f32, bf16, f32, f32, undef) ++ && attr()->has_default_values( ++ primitive_attr_t::skip_mask_t::post_ops, f32); ++ ++ const bool is_weights_md_format_ok ++ = utils::one_of(weights_format_kind_received, ++ format_kind::any, format_kind::blocked); ++ + const bool ok = is_fwd() && !has_zero_dim_memory() +- && utils::one_of(true, is_fp16_ok, is_fp32_ok) +- && weights_md_.format_kind == format_kind::any +- && set_default_params() == status::success; ++ && utils::one_of(true, is_fp16_ok, is_fp32_ok, is_fp32_bf16_ok) ++ && is_weights_md_format_ok ++ && set_default_params(true) == status::success; + + if (!ok) return status::unimplemented; + +- CHECK(init_conf_ip(engine)); ++ CHECK(init_conf_ip(engine, weights_format_kind_received)); + + return status::success; + } +@@ -115,7 +128,8 @@ struct acl_inner_product_fwd_t : public primitive_t { + + acl_post_ops_t post_ops; + +- status_t init_conf_ip(engine_t *engine) { ++ status_t init_conf_ip( ++ engine_t *engine, format_kind_t weights_format_kind_received) { + + ACL_CHECK_SUPPORT(src_md()->ndims != weights_md()->ndims, + "source and weights dimensions must match"); +@@ -257,10 +271,19 @@ struct acl_inner_product_fwd_t : public primitive_t { + return status::unimplemented; + } + ++ const memory_desc_t weights_md_received = weights_md_; + acl_utils::reorder_to_weight_format(aip.wei_tensor_info, + weights_md_, expected_weight_format, inner_dim, o_dim, + remaining_dims, {}); + ++ ACL_CHECK_SUPPORT( ++ (weights_format_kind_received == format_kind::blocked) ++ && !(dnnl_memory_desc_equal( ++ &weights_md_received, &weights_md_)), ++ "specific blocked format not supported by ACL, use " ++ "format_kind_t::any to find a supported blocked format for " ++ "your platform"); ++ + // clang-format off + + // Validate fully connected layer manually to check for return status +diff --git a/src/cpu/cpu_inner_product_list.cpp b/src/cpu/cpu_inner_product_list.cpp +index fdd7b17769..1f59547304 100644 +--- a/src/cpu/cpu_inner_product_list.cpp ++++ b/src/cpu/cpu_inner_product_list.cpp +@@ -83,6 +83,15 @@ const std::map> &impl_list_map() + CPU_INSTANCE(ref_inner_product_fwd_t) + nullptr, + }}, ++ /* With graph compilation, we are able to reorder and pre-pack the weights during the model load ++ * and compilation phase itself so that redundant and on-the-fly reorders can be avoided. ++ * This primitive definition is to support gemm fastmath mode for the compile scenario where src is ++ * in fp32 and weights are in bf16 ++ */ ++ {{forward, f32, bf16, f32}, { ++ CPU_INSTANCE_AARCH64_ACL(acl_inner_product_fwd_t) ++ nullptr, ++ }}, + {{backward_data, f32, f32, f32}, REG_BWD_PK({ + CPU_INSTANCE_AMX(brgemm_inner_product_bwd_data_t) // bf32 + CPU_INSTANCE_AVX512(brgemm_inner_product_bwd_data_t) +-- +2.34.1 +