From 5a113aff98ce42420891c724843ccb30691dc24a Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Fri, 1 Mar 2024 17:26:32 -0800 Subject: [PATCH] Make XLAShardedTensor use default __torch_function__ (#6625) --- torch_xla/distributed/spmd/xla_sharded_tensor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 44377027f5b..aedfd6a801e 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -168,3 +168,7 @@ def wrap(elem): rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) return rs + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs)