Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

python3Packages.torch: 1.13.1 -> 2.0.0 #222273

Merged
merged 10 commits into from
Apr 8, 2023
6 changes: 5 additions & 1 deletion pkgs/development/compilers/llvm/rocm/llvm.nix
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
, targetDir ? "llvm"
, targetProjects ? [ ]
, targetRuntimes ? [ ]
# "NATIVE" resolves into x86 or aarch64 depending on stdenv
, llvmTargetsToBuild ? [ "NATIVE" ]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @NixOS/rocm-maintainers

I merely needed a way to add NVPTX target when overriding llvm for triton. I'm happy to change the interface to whatever

, extraPatches ? [ ]
, extraNativeBuildInputs ? [ ]
, extraBuildInputs ? [ ]
Expand All @@ -46,6 +48,8 @@ let
if stdenv.isx86_64 then "X86"
else if stdenv.isAarch64 then "AArch64"
else throw "Unsupported ROCm LLVM platform";
inferNativeTarget = t: if t == "NATIVE" then llvmNativeTarget else t;
llvmTargetsToBuild' = [ "AMDGPU" ] ++ builtins.map inferNativeTarget llvmTargetsToBuild;
in stdenv.mkDerivation (finalAttrs: {
pname = "rocm-llvm-${targetName}";
version = "5.4.4";
Expand Down Expand Up @@ -98,7 +102,7 @@ in stdenv.mkDerivation (finalAttrs: {
sourceRoot = "${finalAttrs.src.name}/${targetDir}";

cmakeFlags = [
"-DLLVM_TARGETS_TO_BUILD=AMDGPU;${llvmNativeTarget}"
"-DLLVM_TARGETS_TO_BUILD=${builtins.concatStringsSep ";" llvmTargetsToBuild'}"
] ++ lib.optionals (finalAttrs.passthru.isLLVM && targetProjects != [ ]) [
"-DLLVM_ENABLE_PROJECTS=${lib.concatStringsSep ";" targetProjects}"
] ++ lib.optionals ((finalAttrs.passthru.isLLVM || targetDir == "runtimes") && targetRuntimes != [ ]) [
Expand Down
13 changes: 11 additions & 2 deletions pkgs/development/ocaml-modules/torch/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
, stdenv
, buildDunePackage
, fetchFromGitHub
, fetchpatch
, cmdliner
, ctypes
, dune-configurator
Expand All @@ -24,11 +25,19 @@ buildDunePackage rec {

src = fetchFromGitHub {
owner = "LaurentMazare";
repo = "ocaml-${pname}";
rev = version;
repo = "ocaml-${pname}";
rev = version;
hash = "sha256-z/9NUBjeFWE63Z/e8OyzDiy8hrn6qzjaiBH8G9MPeos=";
};

patches = [
# Pytorch 2.0 support. Drop when it reaches a release
(fetchpatch {
url = "https://github.com/LaurentMazare/ocaml-torch/commit/ef7ef30cafecb09e45ec1ed8ce4bedae5947cfa5.patch";
hash = "sha256-smdwKy40iIISp/25L2J4az6KmqFS1soeChBElUyhl5A=";
})
];

buildInputs = [ dune-configurator ];

propagatedBuildInputs = [
Expand Down
253 changes: 253 additions & 0 deletions pkgs/development/python-modules/openai-triton/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
{ lib
, buildPythonPackage
, python
, fetchpatch
, fetchFromGitHub
, addOpenGLRunpath
, cmake
, cudaPackages
, llvmPackages
, pybind11
, gtest
, zlib
, ncurses
, libxml2
, lit
, filelock
, torchWithRocm
, pytest
, pytestCheckHook
, pythonRelaxDepsHook
, pkgsTargetTarget
}:

let
pname = "triton";
version = "2.0.0";

inherit (cudaPackages) cuda_cudart backendStdenv;

# A time may come we'll want to be cross-friendly
#
# Short explanation: we need pkgsTargetTarget, because we use string
# interpolation instead of buildInputs.
#
# Long explanation: OpenAI/triton downloads and vendors a copy of NVidia's
# ptxas compiler. We're not running this ptxas on the build machine, but on
# the user's machine, i.e. our Target platform. The second "Target" in
# pkgsTargetTarget maybe doesn't matter, because ptxas compiles programs to
# be executed on the GPU.
# Cf. https://nixos.org/manual/nixpkgs/unstable/#sec-cross-infra
ptxas = "${pkgsTargetTarget.cudaPackages.cuda_nvcc}/bin/ptxas";
SomeoneSerge marked this conversation as resolved.
Show resolved Hide resolved

llvm = (llvmPackages.llvm.override {
llvmTargetsToBuild = [ "NATIVE" "NVPTX" ];
# Upstream CI sets these too:
# targetProjects = [ "mlir" ];
extraCMakeFlags = [
"-DLLVM_INSTALL_UTILS=ON"
];
});
Comment on lines +43 to +50
Copy link
Contributor Author

@SomeoneSerge SomeoneSerge Mar 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't override at the package level, instead expose e.g. at llvmPackages_rocm.llvmWithCuda and consume as an argument so that users can substitute it? Not sure if it's worth the effort

in
buildPythonPackage {
inherit pname version;

format = "setuptools";

src = fetchFromGitHub {
owner = "openai";
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only now realize, that upstream is actually using a fork of triton, when building for ROCM: triton-lang/triton#46 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a blocker for merging? seems like torch.compile is working without their custom fork?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only tested the runtime for CPU and CUDA. I just started working on making cudaSupport in the triton expression optional, but I stopped myself because I that meant growing the scope of the PR and delaying the merge even further

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For that matter, we also need llvmPackages_17 with mlir enabled, which hasn't been packaged either yet

repo = pname;
rev = "v${version}";
hash = "sha256-9GZzugab+Pdt74Dj6zjlEzjj4BcJ69rzMJmqcVMxsKU=";
};

patches = [
# Prerequisite for llvm15 patch
(fetchpatch {
url = "https://github.com/openai/triton/commit/2aba985daaa70234823ea8f1161da938477d3e02.patch";
hash = "sha256-LGv0+Ut2WYPC4Ksi4803Hwmhi3FyQOF9zElJc/JCobk=";
})
(fetchpatch {
url = "https://github.com/openai/triton/commit/e3941f9d09cdd31529ba4a41018cfc0096aafea6.patch";
hash = "sha256-A+Gor6qzFlGQhVVhiaaYOzqqx8yO2MdssnQS6TIfUWg=";
})

# Source: https://github.com/openai/triton/commit/fc7a8e35819bda632bdcf1cf75fd9abe4d4e077a.patch
# The original patch adds ptxas binary, so we include our own clean copy
# Drop with the next update
./llvm15.patch

# TODO: there have been commits upstream aimed at removing the "torch"
# circular dependency, but the patches fail to apply on the release
# revision. Keeping the link for future reference
# Also cf. https://github.com/openai/triton/issues/1374

# (fetchpatch {
# url = "https://github.com/openai/triton/commit/fc7c0b0e437a191e421faa61494b2ff4870850f1.patch";
# hash = "sha256-f0shIqHJkVvuil2Yku7vuqWFn7VCRKFSFjYRlwx25ig=";
# })
];

postPatch = ''
substituteInPlace python/setup.py \
--replace \
'= get_thirdparty_packages(triton_cache_path)' \
'= os.environ["cmakeFlags"].split()'
''
# Wiring triton=2.0.0 with llcmPackages_rocm.llvm=5.4.3
# Revisit when updating either triton or llvm
+ ''
substituteInPlace CMakeLists.txt \
--replace "nvptx" "NVPTX" \
--replace "LLVM 11" "LLVM"
sed -i '/AddMLIR/a set(MLIR_TABLEGEN_EXE "${llvmPackages.mlir}/bin/mlir-tblgen")' CMakeLists.txt
sed -i '/AddMLIR/a set(MLIR_INCLUDE_DIR ''${MLIR_INCLUDE_DIRS})' CMakeLists.txt
find -iname '*.td' -exec \
sed -i \
-e '\|include "mlir/IR/OpBase.td"|a include "mlir/IR/AttrTypeBase.td"' \
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These paths and class names seem to change a lot between releases, so, admittedly, the next update is probably going to be involved

-e 's|include "mlir/Dialect/StandardOps/IR/Ops.td"|include "mlir/Dialect/Func/IR/FuncOps.td"|' \
'{}' ';'
substituteInPlace unittest/CMakeLists.txt --replace "include(GoogleTest)" "find_package(GTest REQUIRED)"
sed -i 's/^include.*$//' unittest/CMakeLists.txt
sed -i '/LINK_LIBS/i NVPTXInfo' lib/Target/PTX/CMakeLists.txt
sed -i '/LINK_LIBS/i NVPTXCodeGen' lib/Target/PTX/CMakeLists.txt
''
# TritonMLIRIR already links MLIRIR. Not transitive?
# + ''
# echo "target_link_libraries(TritonPTX PUBLIC MLIRIR)" >> lib/Target/PTX/CMakeLists.txt
# ''
# Already defined in llvm, when built with -DLLVM_INSTALL_UTILS
+ ''
substituteInPlace bin/CMakeLists.txt \
--replace "add_subdirectory(FileCheck)" ""

rm cmake/FindLLVM.cmake
''
+
(
let
# Bash was getting weird without linting,
# but basically upstream contains [cc, ..., "-lcuda", ...]
# and we replace it with [..., "-lcuda", "-L/run/opengl-driver/lib", "-L$stubs", ...]
old = [ "-lcuda" ];
new = [ "-lcuda" "-L${addOpenGLRunpath.driverLink}" "-L${cuda_cudart}/lib/stubs/" ];

quote = x: ''"${x}"'';
oldStr = lib.concatMapStringsSep ", " quote old;
newStr = lib.concatMapStringsSep ", " quote new;
in
''
substituteInPlace python/triton/compiler.py \
--replace '${oldStr}' '${newStr}'
''
)
# Triton seems to be looking up cuda.h
+ ''
sed -i 's|cu_include_dir = os.path.join.*$|cu_include_dir = "${cuda_cudart}/include"|' python/triton/compiler.py
'';

nativeBuildInputs = [
cmake
pythonRelaxDepsHook

# Requires torch (circular dependency) and probably needs GPUs:
# pytestCheckHook

# Note for future:
# These *probably* should go in depsTargetTarget
# ...but we cannot test cross right now anyway
# because we only support cudaPackages on x86_64-linux atm
lit
llvm
llvmPackages.mlir
];

buildInputs = [
gtest
libxml2.dev
ncurses
pybind11
zlib
];

propagatedBuildInputs = [
filelock
];

# Avoid GLIBCXX mismatch with other cuda-enabled python packages
preConfigure = ''
export CC="${backendStdenv.cc}/bin/cc";
export CXX="${backendStdenv.cc}/bin/c++";

# Upstream's setup.py tries to write cache somewhere in ~/
export HOME=$TMPDIR

# Upstream's github actions patch setup.cfg to write base-dir. May be redundant
echo "
[build_ext]
base-dir=$PWD" >> python/setup.cfg

# The rest (including buildPhase) is relative to ./python/
cd python/

# Work around download_and_copy_ptxas()
dst_cuda="$PWD/triton/third_party/cuda/bin"
mkdir -p "$dst_cuda"
ln -s "${ptxas}" "$dst_cuda/"
'';

# CMake is run by setup.py instead
dontUseCmakeConfigure = true;
cmakeFlags = [
"-DMLIR_DIR=${llvmPackages.mlir}/lib/cmake/mlir"
];

postFixup =
let
ptxasDestination = "$out/${python.sitePackages}/triton/third_party/cuda/bin/ptxas";
in
# Setuptools (?) strips runpath and +x flags. Let's just restore the symlink
''
rm -f ${ptxasDestination}
ln -s ${ptxas} ${ptxasDestination}
'';

checkInputs = [
cmake # ctest
];
dontUseSetuptoolsCheck = true;
preCheck =
# build/temp* refers to build_ext.build_temp (looked up in the build logs)
''
(cd /build/source/python/build/temp* ; ctest)
'' # For pytestCheckHook
+ ''
cd test/unit
'';
pythonImportsCheck = [
# Circular dependency on torch
# "triton"
# "triton.language"
];

# Ultimately, torch is our test suite:
passthru.tests = {
inherit torchWithRocm;
};

pythonRemoveDeps = [
# Circular dependency, cf. https://github.com/openai/triton/issues/1374
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dependency was removed from setup.py in triton-lang/triton#1389, but the imports are still there

"torch"

# CLI tools without dist-info
"cmake"
"lit"
];
meta = with lib; {
description = "Development repository for the Triton language and compiler";
homepage = "https://github.com/openai/triton/";
platforms = lib.platforms.unix;
SomeoneSerge marked this conversation as resolved.
Show resolved Hide resolved
license = licenses.mit;
maintainers = with maintainers; [ SomeoneSerge ];
};
}
Loading