Skip to content

Commit

Permalink
[CPU] [ARM64] int8 support: comment fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Mar 20, 2024
1 parent efe6c69 commit 2fc13c1
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,10 @@ void load_vector(const T1& data_lane,
}
} else {
if (offset == 0) {
h->ld1(data_lanes, Xbyak_aarch64::ptr(ptr_reg));
h->ld1(data_lanes, ptr(ptr_reg));
} else {
h->add_imm(h->X_DEFAULT_ADDR, ptr_reg, offset, h->X_TMP_0);
h->ld1(data_lanes, Xbyak_aarch64::ptr(h->X_DEFAULT_ADDR));
h->ld1(data_lanes, ptr(h->X_DEFAULT_ADDR));
}
}
}
Expand All @@ -319,9 +319,16 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
}
break;
}
case ov::element::i8:
case ov::element::i8: {
utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this);
sshll(data.h8, data.b8, 0);
sshll(data.s4, data.h4, 0);
break;
}
case ov::element::u8: {
utils::load_vector(data.b, data.s, ptr_reg, ptr_offset, broadcast, this);
ushll(data.h8, data.b8, 0);
ushll(data.s4, data.h4, 0);
break;
}
default: {
Expand All @@ -342,14 +349,10 @@ void jit_uni_eltwise_generic<isa>::load_vector(const TReg& data,
break;
}
case ov::element::i8: {
sshll(data.h8, data.b8, 0);
sshll(data.s4, data.h4, 0);
scvtf(data.s, data.s);
break;
}
case ov::element::u8: {
ushll(data.h8, data.b8, 0);
ushll(data.s4, data.h4, 0);
ucvtf(data.s, data.s);
break;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ std::string EltwiseLayerCPUTest::getTestCaseName(testing::TestParamInfo<EltwiseL
return result.str();
}

// If adopt_intervals is true then:
// 1) the generated tensor value range is limited by operation result value (especially for multiply)
// which has to be in signed/unsigned int8 type range,
// 2) start value is defined by type sign: for signed int8 it's zero to have symmetric interval.
ov::Tensor EltwiseLayerCPUTest::generate_eltwise_input(const ov::element::Type& type, const ov::Shape& shape, const bool adopt_intervals) {
struct gen_params {
uint32_t range;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ std::string EltwiseChainTest::getTestCaseName(const testing::TestParamInfo<Eltwi
return results.str();
}

ov::Tensor EltwiseChainTest::generate_eltwise_input(const ov::element::Type& type, const ov::Shape& shape, const bool adopt_intervals) {
ov::Tensor EltwiseChainTest::generate_eltwise_input(const ov::element::Type& type, const ov::Shape& shape) {
struct gen_params {
uint32_t range;
int32_t start_from;
Expand All @@ -62,7 +62,7 @@ ov::Tensor EltwiseChainTest::generate_eltwise_input(const ov::element::Type& typ
: range(range), start_from(start_from), resolution(resolution) {}
};

gen_params params = type.is_real() ? gen_params(10, 1) : gen_params(10, 5);
gen_params params = type.is_real() ? gen_params(10, 1) : gen_params(10, 10);

ov::test::utils::InputGenerateData in_data;
in_data.start_from = params.start_from;
Expand All @@ -79,9 +79,7 @@ void EltwiseChainTest::generate_inputs(const std::vector<ov::Shape>& targetInput
const auto& funcInput = funcInputs[i];
inputs.insert({funcInput.get_node_shared_ptr(), generate_eltwise_input(
funcInput.get_element_type(),
targetInputStaticShapes[i],
(funcInput.get_element_type() == element::i32) || (funcInput.get_element_type() == element::u32) ||
(funcInput.get_element_type() == element::i8) || (funcInput.get_element_type() == element::u8))});
targetInputStaticShapes[i])});
}
}

Expand Down Expand Up @@ -217,7 +215,6 @@ std::vector<std::vector<ElementType>> inputPrecisionsConvert() {
{ElementType::i16, ElementType::f32, ElementType::f32},
{ElementType::u16, ElementType::f32, ElementType::f32},
{ElementType::i32, ElementType::f32, ElementType::f32},
// { ElementType::u32, ElementType::f32, ElementType::f32 }, // plugin doesn't support
{ElementType::f16, ElementType::f32, ElementType::f32},
{ElementType::f32, ElementType::f32, ElementType::f32},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class EltwiseChainTest : public testing::WithParamInterface<EltwiseChainTuple>,
virtual public SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<EltwiseChainTuple> &obj);
ov::Tensor generate_eltwise_input(const ov::element::Type& type, const ov::Shape& shape, const bool adopt_intervals);
ov::Tensor generate_eltwise_input(const ov::element::Type& type, const ov::Shape& shape);
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;

protected:
Expand Down

0 comments on commit 2fc13c1

Please sign in to comment.