Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 702886640
  • Loading branch information
jakeharmon8 authored and TF2JAXDev committed Dec 5, 2024
1 parent 171d7e3 commit afea2f9
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ performance.

### Platform Specificity

Natively serialized JAX programs are platform specific ([link](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#natively-serialized-jax-modules-are-platform-specific)). Executing a natively
Natively serialized JAX programs are platform specific ([link](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#natively-serialized-jax-modules-are-platform-specific)). Executing a natively
serialized program on platforms other than the one for which it was lowered,
would raise a ValueError, e.g.:

Expand Down Expand Up @@ -399,8 +399,8 @@ ops.

[DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem"
[DeepMind JAX Ecosystem citation]: https://github.com/google-deepmind/jax/blob/main/deepmind2020jax.txt "Citation"
[JAX]: https://github.com/google/jax "JAX on GitHub"
[JAX]: https://github.com/jax-ml/jax "JAX on GitHub"
[TensorFlow]: https://github.com/tensorflow/tensorflow "TensorFlow on GitHub"
[jax2tf documentation]: https://github.com/google/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax "jax2tf documentation"
[jax2tf_cumulative_reduction]: https://github.com/google/jax/blob/main/jax/experimental/jax2tf/jax2tf.py#L2172
[jax2tf documentation]: https://github.com/jax-ml/jax/blob/master/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax "jax2tf documentation"
[jax2tf_cumulative_reduction]: https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/jax2tf.py#L2172
[StableHLO]: https://github.com/openxla/stablehlo
2 changes: 1 addition & 1 deletion test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax
# Native lowering is in active development so we test against nightly and github head.
pip uninstall --yes tensorflow
pip install tf-nightly
pip install git+https://github.com/google/jax.git
pip install git+https://github.com/jax-ml/jax.git
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
CHECK_CUSTOM_CALLS_TEST=0 pytest -n "${N_JOBS}" --pyargs tf2jax._src.roundtrip_test
cd ..
Expand Down
2 changes: 1 addition & 1 deletion tf2jax/_src/numpy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def is_poly_dim(x) -> bool:
return export.is_symbolic_dim(x)

# This should reflect is_poly_dim() at
# https://github.com/google/jax/blob/main/jax/experimental/jax2tf/shape_poly.py#L676
# https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/shape_poly.py#L676
# Array types.
if isinstance(x, (np.ndarray, jax.core.Tracer, xc.ArrayImpl)): # pylint: disable=isinstance-second-argument-not-valid-type
return False
Expand Down
6 changes: 3 additions & 3 deletions tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ def _func(x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
evals, evecs = jnp.linalg.eigh(x, symmetrize_input=False)
else:
# symmetrize_input does not exist for eigvalsh.
# See https://github.com/google/jax/issues/9473
# See https://github.com/jax-ml/jax/issues/9473
evals, evecs = jnp.linalg.eigvalsh(symmetrize(x)), None

# Sorting by eigenvalues to tf.raw_ops.Eig better.
Expand Down Expand Up @@ -2655,7 +2655,7 @@ def _xla_variadic_sort(proto):
return _XlaVariadicSort(dict(comparator=comparator), is_stable=is_stable)


# Taken from https://github.com/google/jax/blob/main/jax/_src/lax/lax.py#L1056
# Taken from https://github.com/jax-ml/jax/blob/main/jax/_src/lax/lax.py#L1056
def _get_max_identity(dtype):
if jax.dtypes.issubdtype(dtype, np.inexact):
return np.array(-np.inf, dtype)
Expand All @@ -2665,7 +2665,7 @@ def _get_max_identity(dtype):
return np.array(False, np.bool_)


# Taken from https://github.com/google/jax/blob/main/jax/_src/lax/lax.py#L1064
# Taken from https://github.com/jax-ml/jax/blob/main/jax/_src/lax/lax.py#L1064
def _get_min_identity(dtype):
if jax.dtypes.issubdtype(dtype, np.inexact):
return np.array(np.inf, dtype)
Expand Down
4 changes: 2 additions & 2 deletions tf2jax/experimental/mhlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def mhlo_apply_impl(*args, module: MhloModule):
mhlo_apply_p.def_impl(mhlo_apply_impl)


# See https://github.com/google/jax/blob/main/jax/_src/interpreters/mlir.py#L115
# See https://github.com/jax-ml/jax/blob/main/jax/_src/interpreters/mlir.py#L115
# for reference
def ir_type_to_dtype(ir_type: ir.Type) -> jnp.dtype:
"""Converts MLIR type to JAX dtype."""
Expand Down Expand Up @@ -154,7 +154,7 @@ def mhlo_apply_abstract_eval(


# Taken from
# github.com/google/jax/blob/main/jax/experimental/jax2tf/jax_export.py#L859
# github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/jax_export.py#L859
def refine_polymorphic_shapes(
module: ir.Module, validate_static_shapes: bool
) -> ir.Module:
Expand Down
4 changes: 2 additions & 2 deletions tf2jax/experimental/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


# See canonicalize_platform for reference
# https://github.com/google/jax/blob/main/jax/_src/xla_bridge.py#L344
# https://github.com/jax-ml/jax/blob/main/jax/_src/xla_bridge.py#L344
def _platform_to_alias(platform: str) -> str:
aliases = {
"cuda": "gpu",
Expand All @@ -41,7 +41,7 @@ def _platform_to_alias(platform: str) -> str:


# Adapted from
# https://github.com/google/jax/commit/ec8b855fa16962b1394716622c8cbc006ce76b1c
# https://github.com/jax-ml/jax/commit/ec8b855fa16962b1394716622c8cbc006ce76b1c
@functools.lru_cache(None)
def _refine_with_static_input_shapes(
module_text: str, operands: Tuple[jax.core.ShapedArray, ...]
Expand Down

0 comments on commit afea2f9

Please sign in to comment.