diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 613643f091d7..c13d791cf2e2 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -573,6 +573,12 @@ def repeat_interleave(self, inputs, input_types): if isinstance(inputs[1], int): repeats = inputs[1] axis = inputs[2] + elif isinstance(inputs[1], _expr.Expr): + if isinstance(inputs[1], _expr.Constant): + repeats = int(inputs[1].data.numpy()) + else: + repeats, _ = try_infer_value(inputs[1], lambda ret: ret.tolist()) + axis = inputs[2] else: msg = "Only repeat with one value as repeat is currently supported." raise AssertionError(msg)