Skip to content

Commit

Permalink
[GPU] Fix of the processing of negative indices in Gather in dynamic …
Browse files Browse the repository at this point in the history
…case (#22716)

### Tickets:
 - *131414*
  • Loading branch information
Lyamin-Roman authored Feb 8, 2024
1 parent 133d03e commit 46a2734
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
#include "include/batch_headers/int4_utils.cl"

#ifdef INDEX_DIM
inline uint FUNC(get_positive_index)(int in)
inline uint FUNC(get_positive_index)(OPTIONAL_SHAPE_INFO_ARG int in)
{
if(in < 0)
if (in < 0)
return in + INDEX_DIM;
else
return in;
}
#define INPUT_AXIS_INDEX (uint)FUNC_CALL(get_positive_index)(indices[indices_idx])
#define INPUT_AXIS_INDEX (uint)FUNC_CALL(get_positive_index)(OPTIONAL_SHAPE_INFO_TENSOR indices[indices_idx])
#else
#define INPUT_AXIS_INDEX (uint)(indices[indices_idx])
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,17 @@ JitConstants GatherKernelRef::GetJitConstants(const gather_params& params) const

jit.AddConstant(MakeJitConstant("DICTIONARY_INDEX_ORDER", GetDictionaryIndexOrder(params, GetGatherChannelIndex(params))));
jit.AddConstant(MakeJitConstant("INDICES_INDEX_ORDER", GetIndicesIdxOrder(params, GetGatherChannelIndex(params), GetGatherBatchDim(params))));
if (params.support_neg_ind)
jit.AddConstant(MakeJitConstant("INDEX_DIM", GetGatherMaxIndexDim(params)));

if (!GetGatherIndexDim(params).is_dynamic)
bool dyn_gather_idx_dim = GetGatherIndexDim(params).is_dynamic;
if (params.support_neg_ind) {
if (!dyn_gather_idx_dim) {
jit.AddConstant(MakeJitConstant("INDEX_DIM", GetGatherMaxIndexDim(params)));
} else {
jit.AddConstant(MakeJitConstant("INDEX_DIM", "shape_info[" + std::to_string(GetGatherAxisIndexInShapeInfo(params)) + "]"));
}
}

if (!dyn_gather_idx_dim)
jit.AddConstant(MakeJitConstant("AXIS_DIM", GetGatherMaxIndexDim(params)));

if (params.is_shape_agnostic)
Expand Down
46 changes: 46 additions & 0 deletions src/plugins/intel_gpu/tests/unit/test_cases/gather_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2327,3 +2327,49 @@ TEST_F(gather_gpu_tests, compressed_scale_fp16) {
TEST_F(gather_gpu_tests, compressed_scale_fp16_cached) {
this->test_compressed_scale_fp16(true);
}

TEST(gather_gpu_fp32, dynamic_support_neg_ind) {
auto& engine = get_test_engine();

ov::Shape data_shape = { 3, 3 };
ov::Shape indices_shape = {};
int64_t axis = 1;

auto data_layout = layout{ov::PartialShape::dynamic(data_shape.size()), data_types::f32, format::bfyx};
auto indices_layout = layout{ov::PartialShape::dynamic(indices_shape.size()), data_types::i32, format::bfyx};

auto data_mem = engine.allocate_memory(layout{ov::PartialShape(data_shape), data_types::f32, format::bfyx});
auto indices_mem = engine.allocate_memory(layout{ov::PartialShape(indices_shape), data_types::i32, format::bfyx});

set_values(data_mem, { 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f });
set_values(indices_mem, { -1 });

topology topology;
topology.add(input_layout("data", data_layout));
topology.add(input_layout("indices", indices_layout));
topology.add(gather("gather", input_info("data"), input_info("indices"), axis, data_shape.size(), ov::Shape{}, 0, true));

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
network network(engine, topology, config);

network.set_input_data("data", data_mem);
network.set_input_data("indices", indices_mem);

auto inst = network.get_primitive("gather");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());

auto outputs = network.execute();

auto output = outputs.at("gather").get_memory();
cldnn::mem_lock<float> output_ptr(output, get_test_stream());

std::vector<float> expected_results = { 2.f, 5.f, 8.f };

ASSERT_EQ(expected_results.size(), output_ptr.size());
for (size_t i = 0; i < expected_results.size(); ++i) {
ASSERT_EQ(expected_results[i], output_ptr[i]) << i;
}
}

0 comments on commit 46a2734

Please sign in to comment.