Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CPU] Prohibit fc avx2_vnni_2 decompression for bf16 input #23638

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ static const MappingNotation dnnlFCMappingNotation{ARG_SRC, ARG_WEI, ARG_BIAS, A
using LayoutConfig = std::vector<LayoutType>;
static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp};

template<dnnl::impl::cpu::x64::cpu_isa_t ISA>
struct Require {
bool operator()() {
return dnnl::impl::cpu::x64::mayiuse(ISA);
}
};

// clang-format off
static const TypeMapping dnnlFCTypeMapping {
// {src, wei, bia, dst} pt<src, wei, bias, dst>
Expand All @@ -54,7 +61,10 @@ static const TypeMapping dnnlFCTypeMapping {
{{_u8 | _i8, _i8, _any, _f16}, pt(bypass(), bypass(), just<f32>(), just<f32>())},
{{_u8 | _i8, _i8, _any, _u8 | _i8 | _i32 | _bf16 | _f32}, pt(bypass(), bypass(), use<3>(), bypass())},
// compresses int weights (@todo more strict requrements for output precision?)
{{_f32 | _bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
{{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>()),
Require<dnnl::impl::cpu::x64::avx512_core_bf16>()}, // Ticket 122347
dmitry-gorokhov marked this conversation as resolved.
Show resolved Hide resolved
{{_bf16, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(just<f32>(), bypass(), just<f32>(), just<f32>())},
{{_f32, _u8 | _nf4 | _u4 | _i4, _any, _any}, pt(bypass(), bypass(), use<0>(), use<0>())},
// @todo should we fallback to FPXX instead of _f32?
{{_any, _any, _any, _any}, pt(just<f32>(), just<f32>(), just<f32>(), just<f32>())},
// @todo explicitly cover configuration limitations for oneDNN on ARM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, const TypeMap
});

for (const auto& entry : mapping) {
const auto& pattern = entry.first;
if (!entry.enabled())
continue;

const auto& pattern = entry.mask();
if (!match(pattern, types))
continue;

const auto& translator = entry.second;
return translator(types);
return entry.translate(types);
}

OPENVINO_THROW("Failed to create a type configuration for the provided memory descriptors");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

#include <cassert>
#include <functional>
#include <utility>
#include <vector>

#include "nodes/executors/memory_arguments.hpp"
Expand Down Expand Up @@ -82,9 +81,43 @@ struct PortsTranslation {
// pros: should be more efficient and safe
// cons: more template instances (binary size) of the translation utility functions
using InOutTypes = std::vector<ov::element::Type>;
using PortsConfigurationImpl = std::function<InOutTypes(const InOutTypes&)>;
using TypeTranslationFunction = std::function<InOutTypes(const InOutTypes&)>;
using InOutTypeMask = std::vector<TypeMask>;
using TypeMapping = std::vector<std::pair<InOutTypeMask, PortsConfigurationImpl>>;

class TypeMappingEntry {
public:
using EnabledPredicate = std::function<bool(void)>;

TypeMappingEntry(InOutTypeMask mask,
TypeTranslationFunction translation,
EnabledPredicate enabled = {})
: m_mask(std::move(mask)),
m_translation(std::move(translation)),
m_enabled(std::move(enabled)) {}

const InOutTypeMask& mask() const {
return m_mask;
}

InOutTypes translate(const InOutTypes& types) const {
if (m_translation)
return m_translation(types);
return {};
}

bool enabled() const {
if (m_enabled)
return m_enabled();
return true;
}

private:
InOutTypeMask m_mask;
TypeTranslationFunction m_translation;
EnabledPredicate m_enabled;
};

using TypeMapping = std::vector<TypeMappingEntry>;
using MappingNotation = std::vector<int>;
using pt = PortsTranslation;

Expand Down
Loading