Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discussion: what does Enzyme.jl need for DL in Julia? #805

Closed
ToucheSir opened this issue May 3, 2023 · 27 comments
Closed

Discussion: what does Enzyme.jl need for DL in Julia? #805

ToucheSir opened this issue May 3, 2023 · 27 comments

Comments

@ToucheSir
Copy link

ToucheSir commented May 3, 2023

Copying over and summarizing some discussion from Slack:

@wsmoses:

  1. Type-stable deep learning libraries. So far all the DL library code I've seen people try to AD with Enzyme is very very type unstable. Enzyme has been adding support for type unstable code (I'm writing the type unstable get rule right now on the plane even). That said for the performance reason above, the deep learning library being AD'd really sohuld not be type unstable, and therefore shouldn't need full support of Type unstable code.
  2. CUBLAS/cudaMemCopy/etc rules. These can be written at the Julia level with EnzymeRules. Some of these may make sense to do in Enzyme proper (cc Manuel Drehwald). Either way we haven't written these yet because we've had bigger feature things to do that require V/I to do per Enzyme/Julia internals knowledge (GC, more type unstable support, etc). However doing these shouldn't require internals knowledge, just someone willing to play with EnzymeRules in Julia [aka please help wanted!!!!]
  3. Scheduling. The two above are sufficient to Enzyme AD through DL stuff (tho to be clear it doesn't mean DL alone is complete, maybe needing fast data loading? or some other non-ad specific stuff). However if we really want super fast backprop we'll also want to do scheduing. To be clear this can be independent of AD (and thats how existing stuff like Jax/TF/PyTorch do it, they run scheduling via XLA). Enzyme/AD doesn't need to be resposible for the scheduling, but not doing it at all may be slow. I however add this point here because the EnzymeMLIR and related work we're doing actually will add scheduling to Enzyme itself. From a research standpoint we actually think this is theoretically (performance and otherwise) better than their separation. Of course this can be separate though.
  4. Maybe exposing the internal Enzyme cache allocation mechanism to use julia-level of other algorithms which might give different memory reuse/other behaviors. I haven't thought much about this yet, but it may be helpful for perf (but not necessary to run).

A type stable example for 1):

using Flux, Random, Enzyme
rng = Random.default_rng()

loss(model, x) = sum(model(x))

model = Chain(Dense(2 => 4), BatchNorm(4), Dense(4 => 2))
x = randn(rng, Float32, 2, 1)

dmodel = Flux.fmap(model) do x
    x isa Array ? zero(x) : x
end

Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
println(dmodel)

As mentioned on Slack, I'd be happy to provide more if people have ideas of model types they'd like to see.

Riffing on 2):
@ToucheSir:

To my knowledge, no CR-based AD can deal with the semantics of Duplicated right now because they lack any notion of aliasing. So whether you wrap ChainRules in EnzymeRules, EnzymeRules in ChainRules or have ChainRules with in-place and out-of-place gradient accumulation, I don't see a way around having two separate systems (in the short- to medium-term)
...
here's all the relevant talk I remember from the CR side: JuliaDiff/ChainRulesCore.jl#591, JuliaDiff/ChainRulesCore.jl#591, JuliaDiff/ChainRulesCore.jl#242, JuliaDiff/ChainRulesCore.jl#578

@MasonProtter:

Interesting. Seems like we could maybe deal with it via some sort of dispatch object that contains info about the AD package under use. So you could write a rule that dispatches on whether or not the AD system supports Duplicated or not

@ToucheSir ToucheSir changed the title Discussion: what does Enzyme need for DL? Discussion: what does Enzyme.jl need for DL in Julia? May 3, 2023
@wsmoses
Copy link
Member

wsmoses commented May 3, 2023

cc @sethaxen re (2) since they were playing around with enzyme rules for BLAS. Might you be interested in helping add the cudaMemcpy/cublas/NNlib/etc stuff?

@ToucheSir
Copy link
Author

You reminded me that I forgot to link FluxML/NNlib.jl#503 :)

@sethaxen
Copy link
Collaborator

sethaxen commented May 3, 2023

cc @sethaxen re (2) since they were playing around with enzyme rules for BLAS. Might you be interested in helping add the cudaMemcpy/cublas/NNlib/etc stuff?

