Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conv gradient is not implemented in EnzymeJAX #214

Closed
yolhan83 opened this issue Nov 1, 2024 · 9 comments
Closed

conv gradient is not implemented in EnzymeJAX #214

yolhan83 opened this issue Nov 1, 2024 · 9 comments

Comments

@yolhan83
Copy link

yolhan83 commented Nov 1, 2024

Hello, I wonder if Reactant works with Conv layers, it seems it works in forward but not in the gradient pass, neither on cpu of gpu

version :

Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 20 × 12th Gen Intel(R) Core(TM) i7-12700H
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, alderlake)
Threads: 1 default, 0 interactive, 1 GC (on 20 virtual cores)

code :

import Enzyme
using Lux,Reactant,Statistics,Random

const rng = Random.default_rng(123)
const dev = xla_device()
model_conv = Chain(
    Conv((3,3),1=>8,pad=SamePad(),relu), 
    Lux.FlattenLayer(),
    Dense(32*32*8,10),
    softmax
)
model_dense = Chain(
    Lux.FlattenLayer(),
    Dense(32*32*1,10),
    softmax
)

ps_conv,st_conv = Lux.setup(rng, model_conv) |> dev
ps_dense,st_dense = Lux.setup(rng, model_dense) |> dev

loss(model,x,ps,st,y) = Lux.MSELoss()(first(model(x,ps,st)),y)
x = randn(rng, Float32, 32,32,1,100) |> dev
y = randn(rng, Float32, 10,100) |> dev

function get_grad(loss,model,x,ps,st,y)
    dps = Enzyme.make_zero(ps)
    Enzyme.autodiff(Enzyme.Reverse,loss,Enzyme.Const(model),Enzyme.Const(x),Enzyme.Duplicated(ps,dps),Enzyme.Const(st),Enzyme.Const(y))
    return dps
end

loss_compile_conv = @compile loss(model_conv,x,ps_conv,st_conv,y) # works
loss_compile_dense = @compile loss(model_dense,x,ps_dense,st_dense,y) # works

grad_compile_conv = @compile get_grad(loss,model_conv,x,ps_conv,st_conv,y) #doesn't work
grad_compile_dense = @compile get_grad(loss,model_dense,x,ps_dense,st_dense,y) # works

error :

error: expects input feature dimension (8) / feature_group_count = kernel input feature dimension (1). Got feature_group_count = 1.
ERROR: "failed to run pass manager on module"
Stacktrace:
  [1] run!
    @ ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Pass.jl:70 [inlined]
  [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:241
  [3] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:272
  [4] compile_mlir!
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:256 [inlined]
  [5] (::Reactant.Compiler.var"#30#32"{typeof(get_grad), Tuple{…}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:584
  [6] context!(f::Reactant.Compiler.var"#30#32"{typeof(get_grad), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/rRa4g/src/mlir/IR/Context.jl:71
  [7] compile_xla(f::Function, args::Tuple{…}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:581
  [8] compile_xla
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:575 [inlined]
  [9] compile(f::Function, args::Tuple{…}; client::Nothing)
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:608
 [10] compile(f::Function, args::Tuple{…})
    @ Reactant.Compiler ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:607
 [11] top-level scope
    @ ~/.julia/packages/Reactant/rRa4g/src/Compiler.jl:368
@avik-pal
Copy link
Collaborator

avik-pal commented Nov 1, 2024

the feature_group_count is probably missing on EnzymeJAX end?

cc @Pangoraw

@Pangoraw
Copy link
Collaborator

Pangoraw commented Nov 1, 2024

The reverse is not implemented for convolution. It should be fixed in Enzyme-JAX

@yolhan83
Copy link
Author

yolhan83 commented Nov 1, 2024

Oh ok I will wait for doing my MISNT benchmark then, have a nice day

@wsmoses
Copy link
Member

wsmoses commented Nov 2, 2024

@Pangoraw do you want to bump the commits of EnzymeJaX and reactant and bump the jll

Setup needs to bump:
https://github.com/EnzymeAD/Enzyme-JAX/blob/6109edcd3f1d04c9f19f7a30d0e493ffae5ba417/workspace.bzl#L4
Then

ENZYMEXLA_COMMIT = "2f1a70349297a21ce67f41cc94ff305dd0aef5d4"

Then
https://github.com/JuliaPackaging/Yggdrasil/blob/ff9587a0df76e20171efd10076a59395c4fad5dd/R/Reactant/build_tarballs.jl#L12

@mofeing
Copy link
Collaborator

mofeing commented Nov 2, 2024

@wsmoses @Pangoraw can we do a Reactant release after the new JLL lands?

@avik-pal avik-pal changed the title How to use Reactant on Conv layers conv gradient is not implemented in EnzymeJAX Nov 8, 2024
@wsmoses
Copy link
Member

wsmoses commented Nov 22, 2024

@Pangoraw
Copy link
Collaborator

I meant that there is a rule in https://github.com/EnzymeAD/Enzyme-JAX/blob/724be0666e0f9fbba66f6b2fddbdb222b0db5fb6/src/enzyme_ad/jax/Implementations/HLODerivatives.td#L588 but it generates invalid code instead of returning an unimplemented error.

@wsmoses
Copy link
Member

wsmoses commented Nov 23, 2024

Ah I read your comment of “it should be fixed In the other repo” as it is fixed but needs a jll bump rather than your intended, we should fix it but the bug is in another repo

@avik-pal
Copy link
Collaborator

avik-pal commented Dec 6, 2024

This is fixed in the latest release

@avik-pal avik-pal closed this as completed Dec 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants