Skip to content

Commit

Permalink
[TF parser] Handle int64 dtype in range (apache#6918)
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 authored and Trevor Morris committed Dec 4, 2020
1 parent 1552a76 commit a9017dc
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
10 changes: 5 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,9 +1454,9 @@ def _impl(inputs, attr, params, mod):
break

if is_symbolic_shape:
ret = _op.shape_of(inputs[0], dtype="int32")
ret = _op.shape_of(inputs[0], dtype=attr["out_type"].name)
else:
ret = np.array(input_shape, dtype="int32")
ret = np.array(input_shape, dtype=attr["out_type"].name)
return ret

return _impl
Expand Down Expand Up @@ -1862,11 +1862,11 @@ def _impl(inputs, attr, params, mod):

dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype)
if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)):
start = _expr.const(start)
start = _expr.const(start, dtype=dtype)
if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)):
limit = _expr.const(limit)
limit = _expr.const(limit, dtype=dtype)
if isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)):
delta = _expr.const(delta)
delta = _expr.const(delta, dtype=dtype)

return AttrCvt(
op_name="arange",
Expand Down
9 changes: 5 additions & 4 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2783,10 +2783,11 @@ def test_forward_unpack():

def test_forward_range():
"""test operator Range"""
tf.reset_default_graph()
with tf.Graph().as_default():
tf.range(1, 18, 3, name="range")
compare_tf_with_tvm([], [], "range:0")
for dtype in [tf.int32, tf.int64]:
tf.reset_default_graph()
with tf.Graph().as_default():
tf.range(1, 18, 3, name="range", dtype=dtype)
compare_tf_with_tvm([], [], "range:0")

"""test type assignment for operator Range"""
tf.reset_default_graph()
Expand Down

0 comments on commit a9017dc

Please sign in to comment.