-
Hi. I am getting the following deprecation warning when doing JAX - PyTorch interoperability via
import jax.dlpack as jax_dlpack
import torch.utils.dlpack as torch_dlpack
# ...
jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(tensor.contiguous())) Question: What, if any, is the "new" way of doing JAX - PyTorch interoperability? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
These two packages might help, https://github.com/rdyro/torch2jax I have personally used the first one a lot, it is now using the new JAX FFI interface and found it extremely efficient / the maintainer is very helpful. Going the other way, |
Beta Was this translation helpful? Give feedback.
-
IIRC, it's |
Beta Was this translation helpful? Give feedback.
IIRC, it's
jax_dlpack.from_dlpack(tensor.contiguous())
these days.