Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3Packages.{jax,jaxlib}: update to 0.4.14 #246712

Merged
merged 6 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions pkgs/build-support/build-bazel-package/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ args@{
, bazelFlags ? []
, bazelBuildFlags ? []
, bazelTestFlags ? []
, bazelRunFlags ? []
, runTargetFlags ? []
, bazelFetchFlags ? []
, bazelTargets
, bazelTargets ? []
, bazelTestTargets ? []
, bazelRunTarget ? null
, buildAttrs
, fetchAttrs

Expand Down Expand Up @@ -46,17 +49,23 @@ args@{

let
fArgs = removeAttrs args [ "buildAttrs" "fetchAttrs" "removeRulesCC" ] // {
name = name;
bazelFlags = bazelFlags;
bazelBuildFlags = bazelBuildFlags;
bazelTestFlags = bazelTestFlags;
bazelFetchFlags = bazelFetchFlags;
bazelTestTargets = bazelTestTargets;
dontAddBazelOpts = dontAddBazelOpts;
inherit
name
bazelFlags
bazelBuildFlags
bazelTestFlags
bazelRunFlags
runTargetFlags
bazelFetchFlags
bazelTargets
bazelTestTargets
bazelRunTarget
dontAddBazelOpts
;
};
fBuildAttrs = fArgs // buildAttrs;
fFetchAttrs = fArgs // removeAttrs fetchAttrs [ "sha256" ];
bazelCmd = { cmd, additionalFlags, targets }:
bazelCmd = { cmd, additionalFlags, targets, targetRunFlags ? [ ] }:
lib.optionalString (targets != [ ]) ''
# See footnote called [USER and BAZEL_USE_CPP_ONLY_TOOLCHAIN variables]
BAZEL_USE_CPP_ONLY_TOOLCHAIN=1 \
Expand All @@ -73,7 +82,8 @@ let
"''${host_linkopts[@]}" \
$bazelFlags \
${lib.strings.concatStringsSep " " additionalFlags} \
${lib.strings.concatStringsSep " " targets}
${lib.strings.concatStringsSep " " targets} \
${lib.optionalString (targetRunFlags != []) " -- " + lib.strings.concatStringsSep " " targetRunFlags}
'';
# we need this to chmod dangling symlinks on darwin, gnu coreutils refuses to do so:
# chmod: cannot operate on dangling symlink '$symlink'
Expand Down Expand Up @@ -262,6 +272,15 @@ stdenv.mkDerivation (fBuildAttrs // {
targets = fBuildAttrs.bazelTargets;
}
}
${
bazelCmd {
cmd = "run";
additionalFlags = fBuildAttrs.bazelRunFlags ++ [ "--jobs" "$NIX_BUILD_CORES" ];
# Bazel run only accepts a single target, but `bazelCmd` expects `targets` to be a list.
targets = lib.optionals (fBuildAttrs.bazelRunTarget != null) [ fBuildAttrs.bazelRunTarget ];
targetRunFlags = fBuildAttrs.runTargetFlags;
}
}
runHook postBuild
'';
})
Expand Down
40 changes: 15 additions & 25 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
{ lib
, absl-py
, blas
, buildPythonPackage
, etils
, setuptools
, importlib-metadata
, fetchFromGitHub
, jaxlib
, jaxlib-bin
, lapack
, matplotlib
, ml-dtypes
, numpy
, opt-einsum
, pytestCheckHook
, pytest-xdist
, pythonOlder
, scipy
, stdenv
, typing-extensions
}:

let
Expand All @@ -27,30 +27,32 @@ let
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.5";
format = "setuptools";
version = "0.4.14";
format = "pyproject";

disabled = pythonOlder "3.7";
disabled = pythonOlder "3.9";

src = fetchFromGitHub {
owner = "google";
repo = pname;
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-UJzX8zP3qaEUIV5hPJhiGiLJO7k8p962MHWxIHDY1ZA=";
hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
};

nativeBuildInputs = [
setuptools
];

# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU.
propagatedBuildInputs = [
absl-py
etils
ml-dtypes
numpy
opt-einsum
scipy
typing-extensions
] ++ etils.optional-dependencies.epath;
] ++ lib.optional (pythonOlder "3.10") importlib-metadata;

nativeCheckInputs = [
jaxlib'
Expand Down Expand Up @@ -96,24 +98,12 @@ buildPythonPackage rec {
"testScanGrad_jit_scan"
];

# See https://github.com/google/jax/issues/11722. This is a temporary fix in
# order to unblock etils, and upgrading jax/jaxlib to the latest version. See
# https://github.com/NixOS/nixpkgs/issues/183173#issuecomment-1204074993.
disabledTestPaths = [
"tests/api_test.py"
"tests/core_test.py"
"tests/lax_numpy_indexing_test.py"
"tests/lax_numpy_test.py"
"tests/nn_test.py"
"tests/random_test.py"
"tests/sparse_test.py"
] ++ lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
disabledTestPaths = lib.optionals (stdenv.isDarwin && stdenv.isAarch64) [
# RuntimeWarning: invalid value encountered in cast
"tests/lax_test.py"
];

# As of 0.3.22, `import jax` does not work without jaxlib being installed.
pythonImportsCheck = [ ];
pythonImportsCheck = [ "jax" ];

meta = with lib; {
description = "Differentiate, compile, and transform Numpy code";
Expand Down
84 changes: 51 additions & 33 deletions pkgs/development/python-modules/jaxlib/bin.nix
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
, autoPatchelfHook
, buildPythonPackage
, config
, cudnn ? cudaPackages.cudnn
, fetchPypi
, fetchurl
, flatbuffers
, isPy39
, jaxlib-build
, lib
, ml-dtypes
, python
, scipy
, stdenv
Expand All @@ -35,46 +36,57 @@ let
inherit (cudaPackages) cudatoolkit cudnn;
in

assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1";
assert cudaSupport -> lib.versionAtLeast cudnn.version "8.2";
assert cudaSupport -> lib.versionAtLeast cudatoolkit.version "11.1" && lib.versionAtLeast cudnn.version "8.2" && stdenv.isLinux;

let
version = "0.4.4";
version = "0.4.14";

inherit (python) pythonVersion;

# As of 2023-06-06, google/jax upstream is no longer publishing CPU-only wheels to their GCS bucket. Instead the
# official instructions recommend installing CPU-only versions via PyPI.
cpuSrcs =
let
getSrcFromPypi = { platform, hash }: fetchPypi {
inherit version platform hash;
pname = "jaxlib";
format = "wheel";
# See the `disabled` attr comment below.
dist = "cp310";
python = "cp310";
abi = "cp310";
};
in
{
"x86_64-linux" = getSrcFromPypi {
platform = "manylinux2014_x86_64";
hash = "sha256-nyylSZfqHeftlvVgJZFCN1ldjluZVJIYu4ZSsVxvXf8=";
};
"aarch64-darwin" = getSrcFromPypi {
platform = "macosx_11_0_arm64";
hash = "sha256-La3wYbGCjWTl7krBD6BaBRqyBD8R530Lckbz0AWv0FM=";
};
"x86_64-darwin" = getSrcFromPypi {
platform = "macosx_10_14_x86_64";
hash = "sha256-hDg5+qisgtgOrdvbjxsUgI73cW6Aah8NLjhPe4kMAsM=";
};
};

pythonVersion = python.pythonVersion;

# Find new releases at https://storage.googleapis.com/jax-releases/jax_releases.html.
# When upgrading, you can get these hashes from prefetch.sh. See
# https://github.com/google/jax/issues/12879 as to why this specific URL is
# the correct index.
cpuSrcs = {
"x86_64-linux" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/nocuda/jaxlib-${version}-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-4VT909AB+ti5HzQvsaZWNY6MS/GItlVEFH9qeZnUuKQ=";
};
"aarch64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_11_0_arm64.whl";
hash = "sha256-wuOmoCeTldslSa0MommQeTe+RYKhUMam1ZXrgSov+8U=";
};
"x86_64-darwin" = fetchurl {
url = "https://storage.googleapis.com/jax-releases/mac/jaxlib-${version}-cp310-cp310-macosx_10_14_x86_64.whl";
hash = "sha256-arfiTw8yafJwjRwJhKby2O7y3+4ksh3PjaKW9JgJ1ok=";
};
};

# https://github.com/google/jax/issues/12879 as to why this specific URL is the correct index.
gpuSrc = fetchurl {
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-bJ62DdzuPSV311ZI2R/LJQ3fOkDibtz2+8wDKw31FLk=";
url = "https://storage.googleapis.com/jax-releases/cuda11/jaxlib-${version}+cuda11.cudnn86-cp310-cp310-manylinux2014_x86_64.whl";
hash = "sha256-CcQ5kjp4XfUX4/RwFY3T5G3kVKAeyoCTXu1Lo4O16Qo=";
};

in
buildPythonPackage rec {
buildPythonPackage {
pname = "jaxlib";
inherit version;
format = "wheel";

# At the time of writing (2022-10-19), there are releases for <=3.10.
# Supporting all of them is a pain, so we focus on 3.10, the current nixpkgs
# python version.
disabled = !(pythonVersion == "3.10");

# See https://discourse.nixos.org/t/ofborg-does-not-respect-meta-platforms/27019/6.
Expand All @@ -87,9 +99,10 @@ buildPythonPackage rec {

# Prebuilt wheels are dynamically linked against things that nix can't find.
# Run `autoPatchelfHook` to automagically fix them.
nativeBuildInputs = lib.optionals cudaSupport [ autoPatchelfHook addOpenGLRunpath ];
nativeBuildInputs = lib.optionals stdenv.isLinux [ autoPatchelfHook ]
++ lib.optionals cudaSupport [ addOpenGLRunpath ];
# Dynamic link dependencies
buildInputs = [ stdenv.cc.cc ];
buildInputs = [ stdenv.cc.cc.lib ];

# jaxlib contains shared libraries that open other shared libraries via dlopen
# and these implicit dependencies are not recognized by ldd or
Expand All @@ -113,7 +126,12 @@ buildPythonPackage rec {
done
'';

propagatedBuildInputs = [ absl-py flatbuffers scipy ];
propagatedBuildInputs = [
absl-py
flatbuffers
ml-dtypes
scipy
];

# Note that cudatoolkit is snecessary since jaxlib looks for "ptxas" in $PATH.
# See https://github.com/NixOS/nixpkgs/pull/164176#discussion_r828801621 for
Expand All @@ -123,7 +141,7 @@ buildPythonPackage rec {
ln -s ${cudatoolkit}/bin/ptxas $out/bin/ptxas
'';

pythonImportsCheck = [ "jaxlib" ];
inherit (jaxlib-build) pythonImportsCheck;

meta = with lib; {
description = "XLA library for JAX";
Expand Down
Loading