Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 20, 2021
1 parent 8bfd736 commit e6876dc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
4 changes: 1 addition & 3 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@ struct GatherNDAttrs : public tvm::AttrsNode<GatherNDAttrs> {
Integer batch_dims;

TVM_DECLARE_ATTRS(GatherAttrs, "relay.attrs.GatherNDAttrs") {
TVM_ATTR_FIELD(batch_dims)
.set_default(Integer(0))
.describe("The number of batch dimensions.");
TVM_ATTR_FIELD(batch_dims).set_default(Integer(0)).describe("The number of batch dimensions.");
}
};
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,11 +1276,11 @@ def gather_nd_batch_dims_1_ref(data, indices):

if batch_dims > 1:
x_data_reshape = np.reshape(x_data, (-1,) + xshape[batch_dims:])
y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1):])
y_data_reshape = np.reshape(y_data, (yshape[0], -1) + yshape[(batch_dims + 1) :])

ref_res = gather_nd_batch_dims_1_ref(x_data_reshape, y_data_reshape)

out_shape = yshape[1: (batch_dims + 1)] + ref_res.shape[1:]
out_shape = yshape[1 : (batch_dims + 1)] + ref_res.shape[1:]
ref_res = np.reshape(ref_res, out_shape)
elif batch_dims == 1:
ref_res = gather_nd_batch_dims_1_ref(x_data, y_data)
Expand Down

0 comments on commit e6876dc

Please sign in to comment.