Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

getting a minimal Lux example working (with CUDA) #1392

Closed
ExpandingMan opened this issue Apr 18, 2024 · 6 comments
Closed

getting a minimal Lux example working (with CUDA) #1392

ExpandingMan opened this issue Apr 18, 2024 · 6 comments

Comments

@ExpandingMan
Copy link
Contributor

ExpandingMan commented Apr 18, 2024

This issue is to provide a minimal example of neural network training with Lux to hopefully make it easier for developers to work toward making it viable. It probably isn't news to anyone here that this example fails, but it was indicated to me on slack that it would still be helpful to have this issue for reference.

In this example, we train a neural network with zero hidden layers to approximate the polynomial $x^2 - 2x$. It is trivial to generalize this to deeper neural networks but probably not useful for this demonstration. This roughly follows the Lux tutorial here, but I have stripped out the opaque Lux training stuff so that it's more clear what's going on. I expect this example should be simpler to diagnose than the equivalent with Flux, as the explicit parameterization of Lux makes it easier to reason about, but I also expect that if this example were working the analogous Flux example surely would. Indeed, I think this example is a good proxy for a huge number of common use cases.

using LinearAlgebra, Random, Statistics, Optimisers
using CUDA
using Lux, LuxCUDA
import Zygote, Enzyme

const dev = gpu_device()

