From 7f989885b36b8e13eaf81bb82e83c65de796d110 Mon Sep 17 00:00:00 2001 From: Egor Duplensky Date: Sat, 23 Mar 2024 21:11:53 +0100 Subject: [PATCH] [CPU] Prohibit fc avx2_vnni decompression for bf16 input --- .../fullyconnected_implementations.cpp | 12 +++++- .../nodes/executors/precision_translation.cpp | 8 ++-- .../nodes/executors/precision_translation.hpp | 39 +++++++++++++++++-- 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp index eea989656e49b6..5574602fc10f04 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp @@ -38,6 +38,13 @@ static const MappingNotation dnnlFCMappingNotation{ARG_SRC, ARG_WEI, ARG_BIAS, A using LayoutConfig = std::vector; static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}; +template +struct Require { + bool operator()() { + return dnnl::impl::cpu::x64::mayiuse(ISA); + } +}; + // clang-format off static const TypeMapping dnnlFCTypeMapping { // {src, wei, bia, dst} pt @@ -54,7 +61,10 @@ static const TypeMapping dnnlFCTypeMapping { {{_u8 | _i8, _i8, _any, _f16}, pt(bypass(), bypass(), just(), just())}, {{_u8 | _i8, _i8, _any, _u8 | _i8 | _i32 | _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())}, // compresses int weights (@todo more strict requrements for output precision?) - {{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, + {{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>()), + Require()}, // Ticket 122347 + {{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(just(), bypass(), just(), just())}, + {{_f32, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())}, // @todo should we fallback to FPXX instead of _f32? {{_any, _any, _any, _any}, pt(just(), just(), just(), just())}, // @todo explicitly cover configuration limitations for oneDNN on ARM diff --git a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp index cda7b4e47a0c0b..73aac151843b08 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp @@ -21,12 +21,14 @@ InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, const TypeMap }); for (const auto& entry : mapping) { - const auto& pattern = entry.first; + if (!entry.enabled()) + continue; + + const auto& pattern = entry.mask(); if (!match(pattern, types)) continue; - const auto& translator = entry.second; - return translator(types); + return entry.translate(types); } OPENVINO_THROW("Failed to create a type configuration for the provided memory descriptors"); diff --git a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp index 4e9be3b3dd5691..374b584dd0ffb5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp @@ -6,7 +6,6 @@ #include #include -#include #include #include "nodes/executors/memory_arguments.hpp" @@ -82,9 +81,43 @@ struct PortsTranslation { // pros: should be more efficient and safe // cons: more template instances (binary size) of the translation utility functions using InOutTypes = std::vector; -using PortsConfigurationImpl = std::function; +using TypeTranslationFunction = std::function; using InOutTypeMask = std::vector; -using TypeMapping = std::vector>; + +class TypeMappingEntry { +public: + using EnabledPredicate = std::function; + + TypeMappingEntry(InOutTypeMask mask, + TypeTranslationFunction translation, + EnabledPredicate enabled = {}) + : m_mask(std::move(mask)), + m_translation(std::move(translation)), + m_enabled(std::move(enabled)) {} + + const InOutTypeMask& mask() const { + return m_mask; + } + + InOutTypes translate(const InOutTypes& types) const { + if (m_translation) + return m_translation(types); + return {}; + } + + bool enabled() const { + if (m_enabled) + return m_enabled(); + return true; + } + +private: + InOutTypeMask m_mask; + TypeTranslationFunction m_translation; + EnabledPredicate m_enabled; +}; + +using TypeMapping = std::vector; using MappingNotation = std::vector; using pt = PortsTranslation;