Skip to content

Commit

Permalink
Op info test for expm1 .. fill (#7940)
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi authored Aug 31, 2024
1 parent 87fe71a commit 13affb9
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 0 additions & 4 deletions experimental/torch_xla2/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,6 @@
"cdouble",
"ceil",
"chalf", # Skip due to jax not support complex32 with backend: https://github.com/google/jax/issues/14180
"expm1",
"fft.fftshift",
"fft.ifftshift",
"fill",
"nn.functional.smooth_l1_loss",
"nn.functional.soft_margin_loss",
"nn.functional.softplus",
Expand Down
6 changes: 5 additions & 1 deletion experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,7 +1780,11 @@ def _aten_exp(input):
# aten.expm1
@op(torch.ops.aten.expm1)
def _aten_expm1(input):
return jnp.expm1(input)
res = jnp.expm1(input)
new_dtype = mappings.t2j_dtype(torch.get_default_dtype())
if input.dtype == jax.numpy.int64:
res = res.astype(new_dtype)
return res


# aten.exp2
Expand Down

0 comments on commit 13affb9

Please sign in to comment.