Skip to content

Commit

Permalink
[Torch] Fix ELU conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Aug 10, 2021
1 parent 39571c1 commit 089b90c
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def leaky_relu(self, inputs, input_types):
def elu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
alpha = _expr.const(float(inputs[1]), dtype=dtype)
alpha = _expr.const(-float(inputs[1]), dtype=dtype)
return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)

def celu(self, inputs, input_types):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def test_forward_leakyrelu():
def test_forward_elu():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
input_data = torch.randn(input_shape).float()
verify_model(torch.nn.ELU().eval(), input_data=input_data)
verify_model(torch.nn.ELU(alpha=0.3).eval(), input_data=input_data)
verify_model(torch.nn.ELU(alpha=1.0).eval(), input_data=input_data)
Expand Down

0 comments on commit 089b90c

Please sign in to comment.