-
-
Notifications
You must be signed in to change notification settings - Fork 14.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #222273 from SomeoneSerge/torch20
python3Packages.torch: 1.13.1 -> 2.0.0
- Loading branch information
Showing
8 changed files
with
4,944 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
253 changes: 253 additions & 0 deletions
253
pkgs/development/python-modules/openai-triton/default.nix
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,253 @@ | ||
{ lib | ||
, buildPythonPackage | ||
, python | ||
, fetchpatch | ||
, fetchFromGitHub | ||
, addOpenGLRunpath | ||
, cmake | ||
, cudaPackages | ||
, llvmPackages | ||
, pybind11 | ||
, gtest | ||
, zlib | ||
, ncurses | ||
, libxml2 | ||
, lit | ||
, filelock | ||
, torchWithRocm | ||
, pytest | ||
, pytestCheckHook | ||
, pythonRelaxDepsHook | ||
, pkgsTargetTarget | ||
}: | ||
|
||
let | ||
pname = "triton"; | ||
version = "2.0.0"; | ||
|
||
inherit (cudaPackages) cuda_cudart backendStdenv; | ||
|
||
# A time may come we'll want to be cross-friendly | ||
# | ||
# Short explanation: we need pkgsTargetTarget, because we use string | ||
# interpolation instead of buildInputs. | ||
# | ||
# Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's | ||
# ptxas compiler. We're not running this ptxas on the build machine, but on | ||
# the user's machine, i.e. our Target platform. The second "Target" in | ||
# pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to | ||
# be executed on the GPU. | ||
# Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra | ||
ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas"; | ||
|
||
llvm = (llvmPackages.llvm.override { | ||
llvmTargetsToBuild = [ "NATIVE" "NVPTX" ]; | ||
# Upstream CI sets these too: | ||
# targetProjects = [ "mlir" ]; | ||
extraCMakeFlags = [ | ||
"-DLLVM_INSTALL_UTILS=ON" | ||
]; | ||
}); | ||
in | ||
buildPythonPackage { | ||
inherit pname version; | ||
|
||
format = "setuptools"; | ||
|
||
src = fetchFromGitHub { | ||
owner = "openai"; | ||
repo = pname; | ||
rev = "v${version}"; | ||
hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU="; | ||
}; | ||
|
||
patches = [ | ||
# Prerequisite for llvm15 patch | ||
(fetchpatch { | ||
url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch"; | ||
hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk="; | ||
}) | ||
(fetchpatch { | ||
url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch"; | ||
hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg="; | ||
}) | ||
|
||
# Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch | ||
# The original patch adds ptxas binary, so we include our own clean copy | ||
# Drop with the next update | ||
./llvm15.patch | ||
|
||
# TODO: there have been commits upstream aimed at removing the "torch" | ||
# circular dependency, but the patches fail to apply on the release | ||
# revision. Keeping the link for future reference | ||
# Also cf. https://github.com/openai/triton/issues/1374 | ||
|
||
# (fetchpatch { | ||
# url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch"; | ||
# hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig="; | ||
# }) | ||
]; | ||
|
||
postPatch = '' | ||
substituteInPlace python/setup.py \ | ||
--replace \ | ||
'= get_thirdparty_packages(triton_cache_path)' \ | ||
'= os.environ["cmakeFlags"].split()' | ||
'' | ||
# Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3 | ||
# Revisit when updating either triton or llvm | ||
+ '' | ||
substituteInPlace CMakeLists.txt \ | ||
--replace "nvptx" "NVPTX" \ | ||
--replace "LLVM 11" "LLVM" | ||
sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt | ||
sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt | ||
find -iname '*.td' -exec \ | ||
sed -i \ | ||
-e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \ | ||
-e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \ | ||
'{}' ';' | ||
substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)" | ||
sed -i 's/^include.*$//' unittest/CMakeLists.txt | ||
sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt | ||
sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt | ||
'' | ||
# TritonMLIRIR already links MLIRIR. Not transitive? | ||
# + '' | ||
# echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt | ||
# '' | ||
# Already defined in llvm, when built with -DLLVM_INSTALL_UTILS | ||
+ '' | ||
substituteInPlace bin/CMakeLists.txt \ | ||
--replace "add_subdirectory(FileCheck)" "" | ||
rm cmake/FindLLVM.cmake | ||
'' | ||
+ | ||
( | ||
let | ||
# Bash was getting weird without linting, | ||
# but basically upstream contains [cc, ..., "-lcuda", ...] | ||
# and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...] | ||
old = [ "-lcuda" ]; | ||
new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ]; | ||
|
||
quote = x: ''"${x}"''; | ||
oldStr = lib.concatMapStringsSep ", " quote old; | ||
newStr = lib.concatMapStringsSep ", " quote new; | ||
in | ||
'' | ||
substituteInPlace python/triton/compiler.py \ | ||
--replace '${oldStr}' '${newStr}' | ||
'' | ||
) | ||
# Triton seems to be looking up cuda.h | ||
+ '' | ||
sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py | ||
''; | ||
|
||
nativeBuildInputs = [ | ||
cmake | ||
pythonRelaxDepsHook | ||
|
||
# Requires torch (circular dependency) and probably needs GPUs: | ||
# pytestCheckHook | ||
|
||
# Note for future: | ||
# These *probably* should go in depsTargetTarget | ||
# ...but we cannot test cross right now anyway | ||
# because we only support cudaPackages on x86_64-linux atm | ||
lit | ||
llvm | ||
llvmPackages.mlir | ||
]; | ||
|
||
buildInputs = [ | ||
gtest | ||
libxml2.dev | ||
ncurses | ||
pybind11 | ||
zlib | ||
]; | ||
|
||
propagatedBuildInputs = [ | ||
filelock | ||
]; | ||
|
||
# Avoid GLIBCXX mismatch with other cuda-enabled python packages | ||
preConfigure = '' | ||
export CC="${backendStdenv.cc}/bin/cc"; | ||
export CXX="${backendStdenv.cc}/bin/c++"; | ||
# Upstream's setup.py tries to write cache somewhere in ~/ | ||
export HOME=$TMPDIR | ||
# Upstream's github actions patch setup.cfg to write base-dir. May be redundant | ||
echo " | ||
[build_ext] | ||
base-dir=$PWD" >> python/setup.cfg | ||
# The rest (including buildPhase) is relative to ./python/ | ||
cd python/ | ||
# Work around download_and_copy_ptxas() | ||
dst_cuda="$PWD/triton/third_party/cuda/bin" | ||
mkdir -p "$dst_cuda" | ||
ln -s "${ptxas}" "$dst_cuda/" | ||
''; | ||
|
||
# CMake is run by setup.py instead | ||
dontUseCmakeConfigure = true; | ||
cmakeFlags = [ | ||
"-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir" | ||
]; | ||
|
||
postFixup = | ||
let | ||
ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas"; | ||
in | ||
# Setuptools (?) strips runpath and +x flags. Let's just restore the symlink | ||
'' | ||
rm -f ${ptxasDestination} | ||
ln -s ${ptxas} ${ptxasDestination} | ||
''; | ||
|
||
checkInputs = [ | ||
cmake # ctest | ||
]; | ||
dontUseSetuptoolsCheck = true; | ||
preCheck = | ||
# build/temp* refers to build_ext.build_temp (looked up in the build logs) | ||
'' | ||
(cd /build/source/python/build/temp* ; ctest) | ||
'' # For pytestCheckHook | ||
+ '' | ||
cd test/unit | ||
''; | ||
pythonImportsCheck = [ | ||
# Circular dependency on torch | ||
# "triton" | ||
# "triton.language" | ||
]; | ||
|
||
# Ultimately, torch is our test suite: | ||
passthru.tests = { | ||
inherit torchWithRocm; | ||
}; | ||
|
||
pythonRemoveDeps = [ | ||
# Circular dependency, cf. https://github.com/openai/triton/issues/1374 | ||
"torch" | ||
|
||
# CLI tools without dist-info | ||
"cmake" | ||
"lit" | ||
]; | ||
meta = with lib; { | ||
description = "Development repository for the Triton language and compiler"; | ||
homepage = "https://github.com/openai/triton/"; | ||
platforms = lib.platforms.unix; | ||
license = licenses.mit; | ||
maintainers = with maintainers; [ SomeoneSerge ]; | ||
}; | ||
} |
Oops, something went wrong.