Skip to content

Commit

Permalink
Add a regression test for #7461.
Browse files Browse the repository at this point in the history
Fixes #7461
  • Loading branch information
hawkinsp committed Oct 13, 2021
1 parent 4d73613 commit 2388804
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Breaking changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
* Bug fixes:
* Fixes https://github.com/google/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler.

## jax 0.2.21 (Sept 23, 2021)
* [GitHub
Expand Down
8 changes: 8 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import itertools
import typing
from typing import Any, Optional, Tuple
import unittest
import warnings

from absl.testing import absltest
Expand Down Expand Up @@ -1216,6 +1217,13 @@ def testIndexSequenceDeprecation(self, idx, idx_type):
with self.assertNoWarnings():
x.at[normalize(idx)].set(0)

@unittest.skipIf(jax._src.lib.version < (0, 1, 72),
"Bug fixed in jaxlib 0.1.72")
def testIndexedUpdateAliasingBug(self):
# https://github.com/google/jax/issues/7461
fn = lambda x: x.at[1:].set(1 + x[:-1])
y = jnp.zeros(8)
self.assertArraysEqual(fn(y), jax.jit(fn)(y))

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 2388804

Please sign in to comment.