Interested, yes, but am still lobbying to take it on as a work project. The BLAS/Enzyme work is more aligned with the other things I work on and is probably all I can focus on right now.

@sethaxen
Copy link
Collaborator

sethaxen commented May 3, 2023

Either way, figuring out where the rules should go and at what level they are needed is useful for whoever takes this on.

@ChrisRackauckas
Copy link
Contributor

I think the right way to do it is in two steps.

  1. Define a macro that defines the Enzyme rules for a given function based on the ChainRules. For example, @ruletransfer conv(x)
  2. Use said macro to define overloads for things that have nice high level rules or handle binaries (like cudnn).
  3. Make better rules in the future.

I honestly don't think this would take all that long. The reason to not carry over chain rules is that many of the rules aren't necessary for Enzyme, but doing it like this would allow one to just pick out the necessary rules and get NNLib converted rather quickly. Then of course there can always be improvements to use less memory and such, but I'd say we do this conversion, then Enzyme is at least strictly better than Zygote for DL, that helps the ecosystem move, and then worry about grabbing the last bit of performance out of each rule.

@wsmoses
Copy link
Member

wsmoses commented Jul 13, 2023

@ToucheSir so in the example you gave above, flux has some runtime activity return mismatches (now featuring the better backtraces which have landed to main).

julia> Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
ERROR: Enzyme execution failed.
Mismatched activity for:   ret {} addrspace(10)* %186, !dbg !418 const val:   %186 = call fastcc noalias nonnull {} addrspace(10)* @julia_Array_2108({} addrspace(10)* noalias nocapture nofree nonnull readnone align 64 undef, [2 x i64] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %185) #403, !dbg !497
Type tree: {}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] Dense
   @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:174

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:2790
  [2] Dense
    @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:174
  [3] macro expansion
    @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53 [inlined]
  [4] _applychain
    @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53 [inlined]
  [5] Chain
    @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:51 [inlined]
  [6] loss
    @ ./REPL[5]:1 [inlined]
  [7] loss
    @ ./REPL[5]:0 [inlined]
  [8] diffejulia_loss_1881_inner_8wrap
    @ ./REPL[5]:0
  [9] macro expansion
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9369 [inlined]
 [10] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…}, ::Float32)
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9047
 [11] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9010
 [12] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
    @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:213
 [13] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
    @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:222
 [14] top-level scope
    @ REPL[9]:1
 [15] top-level scope
    @ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185
Some type information was truncated. Use `show(err)` to see complete types.

This was then fixed by marking the Dense function on basic.jl:170 as @ inline.

Of course this then hit a second one below:

julia> Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
ERROR: Enzyme execution failed.
Mismatched activity for:   ret {} addrspace(10)* %106, !dbg !398 const val:   %106 = call fastcc noalias nonnull {} addrspace(10)* @julia_Array_2004({} addrspace(10)* noalias nocapture nofree nonnull readnone align 64 undef, [2 x i64] addrspace(11)* nocapture noundef nonnull readonly align 8 dereferenceable(16) %105) #399, !dbg !459
Type tree: {}
You may be using a constant variable as temporary storage for active memory (https://enzyme.mit.edu/julia/stable/#Activity-of-temporary-storage). If not, please open an issue, and either rewrite this variable to not be conditionally active or use Enzyme.API.runtimeActivity!(true) as a workaround for now

Stacktrace:
 [1] -
   @ ./abstractarraymath.jl:218

Stacktrace:
  [1] throwerr(cstr::Cstring)
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:2790
  [2] -
    @ ./abstractarraymath.jl:218
  [3] #_norm_layer_forward#302
    @ ~/.julia/packages/Flux/n3cOc/src/layers/normalise.jl:247
  [4] _norm_layer_forward
    @ ~/.julia/packages/Flux/n3cOc/src/layers/normalise.jl:225 [inlined]
  [5] BatchNorm
    @ ~/.julia/packages/Flux/n3cOc/src/layers/normalise.jl:351
  [6] macro expansion
    @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53 [inlined]
  [7] _applychain
    @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:53
  [8] Chain
    @ ~/.julia/packages/Flux/n3cOc/src/layers/basic.jl:51 [inlined]
  [9] loss
    @ ./REPL[4]:1 [inlined]
 [10] loss
    @ ./REPL[4]:0 [inlined]
 [11] diffejulia_loss_1755_inner_8wrap
    @ ./REPL[4]:0
 [12] macro expansion
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9369 [inlined]
 [13] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Const{…}, ::Float32)
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9047
 [14] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/git/Enzyme.jl/src/compiler.jl:9010
 [15] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(loss)}, ::Type{Active}, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
    @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:213
 [16] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Type, ::Duplicated{Chain{Tuple{…}}}, ::Vararg{Any})
    @ Enzyme ~/git/Enzyme.jl/src/Enzyme.jl:222
 [17] top-level scope
    @ REPL[8]:1
 [18] top-level scope
    @ ~/.julia/packages/CUDA/tVtYo/src/initialization.jl:185
Some type information was truncated. Use `show(err)` to see complete types.

How does flux feel about adding some relevant @ inline's

@wsmoses
Copy link
Member

wsmoses commented Jul 24, 2023

Okay now after fixing KA.jl (JuliaGPU/KernelAbstractions.jl#412), @ToucheSir your snippet runs successfully (didn't check values):

julia> Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %77 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %40) #249, !dbg !470"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %120 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %49) #249, !dbg !591"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %124 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %80) #249, !dbg !472"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %130 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #249, !dbg !518"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %224 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #250, !dbg !695"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %230 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %86) #250, !dbg !720"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %397 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #251, !dbg !1047"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %403 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #251, !dbg !1076"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %77 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %40) #255, !dbg !470"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %120 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %49) #255, !dbg !591"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %124 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %80) #255, !dbg !472"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %130 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #255, !dbg !518"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %224 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #255, !dbg !695"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %230 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %86) #255, !dbg !720"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %397 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi18) #255, !dbg !1047"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
┌ Warning: Unknown concrete type
│   tt = "{[]:Pointer}"
│   orig = "  %403 = call noalias nonnull {} addrspace(10)* @ijl_array_copy({} addrspace(10)* nonnull %value_phi17) #255, !dbg !1076"
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YO8Uj/src/utils.jl:56
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing, nothing),)

julia> println(dmodel)
Chain(Dense(2 => 4), BatchNorm(4), Dense(4 => 2))

@wsmoses
Copy link
Member

wsmoses commented Jul 24, 2023

On the fast blas enabled mode, the perf below for your microcode is as follows (though note the numbers didn't seem to match Zygote, so @ToucheSir if you have some cycles to identify what code causes a divergence).

julia> @btime Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, dmodel), Const(x))
  9.090 μs (77 allocations: 5.42 KiB)
((nothing, nothing),)

julia> Zygote.gradient(model->loss(model, x), model)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)

julia> Zygote.gradient(model->loss(model, x), model)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)

julia> @btime Zygote.gradient(model->loss(model, x), model)
  144.153 μs (652 allocations: 39.97 KiB)
((layers = ((weight = Float32[0.0 0.0; 0.0 0.0; 0.0 0.0; 0.0 0.0], bias = Float32[0.0, 0.0, 0.0, 0.0], σ = nothing), (λ = nothing, β = Float32[0.38328338, -0.49341357, -0.9959768, -0.5516981], γ = Float32[0.0, 0.0, 0.0, 0.0], μ = nothing, σ² = nothing, ϵ = -0.0f0, momentum = nothing, affine = nothing, track_stats = nothing, active = nothing, chs = nothing), (weight = Float32[0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0], bias = Fill(1.0f0, 2), σ = nothing)),),)

@ToucheSir
Copy link
Author

I don't have access to a machine to test this right now, but the difference is likely due to BatchNorm being in auto train mode in Zygote and not in Enzyme. Running Flux.trainmode!(model) before the autodiff call should help confirm that.

@wsmoses
Copy link
Member

wsmoses commented Oct 8, 2023

Given the continued progres of all the Enzyme side of things @ToucheSir I think the next step here would be to isolate what part of the flux code causes Enzyme/Zygote answers to differ, so it can be fixed

@ToucheSir
Copy link
Author

I just realized the example above doesn't actually show the discrepancy, what code are you running which does? As I mentioned earlier, the likeliest culprit is manually running Flux.trainmode!(model) before using Enzyme, since all ADs current need to opt-in to the auto trainmode when differentiating mechanism Flux uses for norm layers.

