We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
julia> using Lux, Random, Enzyme, Reactant julia> layer = Conv((3, 3), 4 => 8; groups=4) Conv((3, 3), 4 => 8, groups=4) # 80 parameters julia> ps, st = Lux.setup(Random.default_rng(), layer) |> Reactant.to_rarray; julia> x = rand(Float32, 4, 4, 4, 1) |> Reactant.to_rarray; julia> function run_conv_grad(x, w, gn) cdims = DenseConvDims(x, w; groups=gn) Enzyme.gradient(Enzyme.set_abi(Reverse, Reactant.ReactantABI), (args...) -> sum(conv(args...)), x, w, Const(DenseConvDims(x, w; groups=gn))) end run_conv_grad (generic function with 1 method) julia> @code_hlo run_conv_grad(x, ps.weight, 4) error: expects input feature dimension (8) / feature_group_count = kernel input feature dimension (8). Got feature_group_count = 4. ERROR: "failed to run pass manager on module" Stacktrace: [1] run! @ /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:79 [inlined] [2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool) @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:263 [3] run_pass_pipeline! @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:258 [inlined] [4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{Float32, 4}, ConcreteRArray{Float32, 4}, Int64}; optimize::Bool) @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:309 [5] compile_mlir! @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:289 [inlined] [6] #6 @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:284 [inlined] [7] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{optimize::Bool}, typeof(run_conv_grad), Tuple{ConcreteRArray{Float32, 4}, ConcreteRArray{Float32, 4}, Int64}}, ctx::Reactant.MLIR.IR.Context) @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76 [8] compile_mlir(f::Function, args::Tuple{ConcreteRArray{Float32, 4}, ConcreteRArray{Float32, 4}, Int64}; kwargs::@Kwargs{optimize::Bool}) @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:282 [9] top-level scope @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:474 [10] top-level scope @ none:1 julia> @code_hlo optimize=false run_conv_grad(x, ps.weight, 4) module { func.func private @identity_broadcast_scalar(%arg0: tensor<f32>) -> tensor<f32> { %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32> %1 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32> return %1 : tensor<f32> } func.func private @"Const{var\22#10#11\22}(var\22#10#11\22())_autodiff"(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x3x3xf32>) -> (tensor<f32>, tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>) { %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<1x4x4x4xf32>) -> tensor<4x4x4x1xf32> %1 = stablehlo.transpose %arg1, dims = [3, 2, 1, 0] : (tensor<8x1x3x3xf32>) -> tensor<3x3x1x8xf32> %cst = stablehlo.constant dense<0.000000e+00> : tensor<2x2x8x1xf32> %2 = stablehlo.broadcast_in_dim %0, dims = [0, 1, 2, 3] : (tensor<4x4x4x1xf32>) -> tensor<4x4x4x1xf32> %3 = stablehlo.broadcast_in_dim %1, dims = [0, 1, 2, 3] : (tensor<3x3x1x8xf32>) -> tensor<3x3x1x8xf32> %4 = stablehlo.reverse %3, dims = [0, 1] : tensor<3x3x1x8xf32> %5 = stablehlo.convolution(%2, %4) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<4x4x4x1xf32>, tensor<3x3x1x8xf32>) -> tensor<2x2x8x1xf32> %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32> %6 = stablehlo.broadcast_in_dim %5, dims = [0, 1, 2, 3] : (tensor<2x2x8x1xf32>) -> tensor<2x2x8x1xf32> %7 = enzyme.batch @identity_broadcast_scalar(%6) {batch_shape = array<i64: 2, 2, 8, 1>} : (tensor<2x2x8x1xf32>) -> tensor<2x2x8x1xf32> %8 = stablehlo.reduce(%7 init: %cst_0) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<2x2x8x1xf32>, tensor<f32>) -> tensor<f32> %9 = stablehlo.transpose %8, dims = [] : (tensor<f32>) -> tensor<f32> %10 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<4x4x4x1xf32>) -> tensor<1x4x4x4xf32> %11 = stablehlo.transpose %1, dims = [3, 2, 1, 0] : (tensor<3x3x1x8xf32>) -> tensor<8x1x3x3xf32> return %9, %10, %11 : tensor<f32>, tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32> } func.func @main(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x3x3xf32>) -> (tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>, tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>) { %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<1x4x4x4xf32>) -> tensor<4x4x4x1xf32> %1 = stablehlo.transpose %arg1, dims = [3, 2, 1, 0] : (tensor<8x1x3x3xf32>) -> tensor<3x3x1x8xf32> %cst = stablehlo.constant dense<0.000000e+00> : tensor<4x4x4x1xf32> %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<3x3x1x8xf32> %cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<f32> %2 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<4x4x4x1xf32>) -> tensor<1x4x4x4xf32> %3 = stablehlo.transpose %1, dims = [3, 2, 1, 0] : (tensor<3x3x1x8xf32>) -> tensor<8x1x3x3xf32> %4 = stablehlo.transpose %cst_1, dims = [] : (tensor<f32>) -> tensor<f32> %5 = stablehlo.transpose %cst, dims = [3, 2, 1, 0] : (tensor<4x4x4x1xf32>) -> tensor<1x4x4x4xf32> %6 = stablehlo.transpose %cst_0, dims = [3, 2, 1, 0] : (tensor<3x3x1x8xf32>) -> tensor<8x1x3x3xf32> %7:4 = enzyme.autodiff @"Const{var\22#10#11\22}(var\22#10#11\22())_autodiff"(%2, %3, %4, %5, %6) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>]} : (tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>, tensor<f32>, tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>) -> (tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>, tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>) %8 = stablehlo.transpose %7#0, dims = [3, 2, 1, 0] : (tensor<1x4x4x4xf32>) -> tensor<4x4x4x1xf32> %9 = stablehlo.transpose %7#1, dims = [3, 2, 1, 0] : (tensor<8x1x3x3xf32>) -> tensor<3x3x1x8xf32> %10 = stablehlo.transpose %7#2, dims = [3, 2, 1, 0] : (tensor<1x4x4x4xf32>) -> tensor<4x4x4x1xf32> %11 = stablehlo.transpose %7#3, dims = [3, 2, 1, 0] : (tensor<8x1x3x3xf32>) -> tensor<3x3x1x8xf32> %12 = stablehlo.transpose %10, dims = [3, 2, 1, 0] : (tensor<4x4x4x1xf32>) -> tensor<1x4x4x4xf32> %13 = stablehlo.transpose %11, dims = [3, 2, 1, 0] : (tensor<3x3x1x8xf32>) -> tensor<8x1x3x3xf32> %14 = stablehlo.transpose %8, dims = [3, 2, 1, 0] : (tensor<4x4x4x1xf32>) -> tensor<1x4x4x4xf32> %15 = stablehlo.transpose %9, dims = [3, 2, 1, 0] : (tensor<3x3x1x8xf32>) -> tensor<8x1x3x3xf32> return %12, %13, %14, %15 : tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32>, tensor<1x4x4x4xf32>, tensor<8x1x3x3xf32> } }
The text was updated successfully, but these errors were encountered:
julia> @code_hlo conv(x, ps.weight, DenseConvDims(x, ps.weight; groups=4)) module { func.func @main(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x3x3xf32>) -> tensor<1x8x2x2xf32> { %0 = stablehlo.transpose %arg1, dims = [3, 2, 1, 0] : (tensor<8x1x3x3xf32>) -> tensor<3x3x1x8xf32> %1 = stablehlo.reverse %0, dims = [0, 1] : tensor<3x3x1x8xf32> %2 = stablehlo.convolution(%arg0, %1) dim_numbers = [b, f, 1, 0]x[0, 1, i, o]->[b, f, 1, 0], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 4 : i64} : (tensor<1x4x4x4xf32>, tensor<3x3x1x8xf32>) -> tensor<1x8x2x2xf32> return %2 : tensor<1x8x2x2xf32> } }
Sorry, something went wrong.
@Pangoraw since iirc you added the conv fix earlier
Pangoraw
Successfully merging a pull request may close this issue.
The text was updated successfully, but these errors were encountered: