Skip to content

Commit

Permalink
python3Packages.jax: 0.4.1 -> 0.4.5 and fix aarch64-darwin build
Browse files Browse the repository at this point in the history
  • Loading branch information
samuela committed Mar 30, 2023
1 parent 75658f7 commit d70820e
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
, etils
, fetchFromGitHub
, jaxlib
, jaxlib-bin
, lapack
, matplotlib
, numpy
Expand All @@ -13,30 +14,35 @@
, pytest-xdist
, pythonOlder
, scipy
, stdenv
, typing-extensions
}:

let
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
# jaxlib is broken on aarch64-* as of 2023-03-05, but the binary wheels work
# fine. jaxlib is only used in the checkPhase, so switching backends does not
# impact package behavior. Get rid of this once jaxlib is fixed on aarch64-*.
jaxlib' = if jaxlib.meta.broken then jaxlib-bin else jaxlib;
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.1";
version = "0.4.5";
format = "setuptools";

disabled = pythonOlder "3.7";

src = fetchFromGitHub {
owner = "google";
repo = pname;
rev = "refs/tags/jaxlib-v${version}";
hash = "sha256-ajLI0iD0YZRK3/uKSbhlIZGc98MdW174vA34vhoy7Iw=";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
};

# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU. Currently only the
# CPU wheel is packaged.
# JAX project ships separate wheels for CPU, GPU, and TPU.
propagatedBuildInputs = [
absl-py
etils
Expand All @@ -47,7 +53,7 @@ buildPythonPackage rec {
] ++ etils.optional-dependencies.epath;

nativeCheckInputs = [
jaxlib
jaxlib'
matplotlib
pytestCheckHook
pytest-xdist
Expand Down Expand Up @@ -83,6 +89,11 @@ buildPythonPackage rec {
"test_custom_linear_solve_cholesky"
"test_custom_root_with_aux"
"testEigvalsGrad_shape"
] ++ lib.optionals (stdenv.isAarch64 && stdenv.isDarwin) [
# See https://github.com/google/jax/issues/14793.
"test_for_loop_fixpoint_correctly_identifies_loop_varying_residuals_unrolled_for_loop"
"testQdwhWithRandomMatrix3"
"testScanGrad_jit_scan"
];

# See https://github.com/google/jax/issues/11722. This is a temporary fix in
Expand Down

0 comments on commit d70820e

Please sign in to comment.