@wsmoses
Copy link
Member

wsmoses commented Oct 8, 2023

Oh maybe I did that incorrectly then.

Nevertheless, it would be interesting to start doing some flux + enzyme correctness tests, then we can start diving into the performance (which I see already a lot of optimizations we should be applying but aren't so I'm hopeful we can iterate on).

@jgreener64
Copy link
Contributor

I tried some Flux models with Enzyme to see whether the gradients match Zygote.

The following, based off the above, is with Julia 1.10.0, Enzyme main (283a1c5) and Flux 0.14.10.

using Enzyme, Flux, Random, Test

Enzyme.API.runtimeActivity!(true)

loss(model, x) = sum(model(x))

function test_model(model, x, mi)
    println(model)
    l = loss(model, x)

    Flux.reset!(model)
    grads_flux = Flux.gradient(m -> loss(m, x), model)[1]

    grads_enzyme = Flux.fmap(model) do x
        x isa Array ? zero(x) : x
    end
    Flux.reset!(model)
    Enzyme.autodiff(Reverse, loss, Active, Duplicated(model, grads_enzyme), Const(x))

    @testset "Model $mi" begin
        Flux.reset!(model)
        @test loss(model, x) == l # Check loss doesn't change with multiple runs

        for i in eachindex(grads_flux.layers)
            layer_flux = grads_flux.layers[i]
            layer_enzyme = grads_enzyme.layers[i]
            for field in (:weight, :bias, :scale)
                if hasfield(typeof(layer_flux), field)
                    @test isapprox(getfield(layer_flux, field), getfield(layer_enzyme, field))
                end
            end
            if hasfield(typeof(layer_flux), :cell)
                for field in (:Wh, :Wi, :b)
                    @test isapprox(getfield(layer_flux.cell, field), getfield(layer_enzyme.cell, field))
                end
            end
        end
    end
end

The good news is that that the following all work. I steered clear of normalisation or anything that changes with train/test for now.

models_xs = [
    [
        Chain(Dense(2 => 4), Dense(4 => 2)),
        randn(Float32, 2, 1),
    ],
    [
        f64(Chain(Dense(2 => 4), Dense(4 => 2))),
        randn(Float64, 2, 1),
    ],
    [
        Chain(Dense(2 => 4, relu), Dense(4 => 2)),
        randn(Float32, 2, 1),
    ],
    [
        Chain(Flux.Scale([1.0f0 2.0f0 3.0f0 4.0f0], true, abs2)),
        randn(Float32, 2),
    ],
    [
        Chain(Conv((3, 3), 3 => 7, relu), Conv((3, 3), 7 => 7, relu)),
        rand(Float32, 10, 10, 3, 50),
    ],
    [
        Chain(Conv((5, 5), 3 => 7, pad=SamePad()), MaxPool((5, 5), pad=SamePad())),
        rand(Float32, 100, 100, 3, 50),
    ],
    [
        Maxout(() -> Dense(5 => 7, tanh), 3),
        randn(Float32, 5, 1),
    ],
    [
        Chain(RNN(3 => 5), RNN(5 => 3)),
        randn(Float32, 3, 10),
    ],
    [
        Chain(LSTM(3 => 5), LSTM(5 => 3)),
        randn(Float32, 3, 10),
    ],
]

for (mi, (model, x)) in enumerate(models_xs)
    test_model(model, x, mi)
end

The following error in Enzyme:

models_xs = [
    [
        SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false)),
        randn(Float32, 5, 1),
    ],
    [
        Chain(ConvTranspose((3, 3), 3 => 7, stride=2)),
        rand(Float32, 10, 10, 3, 50),
    ],
    [
        Chain(GRU(3 => 5)),
        randn(Float32, 3, 10),
    ],
    [
        fmap(cu, Chain(Dense(2 => 4), Dense(4 => 2))), # Requires using CUDA
        cu(randn(Float32, 2, 1)),
    ],
]

And this one gives slightly different gradients:

models_xs = [
    [
        Chain(Conv((5, 5), 3 => 7), MeanPool((5,5), pad=SamePad())),
        rand(Float32, 100, 100, 3, 50),
    ],
]

If it is helpful I can open individual issues and add the working cases to the Enzyme tests.

@wsmoses
Copy link
Member

wsmoses commented Jan 29, 2024

Yes individual issues (with corresponding error traces), would be highly helpful!

@wsmoses
Copy link
Member

wsmoses commented Jan 29, 2024

I'd also separately be interested in which ones fail if runtime activity is off

@wsmoses
Copy link
Member

wsmoses commented Feb 12, 2024

@ToucheSir the SkipConnection one seems to be a pure flux issue potentially?

julia> models_xs = [
           [
               SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false)),
               randn(Float32, 5, 1),
           ],]
1-element Vector{Vector{Any}}:
 [SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Bilinear((9, 5) => 3; bias=false)), Float32[0.56165487; 1.2769437; … ; 0.798284; 0.12582794;;]]

julia> for (mi, (model, x)) in enumerate(models_xs)
           test_model(model, x, mi)
       end
SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Bilinear((9, 5) => 3; bias=false))
Model 1: Error During Test at REPL[11]:14
  Got exception outside of a @test
  MethodError: no method matching getindex(::Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}, ::Symbol)
  
  Closest candidates are:
    getindex(::Tuple, ::Colon)
     @ Base tuple.jl:37
    getindex(::Tuple, ::Int64)
     @ Base tuple.jl:31
    getindex(::Tuple, ::CartesianIndex{1})
     @ Base multidimensional.jl:882
    ...
  
  Stacktrace:
    [1] #getindex#173
      @ Flux ~/.julia/packages/MacroTools/Cf2ok/src/examples/forward.jl:18 [inlined]
    [2] getindex(x::Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, args::Symbol)
      @ Flux ~/.julia/packages/MacroTools/Cf2ok/src/examples/forward.jl:17
    [3] macro expansion
      @ ./REPL[11]:20 [inlined]
    [4] macro expansion
      @ ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/Test/src/Test.jl:1577 [inlined]
    [5] test_model(model::SkipConnection{Chain{Tuple{Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}, Dense{typeof(tanh), Matrix{Float32}, Vector{Float32}}}}, Flux.Bilinear{typeof(identity), Array{Float32, 3}, Bool}}, x::Matrix{Float32}, mi::Int64)
      @ Main ./REPL[11]:15
    [6] top-level scope
      @ ./REPL[20]:2
    [7] eval
      @ Core ./boot.jl:385 [inlined]
    [8] eval_user_input(ast::Any, backend::REPL.REPLBackend, mod::Module)
      @ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
    [9] repl_backend_loop(backend::REPL.REPLBackend, get_module::Function)
      @ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
   [10] start_repl_backend(backend::REPL.REPLBackend, consumer::Any; get_module::Function)
      @ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
   [11] run_repl(repl::REPL.AbstractREPL, consumer::Any; backend_on_current_task::Bool, backend::Any)
      @ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
   [12] run_repl(repl::REPL.AbstractREPL, consumer::Any)
      @ REPL ~/git/Enzyme.jl/julia-1.10.0-rc2/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
   [13] (::Base.var"#1013#1015"{Bool, Bool, Bool})(REPL::Module)
      @ Base ./client.jl:432
   [14] #invokelatest#2
      @ Base ./essentials.jl:887 [inlined]
   [15] invokelatest
      @ Base ./essentials.jl:884 [inlined]
   [16] run_main_repl(interactive::Bool, quiet::Bool, banner::Bool, history_file::Bool, color_set::Bool)
      @ Base ./client.jl:416
   [17] exec_options(opts::Base.JLOptions)
      @ Base ./client.jl:333
   [18] _start()
      @ Base ./client.jl:552
Test Summary: | Pass  Error  Total  Time
Model 1       |    1      1      2  0.2s
ERROR: Some tests did not pass: 1 passed, 0 failed, 1 errored, 0 broken.

@wsmoses
Copy link
Member

wsmoses commented Feb 12, 2024

ConvTranspose issue has been posted to nnlib.jl FluxML/NNlib.jl#565 It requires a rule implementation/extension to conv.

@ToucheSir
Copy link
Author

@ToucheSir the SkipConnection one seems to be a pure flux issue potentially?

Doesn't look like it:

