From ff878b53e723431de3d3f12f58472e7e0d6fbdce Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 8 May 2024 00:43:53 +0200 Subject: [PATCH] add import_frule (reprised) (#1333) * Add import frule functionality * Add wip rrule importer * move everything to extension; add tests frule * remove import_rrule * runtests * esc(fn) fixes tests * added failing batchduplicated test * remove test import * address review comments * add dollar in macro * cleanup * Fixup and cleanup --------- Co-authored-by: William S. Moses Co-authored-by: Billy Moses --- Project.toml | 8 ++ ext/EnzymeChainRulesCoreExt.jl | 107 +++++++++++++++++++++ src/Enzyme.jl | 5 + test/Project.toml | 2 + test/ext/chainrulescore.jl | 70 ++++++++++++++ test/{packages => ext}/specialfunctions.jl | 0 test/runtests.jl | 29 ++++-- 7 files changed, 211 insertions(+), 10 deletions(-) create mode 100644 ext/EnzymeChainRulesCoreExt.jl create mode 100644 test/ext/chainrulescore.jl rename test/{packages => ext}/specialfunctions.jl (100%) diff --git a/Project.toml b/Project.toml index e279c5c46c74..0f5073d450bf 100644 --- a/Project.toml +++ b/Project.toml @@ -17,17 +17,25 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [weakdeps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" [extensions] EnzymeSpecialFunctionsExt = "SpecialFunctions" +EnzymeChainRulesCoreExt = "ChainRulesCore" [compat] CEnum = "0.4, 0.5" +ChainRulesCore = "1" EnzymeCore = "0.7" Enzyme_jll = "0.0.105" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26" LLVM = "6.1" ObjectFile = "0.4" Preferences = "1.4" +SpecialFunctions = "1, 2" julia = "1.6" + +[extras] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" \ No newline at end of file diff --git a/ext/EnzymeChainRulesCoreExt.jl b/ext/EnzymeChainRulesCoreExt.jl new file mode 100644 index 000000000000..2c8d180a5783 --- /dev/null +++ b/ext/EnzymeChainRulesCoreExt.jl @@ -0,0 +1,107 @@ +module EnzymeChainRulesCoreExt + +using ChainRulesCore +using EnzymeCore +using Enzyme + + +""" + import_frule(::fn, tys...) + +Automatically import a `ChainRulesCore.frule`` as a custom forward mode `EnzymeRule`. When called in batch mode, this +will end up calling the primal multiple times, which may result in incorrect behavior if the function mutates, +and slow code, always. Importing the rule from `ChainRules` is also likely to be slower than writing your own rule, +and may also be slower than not having a rule at all. + +Use with caution. + +```jldoctest +Enzyme.@import_frule(typeof(Base.sort), Any); + +x=[1.0, 2.0, 0.0]; dx=[0.1, 0.2, 0.3]; ddx = [0.01, 0.02, 0.03]; + +Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,ddx))) +Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,ddx))) +Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,))) +Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,))) + +# output + +(var"1" = [0.0, 1.0, 2.0], var"2" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02])) +(var"1" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]),) +(var"1" = [0.3, 0.1, 0.2],) +(var"1" = [0.0, 1.0, 2.0], var"2" = [0.3, 0.1, 0.2]) + +``` +""" +function Enzyme._import_frule(fn, tys...) + vals = [] + exprs = [] + primals = [] + tangents = [] + tangentsi = [] + anns = [] + for (i, ty) in enumerate(tys) + val = Symbol("arg_$i") + TA = Symbol("AN_$i") + e = :($val::$TA) + push!(anns, :($TA <: Annotation{<:$ty})) + push!(vals, val) + push!(exprs, e) + push!(primals, :($val.val)) + push!(tangents, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval)) + push!(tangentsi, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval[i])) + end + + quote + function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} + batchsize = same_or_one(1, $(vals...)) + if batchsize == 1 + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval + cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...) + if RetAnnotation <: Const + return nothing + elseif RetAnnotation <: Duplicated + return Duplicated(cres[1], cres[2]) + elseif RetAnnotation <: DuplicatedNoNeed + return cres[2]::eltype(RetAnnotation) + else + @assert false + end + else + if RetAnnotation <: Const + ntuple(Val(batchsize)) do i + Base.@_inline_meta + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) + end + return nothing + elseif RetAnnotation <: BatchDuplicated + cres1 = begin + i = 1 + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) + end + batches = ntuple(Val(batchsize-1)) do j + Base.@_inline_meta + i = j+1 + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2] + end + return BatchDuplicated(cres1[1], (cres1[2], batches...)) + elseif RetAnnotation <: BatchDuplicatedNoNeed + ntuple(Val(batchsize)) do i + Base.@_inline_meta + dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] + $ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2] + end + else + @assert false + end + end + end + end # quote +end + + +end # module \ No newline at end of file diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 0fa3f6bb9095..748ce04c0419 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -1197,5 +1197,10 @@ end mapreduce(LinearAlgebra.adjoint, vcat, rows) end +function _import_frule end # defined in EnzymeChainRulesCoreExt extension + +macro import_frule(args...) + return _import_frule(args...) +end end # module diff --git a/test/Project.toml b/test/Project.toml index 28cd57f04997..bf44952c27b9 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a" Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef" diff --git a/test/ext/chainrulescore.jl b/test/ext/chainrulescore.jl new file mode 100644 index 000000000000..217176a657a2 --- /dev/null +++ b/test/ext/chainrulescore.jl @@ -0,0 +1,70 @@ +using Enzyme +using Test +using ChainRules +using ChainRulesCore +using LinearAlgebra +using EnzymeTestUtils + +fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] + +@testset "import_frule" begin + f1(x) = 2*x + ChainRulesCore.@scalar_rule f1(x) (5*one(x),) + Enzyme.@import_frule typeof(f1) Any + @test fdiff(f1, 1f0) === 5f0 + @test fdiff(f1, 1.0) === 5.0 + + # specific signature + f2(x) = 2*x + ChainRulesCore.@scalar_rule f2(x) (5*one(x),) + Enzyme.@import_frule typeof(f2) Float32 + @test fdiff(f2, 1f0) === 5f0 + @test fdiff(f2, 1.0) === 2.0 + + # two arguments + f3(x, y) = 2*x + y + ChainRulesCore.@scalar_rule f3(x, y) (5*one(x), y) + Enzyme.@import_frule typeof(f3) Any Any + @test fdiff(x -> f3(x, 1.0), 2.) === 5.0 + @test fdiff(y -> f3(1.0, y), 2.) === 2.0 + + @testset "batch duplicated" begin + x = [1.0, 2.0, 0.0] + Enzyme.@import_frule typeof(Base.sort) Any + + test_forward(Base.sort, Duplicated, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, Duplicated, (x, DuplicatedNoNeed)) + test_forward(Base.sort, DuplicatedNoNeed, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, DuplicatedNoNeed, (x, DuplicatedNoNeed)) + test_forward(Base.sort, Const, (x, Duplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, Const, (x, DuplicatedNoNeed)) + + test_forward(Base.sort, Const, (x, Const)) + + # ChainRules does not support this case (returning notangent) + # test_forward(Base.sort, Duplicated, (x, Const)) + # test_forward(Base.sort, DuplicatedNoNeed, (x, Const)) + + test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicatedNoNeed)) + test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicatedNoNeed)) + test_forward(Base.sort, Const, (x, BatchDuplicated)) + # Unsupported by EnzymeTestUtils + # test_forward(Base.sort, Const, (x, BatchDuplicatedNoNeed)) + + # ChainRules does not support this case (returning notangent) + # test_forward(Base.sort, BatchDuplicated, (x, Const)) + # test_forward(Base.sort, BatchDuplicatedNoNeed, (x, Const)) + end +end + + + + + diff --git a/test/packages/specialfunctions.jl b/test/ext/specialfunctions.jl similarity index 100% rename from test/packages/specialfunctions.jl rename to test/ext/specialfunctions.jl diff --git a/test/runtests.jl b/test/runtests.jl index 4ca7aa45a8dc..bf7bcfee5d6a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -100,15 +100,6 @@ end include("blas.jl") end -@static if VERSION ≥ v"1.9-" - using SpecialFunctions - @testset "SpecialFunctions ext" begin - lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] - test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) - test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) - end -end - f0(x) = 1.0 + x function vrec(start, x) if start > length(x) @@ -1218,7 +1209,7 @@ end ## https://github.com/JuliaDiff/ChainRules.jl/tree/master/test/rulesets if !Sys.iswindows() - include("packages/specialfunctions.jl") + include("ext/specialfunctions.jl") end @testset "Threads" begin @@ -3032,5 +3023,23 @@ end @test res[2][5] ≈ 0 @test res[2][6] ≈ 6.0 end + +# TEST EXTENSIONS +@static if VERSION ≥ v"1.9-" + using SpecialFunctions + @testset "SpecialFunctions ext" begin + lgabsg(x) = SpecialFunctions.logabsgamma(x)[1] + test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5) + test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5) + end + + using ChainRulesCore + @testset "ChainRulesCore ext" begin + include("ext/chainrulescore.jl") + end +end + + + end