function makedata(rng::AbstractRNG)
    X = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(X, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    (X, y)
end

function loss(model, θ, ψ, (X, y))
    (ŷ, ψ) = Lux.apply(model, X, θ, ψ)
    mean(abs2, ŷ .- y)
end

function gradloss_zygote(model, θ, ψ, (X, y))
    (∇ℓ,) = Zygote.gradient(θ) do ϑ
        loss(model, ϑ, ψ, (X, y))
    end
    ∇ℓ
end

function gradloss_enzyme(model, θ, ψ, (X, y))
    ℓ = ϑ -> begin
        loss(model, ϑ, ψ, (X, y))
    end
    Enzyme.gradient(Enzyme.Reverse, ℓ, θ)
end

function main(rng=Random.Xoshiro(999),
              model=Chain(Dense(1=>16, gelu), Dense(16=>1)),
              (X, y)=makedata(rng) |> dev;
              nepochs=300,
             )
    (θ, ψ) = Lux.setup(rng, model) |> dev

    opts = Optimisers.setup(Adam(0.01f0), θ)

    for j  1:nepochs
        ∇ℓ = gradloss_enzyme(model, θ, ψ, (X, y))
        (opts, θ) = Optimisers.update!(opts, θ, ∇ℓ)
    end

    (ŷ, _) = Lux.apply(model, X, θ, ψ)

    (y, ŷ)
end

Note that

  • This works with both Zygote and Enzyme if dev = cpu_device (i.e. no GPU is involved at all).
  • This works using gradloss_zygote using either cpu_device or gpu_device.
  • This fails rather spectacularly using gradloss_enzyme and gpu_device.

The error output is so verbose that I won't try to reproduce it all here (it goes nuts and starts dumping LLVM IR), I expect others to be able to reproduce a same or similar error, but the stack trace is

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{…}, errtype::Enzyme.API.ErrorType, data::Ptr{…}, data2::Ptr{…}, B::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:1684
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/MIIMf/src/api.jl:154
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:3109
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:4964
  [5] codegen
    @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:4391 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5646
  [7] _thunk
    @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5646 [inlined]
  [8] cached_compilation
    @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5680 [inlined]
  [9] (::Enzyme.Compiler.var"#532#533"{…})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5746
 [10] JuliaContext(f::Enzyme.Compiler.var"#532#533"{…}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:52
 [11] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/packages/GPUCompiler/kqxyC/src/driver.jl:42
 [12] #s1926#531
    @ ~/.julia/packages/Enzyme/MIIMf/src/compiler.jl:5698 [inlined]
 [13]
    @ Enzyme.Compiler ./none:0
 [14] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [15] autodiff
    @ ~/.julia/packages/Enzyme/MIIMf/src/Enzyme.jl:270 [inlined]
 [16] autodiff
    @ ~/.julia/packages/Enzyme/MIIMf/src/Enzyme.jl:287 [inlined]
 [17] gradient
    @ ~/.julia/packages/Enzyme/MIIMf/src/Enzyme.jl:938 [inlined]
 [18] gradloss_enzyme(model::Chain{…}, θ::@NamedTuple{…}, ψ::@NamedTuple{…}, ::Tuple{…})
    @ Main ~/src/autodiff/zygote_enzyme_minimal.jl:31
 [19] main(rng::Xoshiro, model::Chain{@NamedTuple{…}, Nothing}, ::Tuple{CuArray{…}, CuArray{…}}; nepochs::Int64)
    @ Main ~/src/autodiff/zygote_enzyme_minimal.jl:44
 [20] main(rng::Xoshiro, model::Chain{@NamedTuple{…}, Nothing}, ::Tuple{CuArray{…}, CuArray{…}})
    @ Main ~/src/autodiff/zygote_enzyme_minimal.jl:34
 [21] top-level scope
    @ REPL[2]:1
 [22] top-level scope
    @ ~/.julia/packages/CUDA/fGE8R/src/initialization.jl:206
@wsmoses
Copy link
Member

wsmoses commented Apr 18, 2024 via email

@ExpandingMan
Copy link
Contributor Author

Full log attached.
enzyme_crash_full.log

@wsmoses
Copy link
Member

wsmoses commented Apr 20, 2024

@ExpandingMan looks like CUDA.jl needs to have a rule added, specifically like below. What happens if you add this to your file before any AD?

function Enzyme.EnzymeRules.inactive(::typeof(CUDA.CUBLAS.handle))
    return nothing
end

cc @vchuravy

@ExpandingMan
Copy link
Contributor Author

Result looks the same to me, at least superficially. Log attached.
crash.log

Out of curiosity, if Enzyme is merely inactive for CuBLAS, but CuBLAS is being used in the function it's trying to differentiate (as I think would be the case here), wouldn't it, at best, return an incorrect result? I would have thought that CuBLAS would act as a frightful barrier to ever getting this working.

@wsmoses
Copy link
Member

wsmoses commented Apr 20, 2024 via email

@wsmoses
Copy link
Member

wsmoses commented Sep 3, 2024

This now hits:

julia> main()
ERROR: 
No augmented forward pass found for cublasLtMatmulDescCreate
 at context:   %133 = call i32 @cublasLtMatmulDescCreate(i64 %bitcast_coercion, i32 %unbox32, i32 noundef 0) #482 [ "jl_roots"({} addrspace(10)* %126) ], !dbg !544

Stacktrace:
 [1] macro expansion
   @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:218
 [2] macro expansion
   @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublasLt.jl:400
 [3] #1158
   @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:35
 [4] retry_reclaim
   @ ~/.julia/packages/CUDA/Tl08O/src/memory.jl:434
 [5] check
   @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublas.jl:24
 [6] cublasLtMatmulDescCreate
   @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:34
 [7] cublaslt_matmul_fused!
   @ ~/.julia/packages/LuxLib/ZEWr3/ext/LuxLibCUDAExt/cublaslt.jl:62


Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:218 [inlined]
  [2] macro expansion
    @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublasLt.jl:400 [inlined]
  [3] #1158
    @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:35 [inlined]
  [4] retry_reclaim
    @ ~/.julia/packages/CUDA/Tl08O/src/memory.jl:434 [inlined]
  [5] check
    @ ~/.julia/packages/CUDA/Tl08O/lib/cublas/libcublas.jl:24 [inlined]
  [6] cublasLtMatmulDescCreate
    @ ~/.julia/packages/CUDA/Tl08O/lib/utils/call.jl:34 [inlined]
  [7] cublaslt_matmul_fused!
    @ ~/.julia/packages/LuxLib/ZEWr3/ext/LuxLibCUDAExt/cublaslt.jl:62
  [8] cublaslt_matmul_fused!
    @ ~/.julia/packages/LuxLib/ZEWr3/ext/LuxLibCUDAExt/cublaslt.jl:13 [inlined]
  [9] cublasLt_fused_dense!
    @ ~/.julia/packages/LuxLib/ZEWr3/ext/LuxLibCUDAExt/cublaslt.jl:195
 [10] cublasLt_fused_dense!
    @ ~/.julia/packages/LuxLib/ZEWr3/ext/LuxLibCUDAExt/cublaslt.jl:193 [inlined]
 [11] fused_dense!
    @ ~/.julia/packages/LuxLib/ZEWr3/src/impl/dense.jl:38 [inlined]
 [12] fused_dense
    @ ~/.julia/packages/LuxLib/ZEWr3/src/impl/dense.jl:24 [inlined]
 [13] fused_dense
    @ ~/.julia/packages/LuxLib/ZEWr3/src/impl/dense.jl:11 [inlined]
 [14] fused_dense_bias_activation
    @ ~/.julia/packages/LuxLib/ZEWr3/src/api/dense.jl:30 [inlined]
 [15] Dense
    @ ~/.julia/packages/Lux/a2Wcp/src/layers/basic.jl:366 [inlined]
 [16] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [17] macro expansion
    @ ~/.julia/packages/Lux/a2Wcp/src/layers/containers.jl:0 [inlined]
 [18] applychain
    @ ~/.julia/packages/Lux/a2Wcp/src/layers/containers.jl:520
 [19] Chain
    @ ~/.julia/packages/Lux/a2Wcp/src/layers/containers.jl:518 [inlined]
 [20] apply
    @ ~/.julia/packages/LuxCore/yzx6E/src/LuxCore.jl:171 [inlined]
 [21] loss
    @ ./REPL[9]:2
 [22] #8
    @ ./REPL[16]:3 [inlined]
 [23] diffejulia__8_37808_inner_1wrap
    @ ./REPL[16]:0
 [24] macro expansion
    @ ~/Enzyme.jl/src/compiler.jl:7172 [inlined]
 [25] enzyme_call
    @ ~/Enzyme.jl/src/compiler.jl:6781 [inlined]
 [26] CombinedAdjointThunk
    @ ~/Enzyme.jl/src/compiler.jl:6658 [inlined]
 [27] autodiff
    @ ~/Enzyme.jl/src/Enzyme.jl:320 [inlined]
 [28] gradient
    @ ~/Enzyme.jl/src/Enzyme.jl:1049 [inlined]
 [29] gradloss_enzyme(model::Chain{@NamedTuple{…}, Nothing}, θ::@NamedTuple{layer_1::@NamedTuple{…}, layer_2::@NamedTuple{…}}, ψ::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}, ::Tuple{CuArray{…}, CuArray{…}})
    @ Main ./REPL[16]:5
 [30] main(rng::Xoshiro, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}}, Nothing}, ::Tuple{CuArray{Float32, 2, CUDA.DeviceMemory}, CuArray{Float32, 2, CUDA.DeviceMemory}}; nepochs::Int64)
    @ Main ./REPL[12]:11
 [31] main(rng::Xoshiro, model::Chain{@NamedTuple{layer_1::Dense{…}, layer_2::Dense{…}}, Nothing}, ::Tuple{CuArray{Float32, 2, CUDA.DeviceMemory}, CuArray{Float32, 2, CUDA.DeviceMemory}})
    @ Main ./REPL[12]:1
 [32] top-level scope
    @ REPL[17]:1
Some type information was truncated. Use `show(err)` to see complete types.

which is equivalent to LuxDL/LuxLib.jl#148 so moving the issue there

@wsmoses wsmoses closed this as completed Sep 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants