Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade torch dependency to 2.4.0 plus misc changes: #7821

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion experimental/torch_xla2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def model_func(param, inputs):
Now, we can apply `jax_jit`

```python
from torch_xla2.extra import jax_jit
from torch_xla2.interop import jax_jit
model_func_jitted = jax_jit(model_func)
print(model_func_jitted(new_state_dict, inputs))
```
Expand Down
2 changes: 1 addition & 1 deletion experimental/torch_xla2/dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
-f https://download.pytorch.org/whl/torch
torch==2.3.0+cpu
torch==2.4.0+cpu
ruff~=0.3.5
2 changes: 2 additions & 0 deletions experimental/torch_xla2/test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
-r dev-requirements.txt
absl-py
immutabledict
pytest
pytest-xdist
sentencepiece
Expand Down
6 changes: 4 additions & 2 deletions experimental/torch_xla2/test/test_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.export
import torch_xla2
import torch_xla2.export
from torch_xla2 import tensor
from torch_xla2.ops import mappings

Expand Down Expand Up @@ -107,7 +108,8 @@ def test_export_dtypes(self):
torch.long : "i64",
# NO_MAPPING : "ui4"
torch.uint8 : "ui8",
torch.uint16 : "ui16",
# NOTE(qihqi): torch export for uint16 seems broken at torch 2.4
# torch.uint16 : "ui16",
torch.uint32 : "ui32",
torch.uint64 : "ui64",
# NO_MAPPING : "f8E4M3B11FNUZ"
Expand All @@ -127,7 +129,7 @@ def test_export_dtypes(self):
}

model = TensorConstant()
for torch_dtype in mappings.TORCH_DTYPE_TO_JAX.keys():
for torch_dtype in DTYPE_TO_MLIR_STR.keys():
if torch_dtype == None:
## TODO: Figure out what the None mapping should be, seems like:
## torch.tensor(dtype=None) maps to f32
Expand Down
13 changes: 13 additions & 0 deletions experimental/torch_xla2/test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ def test_dont_capture_conversion(self):
t2 = self.env.to_xla(t)
# assert no exceptions

def test_brackets(self):
with self.env:
a = torch.randn((2,3))
a[1] = 9
self.assertEqual(a[1, 0].item(), 9)

def test_bernoulli_inplace(self):
with self.env:
a = torch.randn((2,3))
a.bernoulli_(0.4)




if __name__ == '__main__':
absltest.main()
12 changes: 12 additions & 0 deletions experimental/torch_xla2/torch_xla2/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def rand_like(self, **kwargs):
dtype = kwargs.get('dtype')
return torch.rand(self.shape, dtype=dtype)

def channel_shuffle(self, groups):
batchsize, channels, height, width = self.shape
channels_per_group = channels // groups
self = self.reshape(batchsize, groups, channels_per_group, height, width)
self = self.transpose(1, 2)
self = self.reshape(batchsize, channels, height, width)
return self

_try_register(aten.channel_shuffle, channel_shuffle)

_try_register(aten.bernoulli, bernoulli)
_try_register(aten.rand_like, rand_like)

Expand All @@ -121,4 +131,6 @@ def rand_like(self, **kwargs):
torch.ops.aten.replication_pad3d,
torch.ops.aten.bernoulli,
torch.ops.aten.rand_like,
torch.ops.aten._batch_norm_with_update,
torch.ops.aten.channel_shuffle,
])
21 changes: 15 additions & 6 deletions experimental/torch_xla2/torch_xla2/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
torch.ops.aten.add_: torch.ops.aten.add,
torch.ops.aten.sub_: torch.ops.aten.sub,
torch.ops.aten.mul_: torch.ops.aten.mul,
torch.ops.aten.div_: torch.ops.aten.div,
torch.ops.aten.pow_: torch.ops.aten.pow,
torch.ops.aten.lt_: torch.ops.aten.lt,
torch.ops.aten.le_: torch.ops.aten.le,
Expand All @@ -34,6 +35,7 @@
torch.ops.aten.relu_: torch.ops.aten.relu,
torch.ops.aten.normal_: torch.ops.aten.normal,
torch.ops.aten.squeeze_: torch.ops.aten.squeeze,
torch.ops.aten.bernoulli_: torch.ops.aten.bernoulli.p,
}


Expand Down Expand Up @@ -277,12 +279,6 @@ def _aten_div(x, y, rounding_mode=""):
return res


@op(torch.ops.aten.div_, is_jax_function=False)
def _aten_div_(x, y, rounding_mode=""):
x._elem = _aten_div(x._elem, y._elem, rounding_mode)
return x


@op(torch.ops.aten.true_divide)
def _aten_true_divide(x, y):
return x / y
Expand Down Expand Up @@ -2219,6 +2215,19 @@ def _randn(
res = res.astype(dtype)
return res

@op(torch.ops.aten.bernoulli.p, needs_env=True)
def _bernoulli(
self,
p = 0.5,
*,
generator=None,
env=None,
):
key = env.get_and_rotate_prng_key(generator)
res = jax.random.uniform(key, self.shape) < p
return res



@op(torch.ops.aten.randn_like, needs_env=True)
@op_base.convert_dtype()
Expand Down
4 changes: 2 additions & 2 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def flatten(self, start_dim=0, end_dim=-1):
# return torch.reshape(self, new_shape)

def __setitem__(self, key, val):
key = unwrap(key)
self._elem = self._elem.at[key].set(val._elem)
key, val = self._env.t2j_iso((key, val))
self._elem = self._elem.at[key].set(val)

def type_as(self, other):
self._elem = self._elem.astype(other._elem.dtype)
Expand Down
Loading