julia> model = SkipConnection(Chain(Dense(5 => 20, tanh), Dense(20 => 9, tanh)), Flux.Bilinear((9, 5) => 3, bias=false))
SkipConnection(
  Chain(
    Dense(5 => 20, tanh),               # 120 parameters
    Dense(20 => 9, tanh),               # 189 parameters
  ),
  Bilinear((9, 5) => 3; bias=false),    # 135 parameters
)                   # Total: 5 arrays, 444 parameters, 2.094 KiB.

julia> x = randn(Float32, 5, 1)
5×1 Matrix{Float32}:
 -0.6271728
 -0.5722281
 -1.7240835
  0.43075645
  0.044463925

julia> model(x)
3×1 Matrix{Float32}:
  0.39723554
  0.15903589
 -0.38918468

I believe the code here is making invalid assumptions about the structure of models and their gradients (i.e. model is always a single-layer Chain of Denses, which is not true for SkipConnection), so the error is caused by the test code instead of Flux.

        for i in eachindex(grads_flux.layers)
            layer_flux = grads_flux.layers[i]
            layer_enzyme = grads_enzyme.layers[i]
            for field in (:weight, :bias, :scale)
                if hasfield(typeof(layer_flux), field)
                    @test isapprox(getfield(layer_flux, field), getfield(layer_enzyme, field))
                end
            end
            if hasfield(typeof(layer_flux), :cell)
                for field in (:Wh, :Wi, :b)
                    @test isapprox(getfield(layer_flux.cell, field), getfield(layer_enzyme.cell, field))
                end
            end
        end

@wsmoses
Copy link
Member

wsmoses commented Feb 12, 2024

I suppose then its that little test code erring then, which means maybe Enzyme works on it :) [or maybe not].

@jgreener64
Copy link
Contributor

For me the test code runs the Flux version fine but errors on the Enzyme version. The Enzyme error is attached as error.txt. I notice your stacktrace contains julia-1.10.0-rc2 Billy, I don't know if that affects things.

The gradient checking code makes assumptions and might fail to check all gradients but ideally shouldn't error itself for anything with indexable .layers since it uses hasfield.

@wsmoses
Copy link
Member

wsmoses commented Feb 12, 2024

@jgreener64 can you try on latest main (aka not release). There have been fixes for that error precisely that have landed on main since.

@jgreener64
Copy link
Contributor

Still errors for me on main (877e1d9).

@wsmoses
Copy link
Member

wsmoses commented Feb 12, 2024

Can you confirm your jll version via st? (it should be version 0.0.100)

@jgreener64
Copy link
Contributor

Yes it is with jll version 0.0.100. The error message looked similar but may have changed a bit, I can't check right now.

@wsmoses
Copy link
Member

wsmoses commented May 11, 2024

Given that many of the core points here have been resolved, I'm going to propose closing this issue unless there are any objections.

  1. Type-stable deep learning libraries. All the Enzyme tests on Flux now pass, due to both Enzyme gaining more type unstable support, as well Flux being mostly type stable now. I'm not sure what the exact status here is for Lux (e.g. Meta Issue for proper Enzyme Integration into Lux LuxDL/Lux.jl#605). But given that there is a DL library happy, I'll call this good for now (and presume others can follow suit shortly).

  2. CUBLAS/cudaMemCopy/etc rules. NNlib now has EnzymeRules for relevant functions, and we have also added cuBLAS rules in our existing blas support. There is still a need for cudaMemcpy and some Julia-side JIT CUDA fixups for Enzyme, but we need to see if the existing support is sufficient for DL with tests, and the broader CUDA runtime function support can be separated to its own issue for what fails.

  3. Scheduling. Like I said at the top this is key to good performance, and incidentally distinct from AD. I have started playing with a repo Reactant.jl (https://github.com/EnzymeAD/Reactant.jl) which aims to resolve this. It can take a julia function and compile it into MLIR and run fancy optimizations on top of it, including using EnzymeMLIR for AD, and create relevant executables for CPU/GPU/TPU via XLA. It is very much in progress, but nevertheless a problem outside Enzyme now.

@avik-pal
Copy link
Contributor

I'm not sure what the exact status here is for Lux

The tests are in, and 1 small patch was needed. The release should be available in an hour or so.

@wsmoses
Copy link
Member

wsmoses commented May 13, 2024

As mentioned, closing this now for the reasons above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants