-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Discussion: what does Enzyme.jl need for DL in Julia? #805
Comments
cc @sethaxen re (2) since they were playing around with enzyme rules for BLAS. Might you be interested in helping add the cudaMemcpy/cublas/NNlib/etc stuff? |
You reminded me that I forgot to link FluxML/NNlib.jl#503 :) |
Interested, yes, but am still lobbying to take it on as a work project. The BLAS/Enzyme work is more aligned with the other things I work on and is probably all I can focus on right now. |
Either way, figuring out where the rules should go and at what level they are needed is useful for whoever takes this on. |
I think the right way to do it is in two steps.
I honestly don't think this would take all that long. The reason to not carry over chain rules is that many of the rules aren't necessary for Enzyme, but doing it like this would allow one to just pick out the necessary rules and get NNLib converted rather quickly. Then of course there can always be improvements to use less memory and such, but I'd say we do this conversion, then Enzyme is at least strictly better than Zygote for DL, that helps the ecosystem move, and then worry about grabbing the last bit of performance out of each rule. |
@ToucheSir so in the example you gave above, flux has some runtime activity return mismatches (now featuring the better backtraces which have landed to main).
This was then fixed by marking the Dense function on basic.jl:170 as @ inline. Of course this then hit a second one below:
How does flux feel about adding some relevant @ inline's |
Okay now after fixing KA.jl (JuliaGPU/KernelAbstractions.jl#412), @ToucheSir your snippet runs successfully (didn't check values): julia> Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %77 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %40) #249, !dbg !470"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %120 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %49) #249, !dbg !591"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %124 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %80) #249, !dbg !472"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %130 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #249, !dbg !518"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %224 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #250, !dbg !695"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %230 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %86) #250, !dbg !720"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %397 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #251, !dbg !1047"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %403 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #251, !dbg !1076"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %77 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %40) #255, !dbg !470"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %120 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %49) #255, !dbg !591"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %124 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %80) #255, !dbg !472"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %130 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #255, !dbg !518"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %224 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #255, !dbg !695"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %230 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %86) #255, !dbg !720"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %397 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #255, !dbg !1047"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│ tt = "{[]:Pointer}"
│ orig = " %403 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #255, !dbg !1076"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing, nothing),)
julia> println(dmodel)
Chain(Dense(2 => 4), BatchNorm(4), Dense(4 => 2)) |
On the fast blas enabled mode, the perf below for your microcode is as follows (though note the numbers didn't seem to match Zygote, so @ToucheSir if you have some cycles to identify what code causes a divergence). julia> @btime Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
9.090 μs (77 allocations: 5.42 KiB)
((nothing, nothing),)
julia> Zygote.gradient(model->loss(model, x), model)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)
julia> Zygote.gradient(model->loss(model, x), model)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)
julia> @btime Zygote.gradient(model->loss(model, x), model)
144.153 μs (652 allocations: 39.97 KiB)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),) |
I don't have access to a machine to test this right now, but the difference is likely due to BatchNorm being in auto train mode in Zygote and not in Enzyme. Running |
Given the continued progres of all the Enzyme side of things @ToucheSir I think the next step here would be to isolate what part of the flux code causes Enzyme/Zygote answers to differ, so it can be fixed |
I just realized the example above doesn't actually show the discrepancy, what code are you running which does? As I mentioned earlier, the likeliest culprit is manually running |
Oh maybe I did that incorrectly then. Nevertheless, it would be interesting to start doing some flux + enzyme correctness tests, then we can start diving into the performance (which I see already a lot of optimizations we should be applying but aren't so I'm hopeful we can iterate on). |
I tried some Flux models with Enzyme to see whether the gradients match Zygote. The following, based off the above, is with Julia 1.10.0, Enzyme main (283a1c5) and Flux 0.14.10. using Enzyme, Flux, Random, Test
Enzyme.API.runtimeActivity!(true)
loss(model, x) = sum(model(x))
function test_model(model, x, mi)
println(model)
l = loss(model, x)
Flux.reset!(model)
grads_flux = Flux.gradient(m -> loss(m, x), model)[1]
grads_enzyme = Flux.fmap(model) do x
x isa Array ? zero(x) : x
end
Flux.reset!(model)
Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, grads_enzyme), Const(x))
@testset "Model $mi" begin
Flux.reset!(model)
@test loss(model, x) == l # Check loss doesn't change with multiple runs
for i in eachindex(grads_flux.layers)
layer_flux = grads_flux.layers[i]
layer_enzyme = grads_enzyme.layers[i]
for field in (:weight, :bias, :scale)
if hasfield(typeof(layer_flux), field)
@test isapprox(getfield(layer_flux, field), getfield(layer_enzyme, field))
end
end
if hasfield(typeof(layer_flux), :cell)
for field in (:Wh, :Wi, :b)
@test isapprox(getfield(layer_flux.cell, field), getfield(layer_enzyme.cell, field))
end
end
end
end
end The good news is that that the following all work. I steered clear of normalisation or anything that changes with train/test for now. models_xs = [
[
Chain(Dense(2 => 4), Dense(4 => 2)),
randn(Float32, 2, 1),
],
[
f64(Chain(Dense(2 => 4), Dense(4 => 2))),
randn(Float64, 2, 1),
],
[
Chain(Dense(2 => 4, relu), Dense(4 => 2)),
randn(Float32, 2, 1),
],
[
Chain(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2)),
randn(Float32, 2),
],
[
Chain(Conv((3, 3), 3 => 7, relu), Conv((3, 3), 7 => 7, relu)),
rand(Float32, 10, 10, 3, 50),
],
[
Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())),
rand(Float32, 100, 100, 3, 50),
],
[
Maxout(() -> Dense(5 => 7, tanh), 3),
randn(Float32, 5, 1),
],
[
Chain(RNN(3 => 5), RNN(5 => 3)),
randn(Float32, 3, 10),
],
[
Chain(LSTM(3 => 5), LSTM(5 => 3)),
randn(Float32, 3, 10),
],
]
for (mi, (model, x)) in enumerate(models_xs)
test_model(model, x, mi)
end The following error in Enzyme: models_xs = [
[
SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false)),
randn(Float32, 5, 1),
],
[
Chain(ConvTranspose((3, 3), 3 => 7, stride=2)),
rand(Float32, 10, 10, 3, 50),
],
[
Chain(GRU(3 => 5)),
randn(Float32, 3, 10),
],
[
fmap(cu, Chain(Dense(2 => 4), Dense(4 => 2))), # Requires using CUDA
cu(randn(Float32, 2, 1)),
],
] And this one gives slightly different gradients: models_xs = [
[
Chain(Conv((5, 5), 3 => 7), MeanPool((5,5), pad=SamePad())),
rand(Float32, 100, 100, 3, 50),
],
] If it is helpful I can open individual issues and add the working cases to the Enzyme tests. |
Yes individual issues (with corresponding error traces), would be highly helpful! |
I'd also separately be interested in which ones fail if runtime activity is off |
@ToucheSir the SkipConnection one seems to be a pure flux issue potentially?
|
ConvTranspose issue has been posted to nnlib.jl FluxML/NNlib.jl#565 It requires a rule implementation/extension to conv. |
Doesn't look like it: julia> model = SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false))
SkipConnection(
Chain(
Dense(5 => 20, tanh), # 120 parameters
Dense(20 => 9, tanh), # 189 parameters
),
Bilinear((9, 5) => 3; bias=false), # 135 parameters
) # Total: 5 arrays, 444 parameters, 2.094 KiB.
julia> x = randn(Float32, 5, 1)
5×1 Matrix{Float32}:
-0.6271728
-0.5722281
-1.7240835
0.43075645
0.044463925
julia> model(x)
3×1 Matrix{Float32}:
0.39723554
0.15903589
-0.38918468 I believe the code here is making invalid assumptions about the structure of models and their gradients (i.e. model is always a single-layer
|
I suppose then its that little test code erring then, which means maybe Enzyme works on it :) [or maybe not]. |
For me the test code runs the Flux version fine but errors on the Enzyme version. The Enzyme error is attached as error.txt. I notice your stacktrace contains The gradient checking code makes assumptions and might fail to check all gradients but ideally shouldn't error itself for anything with indexable |
@jgreener64 can you try on latest main (aka not release). There have been fixes for that error precisely that have landed on main since. |
Still errors for me on main (877e1d9). |
Can you confirm your jll version via st? (it should be version 0.0.100) |
Yes it is with jll version 0.0.100. The error message looked similar but may have changed a bit, I can't check right now. |
Given that many of the core points here have been resolved, I'm going to propose closing this issue unless there are any objections.
|
The tests are in, and 1 small patch was needed. The release should be available in an hour or so. |
As mentioned, closing this now for the reasons above. |
Copying over and summarizing some discussion from Slack:
@wsmoses:
A type stable example for 1):
As mentioned on Slack, I'd be happy to provide more if people have ideas of model types they'd like to see.
Riffing on 2):
@ToucheSir:
@MasonProtter:
The text was updated successfully, but these errors were encountered: