Skip to content

Commit

Permalink
python3Packages.jaxlib: 0.4.4 -> 0.4.14
Browse files Browse the repository at this point in the history
  • Loading branch information
GaetanLepage authored and NickCao committed Aug 2, 2023
1 parent 06ef57d commit 6232bc9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 39 deletions.
78 changes: 40 additions & 38 deletions pkgs/development/python-modules/jaxlib/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

# Build-time dependencies:
, addOpenGLRunpath
, bazel_5
, bazel_6
, binutils
, buildBazelPackage
, buildPythonPackage
Expand All @@ -21,11 +21,13 @@
, setuptools
, symlinkJoin
, wheel
, build
, which

# Python dependencies:
, absl-py
, flatbuffers
, ml-dtypes
, numpy
, scipy
, six
Expand All @@ -35,7 +37,6 @@
, giflib
, grpc
, libjpeg_turbo
, protobuf
, python
, snappy
, zlib
Expand All @@ -53,7 +54,7 @@ let
inherit (cudaPackages) backendStdenv cudatoolkit cudaFlags cudnn nccl;

pname = "jaxlib";
version = "0.4.4";
version = "0.4.14";

meta = with lib; {
description = "JAX is Autograd and XLA, brought together for high-performance machine learning research.";
Expand Down Expand Up @@ -99,7 +100,9 @@ let
# "com_github_googleapis_googleapis"
# "com_github_googlecloudplatform_google_cloud_cpp"
"com_github_grpc_grpc"
"com_google_protobuf"
# ERROR: /build/output/external/bazel_tools/tools/proto/BUILD:25:6: no such target '@com_google_protobuf//:cc_toolchain':
# target 'cc_toolchain' not declared in package '' defined by /build/output/external/com_google_protobuf/BUILD.bazel
# "com_google_protobuf"
# Fails with the error: external/org_tensorflow/tensorflow/core/profiler/utils/tf_op_utils.cc:46:49: error: no matching function for call to 're2::RE2::FullMatch(absl::lts_2020_02_25::string_view&, re2::RE2&)'
# "com_googlesource_code_re2"
"curl"
Expand All @@ -120,7 +123,9 @@ let
"org_sqlite"
"pasta"
"png"
"pybind11"
# ERROR: /build/output/external/pybind11/BUILD.bazel: no such target '@pybind11//:osx':
# target 'osx' not declared in package '' defined by /build/output/external/pybind11/BUILD.bazel
# "pybind11"
"six_archive"
"snappy"
"tblib_archive"
Expand All @@ -138,14 +143,15 @@ let
bazel-build = buildBazelPackage rec {
name = "bazel-build-${pname}-${version}";

bazel = bazel_5;
# See https://github.com/google/jax/blob/main/.bazelversion for the latest.
bazel = bazel_6;

src = fetchFromGitHub {
owner = "google";
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jaxlib tags!
rev = "refs/tags/${pname}-v${version}";
hash = "sha256-DP68UwS9bg243iWU4MLHN0pwl8LaOcW3Sle1ZjsLOHo=";
hash = "sha256-0KnILQkahSiA1uuyT+kgy1XaCcZ3cpx1q114e2pecvg=";
};

nativeBuildInputs = [
Expand All @@ -154,6 +160,7 @@ let
git
setuptools
wheel
build
which
] ++ lib.optionals stdenv.isDarwin [
cctools
Expand All @@ -169,7 +176,7 @@ let
numpy
openssl
pkgs.flatbuffers
protobuf
pkgs.protobuf
pybind11
scipy
six
Expand All @@ -188,7 +195,8 @@ let
rm -f .bazelversion
'';

bazelTargets = [ "//build:build_wheel" ];
bazelRunTarget = "//jaxlib/tools:build_wheel";
runTargetFlags = [ "--output_path=$out" "--cpu=${arch}" ];

removeRulesCC = false;

Expand All @@ -207,7 +215,11 @@ let
build --action_env=PYENV_ROOT
build --python_path="${python}/bin/python"
build --distinct_host_configuration=false
build --define PROTOBUF_INCLUDE_PATH="${protobuf}/include"
build --define PROTOBUF_INCLUDE_PATH="${pkgs.protobuf}/include"
'' + lib.optionalString (stdenv.targetPlatform.avxSupport && stdenv.targetPlatform.isUnix) ''
build --config=avx_posix
'' + lib.optionalString mklSupport ''
build --config=mkl_open_source_only
'' + lib.optionalString cudaSupport ''
build --action_env CUDA_TOOLKIT_PATH="${cudatoolkit_joined}"
build --action_env CUDNN_INSTALL_PATH="${cudnn}"
Expand All @@ -234,7 +246,7 @@ let
fetchAttrs = {
TF_SYSTEM_LIBS = lib.concatStringsSep "," tf_system_libs;
# we have to force @mkl_dnn_v1 since it's not needed on darwin
bazelTargets = bazelTargets ++ [ "@mkl_dnn_v1//:mkl_dnn" ];
bazelTargets = [ bazelRunTarget "@mkl_dnn_v1//:mkl_dnn" ];
bazelFlags = bazelFlags ++ [
"--config=avx_posix"
] ++ lib.optionals cudaSupport [
Expand All @@ -249,9 +261,9 @@ let

sha256 =
if cudaSupport then
"sha256-O6bM7Lc8eaFyO4Xzl5/hvBrbPioI+Yeqx9yNC97fvKk="
"sha256-8QaXoZq6oITRsYn4RdLUXcKQv3PJ4Q3ItX9PkBwxGBI="
else
"sha256-gLMJfJSQIdGGY2Ivx4IgDWg0hc+mxzlqY11CUkSWcjI=";
"sha256-M/h5EZmyiV4QvzgKRjdz7V1LHENUJlc/ig1QAItnWVQ=";
};

buildAttrs = {
Expand All @@ -261,25 +273,13 @@ let
"nsync" # fails to build on darwin
]);

bazelFlags = bazelFlags ++ lib.optionals (stdenv.targetPlatform.isx86_64 && stdenv.targetPlatform.isUnix) [
"--config=avx_posix"
] ++ lib.optionals cudaSupport [
"--config=cuda"
] ++ lib.optionals mklSupport [
"--config=mkl_open_source_only"
];
# Note: we cannot do most of this patching at `patch` phase as the deps are not available yet.
# 1) Fix pybind11 include paths.
# 2) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# 1) Link protobuf from nixpkgs (through TF_SYSTEM_LIBS when using gcc) to prevent crashes on
# loading multiple extensions in the same python program due to duplicate protobuf DBs.
# 3) Patch python path in the compiler driver.
preBuild = ''
for src in ./jaxlib/*.{cc,h} ./jaxlib/cuda/*.{cc,h}; do
sed -i 's@include/pybind11@pybind11@g' $src
done
'' + lib.optionalString cudaSupport ''
# 2) Patch python path in the compiler driver.
preBuild = lib.optionalString cudaSupport ''
export NIX_LDFLAGS+=" -L${backendStdenv.nixpkgsCompatibleLibstdcxx}/lib"
patchShebangs ../output/external/org_tensorflow/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
patchShebangs ../output/external/xla/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
'' + lib.optionalString stdenv.isDarwin ''
# Framework search paths aren't added by bintools hook
# https://github.com/NixOS/nixpkgs/pull/41914
Expand All @@ -289,16 +289,12 @@ let
substituteInPlace ../output/external/rules_cc/cc/private/toolchain/unix_cc_configure.bzl \
--replace "/usr/bin/libtool" "${cctools}/bin/libtool"
'' + (if stdenv.cc.isGNU then ''
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotobuf@-l:libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@-l:libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
'' else if stdenv.cc.isClang then ''
sed -i 's@-lprotobuf@${protobuf}/lib/libprotobuf.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@${protobuf}/lib/libprotoc.a@' ../output/external/org_tensorflow/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotobuf@${pkgs.protobuf}/lib/libprotobuf.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
sed -i 's@-lprotoc@${pkgs.protobuf}/lib/libprotoc.a@' ../output/external/xla/third_party/systemlibs/protobuf.BUILD
'' else throw "Unsupported stdenv.cc: ${stdenv.cc}");

installPhase = ''
./bazel-bin/build/build_wheel --output_path=$out --cpu=${arch}
'';
};

inherit meta;
Expand Down Expand Up @@ -345,13 +341,19 @@ buildPythonPackage {
grpc
jsoncpp
libjpeg_turbo
ml-dtypes
numpy
scipy
six
snappy
];

pythonImportsCheck = [ "jaxlib" ];
pythonImportsCheck = [
"jaxlib"
# `import jaxlib` loads surprisingly little. These imports are actually bugs that appeared in the 0.4.11 upgrade.
"jaxlib.cpu_feature_guard"
"jaxlib.xla_client"
];

# Without it there are complaints about libcudart.so.11.0 not being found
# because RPATH path entries added above are stripped.
Expand Down
1 change: 0 additions & 1 deletion pkgs/top-level/python-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -5310,7 +5310,6 @@ self: super: with self; {
# Some platforms don't have `cudaSupport` defined, hence the need for 'or false'.
inherit (pkgs.config) cudaSupport;
IOKit = pkgs.darwin.apple_sdk_11_0.IOKit;
protobuf = pkgs.protobuf3_20; # jaxlib-build 0.3.15 won't build with protobuf 3.21
};

jaxlib = self.jaxlib-build;
Expand Down

0 comments on commit 6232bc9

Please sign in to comment.