From e9b65c798da50216c565eb2cbdd880e5d53a1d1d Mon Sep 17 00:00:00 2001 From: Joey Tsai Date: Wed, 2 Mar 2022 15:19:55 +0800 Subject: [PATCH] [Experiment] faill to mutate expression while converting LSTM * ExprMutator without any change fails to mutate expressions during converting LSTM * [1]https://github.com/apache/tvm/pull/9723#issuecomment-1020952881 * [2]https://github.com/apache/tvm/blob/122be3fb18902bf2317797fedfa867dcf9607ef9/src/relay/transforms/de_duplicate.cc#L113 * [3] https://github.com/apache/tvm/blob/8f6fa8f2c41406cb54d01647ba8731e4ceb8f4ab/src/ir/module.cc#L202 * [4] https://github.com/apache/tvm/blob/8f6fa8f2c41406cb54d01647ba8731e4ceb8f4ab/python/tvm/relay/expr_functor.py#L216 * [5]https://github.com/apache/tvm/pull/10072 * [6]https://github.com/apache/tvm/pull/9723#issuecomment-1022059117 * [7]https://github.com/apache/tvm/blob/8f6fa8f2c41406cb54d01647ba8731e4ceb8f4ab/src/relay/transforms/de_duplicate.cc --- python/tvm/relay/frontend/pytorch.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index ee75220ed3922..570888d9e8aec 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -3402,6 +3402,21 @@ def _handel_nested_input(inputs): relay_out = relay_op( inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) ) + + def do_mutate(sym): + class PureMutator(ExprMutator): + def __init__(self): + ExprMutator.__init__(self) + + def mutate(self, sym): + if isinstance(sym, _expr.TupleWrapper): + return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size) + if isinstance(sym, _expr.RelayExpr): + return self.visit(sym) + return sym + return PureMutator().mutate(sym) + + relay_out = do_mutate(relay_out) self.record_output_type(relay_out) if isinstance(relay_out, tuple):