From 4cb838a2b4f3604b64efd6e95977ce396a06ee6c Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 27 Feb 2024 20:20:37 +0000 Subject: [PATCH 1/4] Make XLAShardedTensor use default __torch_function__ --- torch_xla/distributed/spmd/xla_sharded_tensor.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index 44377027f5b..b4c3b939184 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -108,6 +108,9 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): r.global_tensor = elem.detach() if r.requires_grad else elem return r + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) + # Shards on the devices are materialized/available after the lazy # execution of the partitioned HLO graph. Each XLAShard points # to torch.Tensor. The shards represent a snapshot on CPU, detached From 1cfec8825ee5ba167de28fa35c9c9160eaec3fda Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Tue, 27 Feb 2024 20:21:33 +0000 Subject: [PATCH 2/4] torch_pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 00000000000..af9104c8ede --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#120632 From 640b43112297e70c6fe582e14583d326b401dd4a Mon Sep 17 00:00:00 2001 From: JackCaoG Date: Wed, 28 Feb 2024 18:55:25 +0000 Subject: [PATCH 3/4] add classmethod --- torch_xla/distributed/spmd/xla_sharded_tensor.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_xla/distributed/spmd/xla_sharded_tensor.py b/torch_xla/distributed/spmd/xla_sharded_tensor.py index b4c3b939184..aedfd6a801e 100644 --- a/torch_xla/distributed/spmd/xla_sharded_tensor.py +++ b/torch_xla/distributed/spmd/xla_sharded_tensor.py @@ -108,9 +108,6 @@ def __new__(cls, elem: torch.Tensor, *args, **kwargs): r.global_tensor = elem.detach() if r.requires_grad else elem return r - def __torch_function__(cls, func, types, args=(), kwargs=None): - return super().__torch_function__(func, types, args, kwargs) - # Shards on the devices are materialized/available after the lazy # execution of the partitioned HLO graph. Each XLAShard points # to torch.Tensor. The shards represent a snapshot on CPU, detached @@ -171,3 +168,7 @@ def wrap(elem): rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) return rs + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + return super().__torch_function__(func, types, args, kwargs) From 625d026d6371c92e50933f2a25c3a727b6dfdf89 Mon Sep 17 00:00:00 2001 From: JackCaoG <59073027+JackCaoG@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:44:25 -0800 Subject: [PATCH 4/4] Delete torch_patches/.torch_pin --- torch_patches/.torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index af9104c8ede..00000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#120632