Skip to content

Commit

Permalink
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 0 deletions.
17 changes: 17 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3402,6 +3402,23 @@ def _handel_nested_input(inputs):
relay_out = relay_op(
inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype)
)

def do_mutate(sym):
class PureMutator(ExprMutator):
def __init__(self):
ExprMutator.__init__(self)

def mutate(self, sym):
if isinstance(sym, _expr.TupleWrapper):
return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size)
if isinstance(sym, _expr.RelayExpr):
return self.visit(sym)
return sym
return PureMutator().mutate(sym)

# comment the following line to get original relay IR
relay_out = do_mutate(relay_out)

self.record_output_type(relay_out)

if isinstance(relay_out, tuple):
Expand Down
85 changes: 85 additions & 0 deletions tests/python/frontend/pytorch/after_mutator
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
fn (%input: Tensor[(1, 5, 10), float32], %lstm.weight_ih_l0: Tensor[(40, 10), float32], %lstm.weight_hh_l0: Tensor[(40, 10), float32], %lstm.bias_ih_l0: Tensor[(40), float32], %lstm.bias_hh_l0: Tensor[(40), float32], %linear.weight: Tensor[(10, 20), float32], %linear.bias: Tensor[(10), float32]) {
%0 = split(%input, indices_or_sections=1);
%1 = %0.0;
%2 = squeeze(%1, axis=[0]);
%3 = (%2,);
%4 = full(0, shape=[1, 5, 10], dtype="float32");
%5 = split(%4, indices_or_sections=1);
%6 = %5.0;
%7 = squeeze(%6, axis=[0]);
%8 = (%7,);
%9 = %3.0;
%10 = %8.0;
%11 = (%9, %10);
%12 = (%lstm.weight_ih_l0, %lstm.weight_hh_l0);
%13 = concatenate(%11, axis=1);
%14 = concatenate(%12, axis=1);
%15 = nn.dense(%13, %14, units=None);
%16 = add(%15, %lstm.bias_ih_l0);
%17 = add(%16, %lstm.bias_hh_l0);
%18 = split(%17, indices_or_sections=4, axis=-1);
%19 = %18.3;
%20 = %18.1;
%21 = full(0, shape=[1, 5, 10], dtype="float32");
%22 = split(%21, indices_or_sections=1);
%23 = %22.0;
%24 = squeeze(%23, axis=[0]);
%25 = (%24,);
%26 = sigmoid(%20);
%27 = %25.0;
%28 = %18.0;
%29 = %18.2;
%30 = sigmoid(%28);
%31 = tanh(%29);
%32 = multiply(%26, %27);
%33 = multiply(%30, %31);
%34 = add(%32, %33);
%35 = sigmoid(%19);
%36 = tanh(%34);
%37 = multiply(%35, %36);
%38 = (%37,);
%39 = stack(%38);
%40 = split(%input, indices_or_sections=1);
%41 = %40.0;
%42 = squeeze(%41, axis=[0]);
%43 = (%42,);
%44 = full(0, shape=[1, 5, 10], dtype="float32");
%45 = split(%44, indices_or_sections=1);
%46 = %45.0;
%47 = squeeze(%46, axis=[0]);
%48 = (%47,);
%49 = %43.0;
%50 = %48.0;
%51 = (%49, %50);
%52 = (%lstm.weight_ih_l0, %lstm.weight_hh_l0);
%53 = concatenate(%51, axis=1);
%54 = concatenate(%52, axis=1);
%55 = nn.dense(%53, %54, units=None);
%56 = add(%55, %lstm.bias_ih_l0);
%57 = add(%56, %lstm.bias_hh_l0);
%58 = split(%57, indices_or_sections=4, axis=-1);
%59 = %58.1;
%60 = full(0, shape=[1, 5, 10], dtype="float32");
%61 = split(%60, indices_or_sections=1);
%62 = %61.0;
%63 = squeeze(%62, axis=[0]);
%64 = (%63,);
%65 = sigmoid(%59);
%66 = %64.0;
%67 = %58.0;
%68 = %58.2;
%69 = sigmoid(%67);
%70 = tanh(%68);
%71 = multiply(%65, %66);
%72 = multiply(%69, %70);
%73 = add(%71, %72);
%74 = (%73,);
%75 = stack(%74);
%76 = take(%39, -1, axis=0, mode="wrap");
%77 = take(%75, -1, axis=0, mode="wrap");
%78 = (%76, %77);
%79 = concatenate(%78, axis=1);
%80 = nn.dense(%79, %linear.weight, units=None);
%81 = nn.bias_add(%80, %linear.bias, axis=-1);
nn.log_softmax(%81)
}
33 changes: 33 additions & 0 deletions tests/python/frontend/pytorch/my_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

import tvm
from tvm import relay


class LSTM(nn.Module):

def __init__(self, input_dim, hidden_dim, target_size):
super(LSTM, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim)
self.linear = nn.Linear(hidden_dim * 2, target_size)

def forward(self, inputs):
lstm_out, (hidden, cell) = self.lstm(inputs)
x = torch.cat((lstm_out[-1], cell[-1]), 1)
logits = self.linear(x)
log_probs = F.log_softmax(logits, dim=-1)
return log_probs


if __name__ == "__main__":
model = LSTM(10, 10, 10)
torch.save(model, "./lstm.pt")
model = model.eval()
input_shape = [1, 5, 10]
input_data = torch.randn(input_shape)
scripted_model = torch.jit.trace(model, input_data).eval()
mod, params = relay.frontend.from_pytorch(scripted_model, [('input', tuple(input_shape))])
print(mod['main'])
51 changes: 51 additions & 0 deletions tests/python/frontend/pytorch/ori_IR
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
fn (%input: Tensor[(1, 5, 10), float32], %lstm.weight_ih_l0: Tensor[(40, 10), float32], %lstm.weight_hh_l0: Tensor[(40, 10), float32], %lstm.bias_ih_l0: Tensor[(40), float32], %lstm.bias_hh_l0: Tensor[(40), float32], %linear.weight: Tensor[(10, 20), float32], %linear.bias: Tensor[(10), float32]) {
%0 = split(%input, indices_or_sections=1);
%1 = %0.0;
%2 = squeeze(%1, axis=[0]);
%3 = (%2,);
%4 = full(0, shape=[1, 5, 10], dtype="float32");
%5 = split(%4, indices_or_sections=1);
%6 = %5.0;
%7 = squeeze(%6, axis=[0]);
%8 = (%7,);
%9 = %3.0;
%10 = %8.0;
%11 = (%9, %10);
%12 = (%lstm.weight_ih_l0, %lstm.weight_hh_l0);
%13 = concatenate(%11, axis=1);
%14 = concatenate(%12, axis=1);
%15 = nn.dense(%13, %14, units=None);
%16 = add(%15, %lstm.bias_ih_l0);
%17 = add(%16, %lstm.bias_hh_l0);
%18 = split(%17, indices_or_sections=4, axis=-1);
%19 = %18.3;
%20 = %18.1;
%21 = full(0, shape=[1, 5, 10], dtype="float32");
%22 = split(%21, indices_or_sections=1);
%23 = %22.0;
%24 = squeeze(%23, axis=[0]);
%25 = (%24,);
%26 = sigmoid(%20);
%27 = %25.0;
%28 = %18.0;
%29 = %18.2;
%30 = sigmoid(%28);
%31 = tanh(%29);
%32 = multiply(%26, %27);
%33 = multiply(%30, %31);
%34 = add(%32, %33);
%35 = sigmoid(%19);
%36 = tanh(%34);
%37 = multiply(%35, %36);
%38 = (%37,);
%39 = stack(%38);
%40 = (%34,);
%41 = stack(%40);
%42 = take(%39, -1, axis=0, mode="wrap");
%43 = take(%41, -1, axis=0, mode="wrap");
%44 = (%42, %43);
%45 = concatenate(%44, axis=1);
%46 = nn.dense(%45, %linear.weight, units=None);
%47 = nn.bias_add(%46, %linear.bias, axis=-1);
nn.log_softmax(%47)
}

0 comments on commit 17b0664

Please sign in to comment.