Skip to content

Commit

Permalink
Make XLAShardedTensor use default __torch_function__ (#6625)
Browse files Browse the repository at this point in the history
  • Loading branch information
JackCaoG authored Mar 2, 2024
1 parent 508aa26 commit 5a113af
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch_xla/distributed/spmd/xla_sharded_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,7 @@ def wrap(elem):
rs = tree_map(wrap,
func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
return rs

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
return super().__torch_function__(func, types, args, kwargs)

0 comments on commit 5a113af

Please sign in to comment.