Skip to content

Commit

Permalink
[bugfix][Relay] fix scatter_nd type relation
Browse files Browse the repository at this point in the history
ScatterND requires updates.shape[K:] == output.shape[M:],  
not  data.shape[K:] == output.shape[M:]
  • Loading branch information
JR4er authored May 5, 2023
1 parent eca6edf commit 5ba4e00
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const size_t kdim = indices->shape.size() - 1;
const size_t ndim = out_shape.size();
ICHECK_LE(size_t(mdim->value), ndim)
<< "ScatterND: Given data with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices "
<< "ScatterND: Given updates with shape (Y_0, ..., Y_{K-1}, X_M, ..., X_{N-1}), and indices "
"with shape (M, Y_0, ..., Y_{K-1}), M must be less than or equal to N.";
// Indices: (M, Y_0, .. Y_{K-1}) data: (Y_0, .. Y_{K-1}, ...), verify Y's.
for (size_t i = 0; i < kdim; i++) {
Expand All @@ -1197,9 +1197,9 @@ bool ScatterNDRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
oshape.push_back(x);
}

// data: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1}
// updates: (Y_0, .. Y_{K-1}, X_M, .. X_{N-1}) out: (X_0, .. X_{N-1}), verify X_M to X_{N-1}
for (size_t i = mdim->value; i < ndim; i++) {
reporter->AssertEQ(data->shape[i - mdim->value + kdim], oshape[i]);
reporter->AssertEQ(updates->shape[i - mdim->value + kdim], oshape[i]);
}

reporter->Assign(types[3], TensorType(data->shape, data->dtype));
Expand Down

0 comments on commit 5ba4e00

Please sign in to comment.