Skip to content

Commit

Permalink
[CPU] [ARM64] int8 support
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 7, 2024
1 parent 1c5a736 commit 24db6cc
Show file tree
Hide file tree
Showing 13 changed files with 475 additions and 230 deletions.
2 changes: 1 addition & 1 deletion src/core/src/op/divide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ bool Divide::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
this,
outputs,
inputs,
OV_PP_ET_LIST(f32, i32, i64, u32, u64),
OV_PP_ET_LIST(f32, i8, i32, i64, u8, u32, u64),
divide::Evaluate,
inputs[0].get_element_type(),
inputs[0],
Expand Down
2 changes: 1 addition & 1 deletion src/core/src/op/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ bool Multiply::evaluate(TensorVector& outputs, const TensorVector& inputs) const
this,
outputs,
inputs,
OV_PP_ET_LIST(f32, f64, i32, i64, u32, u64),
OV_PP_ET_LIST(f32, f64, i8, i32, i64, u8, u32, u64),
multiply::Evaluate,
inputs[0].get_element_type(),
inputs[0],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ bool JitEltwiseExecutor::isSupported(
static const std::set<ov::element::Type> supported_precisions = {
ov::element::f16,
ov::element::f32,
ov::element::i32
ov::element::i32,
ov::element::i8,
ov::element::u8
};

if (!check_precisions(input_precisions, output_precisions, supported_precisions)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,32 @@ void jit_uni_eltwise_generic<isa>::generate() {
}
}

namespace utils {
template <typename T1, typename T2>
void load_vector(const T1& data_lane,
const T2& data_lanes,
const Xbyak_aarch64::XReg &ptr_reg,
const int64_t offset,
const bool broadcast,
jit_generator* h) {
if (broadcast) {
if (offset == 0) {
h->ld1r(data_lane, ptr(ptr_reg));
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, offset, h->X_TMP_0);
h->ld1r(data_lane, ptr(h->X_DEFAULT_ADDR));
}
} else {
if (offset == 0) {
h->ld1(data_lanes, Xbyak_aarch64::ptr(ptr_reg));
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, offset, h->X_TMP_0);
h->ld1(data_lanes, Xbyak_aarch64::ptr(h->X_DEFAULT_ADDR));
}
}
}
} // namespace utils

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
const XReg& ptr_reg,
Expand All @@ -281,16 +307,7 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
const int32_t ptr_offset) {
switch (src_prc) {
case ov::element::f16: {
if (broadcast) {
if (ptr_offset == 0) {
ld1r(data.h, ptr(ptr_reg));
} else {
add_imm(ptr_reg, ptr_reg, ptr_offset, X_DEFAULT_ADDR);
ld1r(data.h, ptr(ptr_reg));
}
} else {
ldr(Xbyak_aarch64::DReg(data.getIdx()), Xbyak_aarch64::ptr(ptr_reg, ptr_offset));
}
utils::load_vector(data.h, data.h4, ptr_reg, ptr_offset, broadcast, this);
break;
}
case ov::element::f32:
Expand All @@ -302,6 +319,11 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
}
break;
}
case ov::element::i8:
case ov::element::u8: {
utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this);
break;
}
default: {
OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string());
}
Expand All @@ -319,6 +341,18 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
scvtf(data.s, data.s);
break;
}
case ov::element::i8: {
sshll(data.h8, data.b8, 0);
sshll(data.s4, data.h4, 0);
scvtf(data.s, data.s);
break;
}
case ov::element::u8: {
ushll(data.h8, data.b8, 0);
ushll(data.s4, data.h4, 0);
ucvtf(data.s, data.s);
break;
}
default:
OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string());
}
Expand All @@ -345,6 +379,24 @@ void jit_uni_eltwise_generic<isa>::load_scalar(const SReg& data,
ldr(data, Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::i8: {
ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));

// scalar is loaded, operates with vector
TReg vec(data.getIdx());
sshll(vec.h8, vec.b8, 0);
sshll(vec.s4, vec.h4, 0);
break;
}
case ov::element::u8: {
ldr(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));

// scalar is loaded, operates with vector
TReg vec(data.getIdx());
ushll(vec.h8, vec.b8, 0);
ushll(vec.s4, vec.h4, 0);
break;
}
default: {
OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string());
}
Expand All @@ -358,10 +410,15 @@ void jit_uni_eltwise_generic<isa>::load_scalar(const SReg& data,
fcvt(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::HReg(data.getIdx()));
break;
}
case ov::element::i32: {
case ov::element::i32:
case ov::element::i8: {
scvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx()));
break;
}
case ov::element::u8: {
ucvtf(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::SReg(data.getIdx()));
break;
}
default:
OPENVINO_THROW("src_prc " + src_prc.to_string() + " is not supported, dst_prc is " + dst_prc.to_string());
}
Expand Down Expand Up @@ -390,6 +447,18 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,
fcvtns(data.s, data.s);
break;
}
case ov::element::i8: {
fcvtns(data.s, data.s);
xtn(data.h4, data.s4);
xtn(data.b8, data.h8);
break;
}
case ov::element::u8: {
fcvtnu(data.s, data.s);
xtn(data.h4, data.s4);
xtn(data.b8, data.h8);
break;
}
default: {
OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string());
}
Expand All @@ -412,6 +481,11 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,
str(Xbyak_aarch64::QReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::i8:
case ov::element::u8: {
str(Xbyak_aarch64::SReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
default: {
OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_ptr is " + src_prc.to_string());
}
Expand All @@ -436,6 +510,20 @@ void jit_uni_eltwise_generic<isa>::store_scalar(const XReg& ptr,
fcvtns(data, data);
break;
}
case ov::element::i8: {
TReg vec_data(data.getIdx());
fcvtns(vec_data.s, vec_data.s);
xtn(vec_data.h4, vec_data.s4);
xtn(vec_data.b8, vec_data.h8);
break;
}
case ov::element::u8: {
TReg vec_data(data.getIdx());
fcvtnu(vec_data.s, vec_data.s);
xtn(vec_data.h4, vec_data.s4);
xtn(vec_data.b8, vec_data.h8);
break;
}
default: {
OPENVINO_THROW("dst_prc " + dst_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string());
}
Expand All @@ -458,6 +546,11 @@ void jit_uni_eltwise_generic<isa>::store_scalar(const XReg& ptr,
str(data, Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
case ov::element::i8:
case ov::element::u8: {
str(Xbyak_aarch64::BReg(data.getIdx()), Xbyak_aarch64::ptr(ptr, ptr_offset));
break;
}
default: {
OPENVINO_THROW("dst_prc " + src_prc.to_string() + " is not supported, src_prc is " + src_prc.to_string());
}
Expand Down
8 changes: 5 additions & 3 deletions src/plugins/intel_cpu/tests/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ else()
file(GLOB_RECURSE TMP_LIST_OF_TEST_CLASSES ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/classes/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_COMMON_TEST_INSTANCES ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/common/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_ARM_TEST_INSTANCES ${CMAKE_CURRENT_SOURCE_DIR}/custom/single_layer_tests/instances/arm/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_ARM_SUBGRAPH_TESTS ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/arm/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_ARM_SUBGRAPH_TESTS ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/common/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_ARM_SUBGRAPH_TESTS ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/arm/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_COMMON_SUBGRAPH_TESTS ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/common/*.cpp)
file(GLOB_RECURSE TMP_LIST_OF_SUBGRAPH_TEST_CLASSES ${CMAKE_CURRENT_SOURCE_DIR}/custom/subgraph_tests/src/classes/*.*)

list(APPEND TMP_LIST_OF_EXPLICITLY_ENABLED_TESTS
${TMP_LIST_OF_TEST_CLASSES} ${TMP_LIST_OF_COMMON_TEST_INSTANCES} ${TMP_LIST_OF_ARM_TEST_INSTANCES} ${TMP_LIST_OF_ARM_SUBGRAPH_TESTS})
${TMP_LIST_OF_TEST_CLASSES} ${TMP_LIST_OF_COMMON_TEST_INSTANCES} ${TMP_LIST_OF_ARM_TEST_INSTANCES} ${TMP_LIST_OF_ARM_SUBGRAPH_TESTS} ${TMP_LIST_OF_COMMON_SUBGRAPH_TESTS} ${TMP_LIST_OF_SUBGRAPH_TEST_CLASSES})
set(TMP_EXPLICITLY_ENABLED_TESTS "${TMP_LIST_OF_EXPLICITLY_ENABLED_TESTS}")
endif()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,18 @@ ov::Tensor EltwiseLayerCPUTest::generate_eltwise_input(const ov::element::Type&
} else {
switch (type) {
case ov::element::i8:
params = gen_params(INT8_MAX, INT8_MIN);
if (adopt_intervals) {
params = gen_params(11 * 2, -11);
} else {
params = gen_params(INT8_MAX, INT8_MIN);
}
break;
case ov::element::u8:
params = gen_params(UINT8_MAX, 0);
if (adopt_intervals) {
params = gen_params(15, 0);
} else {
params = gen_params(UINT8_MAX, 0);
}
break;
case ov::element::i16:
params = gen_params(INT16_MAX, INT16_MIN);
Expand Down Expand Up @@ -109,7 +117,8 @@ void EltwiseLayerCPUTest::generate_inputs(const std::vector<ov::Shape>& targetIn
inputs.insert({funcInput.get_node_shared_ptr(), generate_eltwise_input(
funcInput.get_element_type(),
targetInputStaticShapes[i],
(funcInput.get_element_type() == element::i32) || (funcInput.get_element_type() == element::u32))});
(funcInput.get_element_type() == element::i32) || (funcInput.get_element_type() == element::u32) ||
(funcInput.get_element_type() == element::i8) || (funcInput.get_element_type() == element::u8))});
}
}

Expand Down Expand Up @@ -199,7 +208,11 @@ void EltwiseLayerCPUTest::SetUp() {
}
}

auto data_tensor = generate_eltwise_input(netType, shape, (netType == element::i32) || (netType == element::u32));
auto data_tensor = generate_eltwise_input(
netType,
shape,
(netType == element::i32) || (netType == element::u32) ||
(netType == element::i8) || (netType == element::u8));
if ((netType == ElementType::i8) || (netType == ElementType::u8)) {
auto data_ptr = reinterpret_cast<uint8_t*>(data_tensor.data());
std::vector<uint8_t> data(data_ptr, data_ptr + ov::shape_size(shape));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ const auto params_4D_int_jit = ::testing::Combine(
::testing::ValuesIn({ utils::EltwiseTypes::ADD, utils::EltwiseTypes::MULTIPLY }),
::testing::ValuesIn(secondaryInputTypes()),
::testing::ValuesIn(opTypes()),
::testing::ValuesIn({ ElementType::i32, ElementType::f32 }),
::testing::ValuesIn({ ElementType::i8, ElementType::u8, ElementType::f16, ElementType::i32, ElementType::f32 }),
::testing::Values(ov::element::undefined),
::testing::Values(ov::element::undefined),
::testing::Values(ov::test::utils::DEVICE_CPU),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include <tuple>
#include <string>
#include <vector>

#include "custom/subgraph_tests/src/classes/eltwise_chain.hpp"

#include "shared_test_classes/base/ov_subgraph.hpp"
#include "common_test_utils/node_builders/constant.hpp"
#include "common_test_utils/node_builders/eltwise.hpp"
#include "common_test_utils/ov_tensor_utils.hpp"

using namespace CPUTestUtils;

namespace ov {
namespace test {
using namespace ov::test::utils;
using namespace ov::test::eltwise_chain;

namespace {

std::vector<std::vector<EltwiseTypes>> eltwiseOpsConvertInt8 = {
{ EltwiseTypes::MULTIPLY },
{ EltwiseTypes::ADD },
{ EltwiseTypes::DIVIDE }
};

INSTANTIATE_TEST_SUITE_P(smoke_EltwiseChain_MergeConvert_int8, EltwiseChainTest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapesConvert())),
::testing::Values(InputLayerType::CONSTANT),
::testing::ValuesIn(inputPrecisionsConvert()),
::testing::ValuesIn(eltwiseOpsConvertInt8),
::testing::Values(false),
::testing::ValuesIn({ov::element::i8, ov::element::u8}),
::testing::Values(ov::test::utils::DEVICE_CPU)),
EltwiseChainTest::getTestCaseName);

} // namespace
} // namespace test
} // namespace ov
Loading

0 comments on commit 24db6cc

Please sign in to comment.