Skip to content

Commit

Permalink
add import_frule (reprised) (EnzymeAD#1333)
Browse files Browse the repository at this point in the history
* Add import frule functionality

* Add wip rrule importer

* move everything to extension; add tests frule

* remove import_rrule

* runtests

* esc(fn) fixes tests

* added failing batchduplicated test

* remove test import

* address review comments

* add dollar in macro

* cleanup

* Fixup and cleanup

---------

Co-authored-by: William S. Moses <gh@wsmoses.com>
Co-authored-by: Billy Moses <wmoses@google.com>
  • Loading branch information
3 people authored May 7, 2024
1 parent 1c184e3 commit ff878b5
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 10 deletions.
8 changes: 8 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,25 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
EnzymeSpecialFunctionsExt = "SpecialFunctions"
EnzymeChainRulesCoreExt = "ChainRulesCore"

[compat]
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.7"
Enzyme_jll = "0.0.105"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1"
ObjectFile = "0.4"
Preferences = "1.4"
SpecialFunctions = "1, 2"
julia = "1.6"

[extras]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
107 changes: 107 additions & 0 deletions ext/EnzymeChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
module EnzymeChainRulesCoreExt

using ChainRulesCore
using EnzymeCore
using Enzyme


"""
import_frule(::fn, tys...)
Automatically import a `ChainRulesCore.frule`` as a custom forward mode `EnzymeRule`. When called in batch mode, this
will end up calling the primal multiple times, which may result in incorrect behavior if the function mutates,
and slow code, always. Importing the rule from `ChainRules` is also likely to be slower than writing your own rule,
and may also be slower than not having a rule at all.
Use with caution.
```jldoctest
Enzyme.@import_frule(typeof(Base.sort), Any);
x=[1.0, 2.0, 0.0]; dx=[0.1, 0.2, 0.3]; ddx = [0.01, 0.02, 0.03];
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,ddx)))
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,ddx)))
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,)))
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,)))
# output
(var"1" = [0.0, 1.0, 2.0], var"2" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]))
(var"1" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]),)
(var"1" = [0.3, 0.1, 0.2],)
(var"1" = [0.0, 1.0, 2.0], var"2" = [0.3, 0.1, 0.2])
```
"""
function Enzyme._import_frule(fn, tys...)
vals = []
exprs = []
primals = []
tangents = []
tangentsi = []
anns = []
for (i, ty) in enumerate(tys)
val = Symbol("arg_$i")
TA = Symbol("AN_$i")
e = :($val::$TA)
push!(anns, :($TA <: Annotation{<:$ty}))
push!(vals, val)
push!(exprs, e)
push!(primals, :($val.val))
push!(tangents, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval))
push!(tangentsi, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval[i]))
end

quote
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)}
batchsize = same_or_one(1, $(vals...))
if batchsize == 1
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval
cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...)
if RetAnnotation <: Const
return nothing
elseif RetAnnotation <: Duplicated
return Duplicated(cres[1], cres[2])
elseif RetAnnotation <: DuplicatedNoNeed
return cres[2]::eltype(RetAnnotation)
else
@assert false
end
else
if RetAnnotation <: Const
ntuple(Val(batchsize)) do i
Base.@_inline_meta
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
end
return nothing
elseif RetAnnotation <: BatchDuplicated
cres1 = begin
i = 1
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)
end
batches = ntuple(Val(batchsize-1)) do j
Base.@_inline_meta
i = j+1
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2]
end
return BatchDuplicated(cres1[1], (cres1[2], batches...))
elseif RetAnnotation <: BatchDuplicatedNoNeed
ntuple(Val(batchsize)) do i
Base.@_inline_meta
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i]
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2]
end
else
@assert false
end
end
end
end # quote
end


end # module
5 changes: 5 additions & 0 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1197,5 +1197,10 @@ end
mapreduce(LinearAlgebra.adjoint, vcat, rows)
end

function _import_frule end # defined in EnzymeChainRulesCoreExt extension

macro import_frule(args...)
return _import_frule(args...)
end

end # module
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
Expand Down
70 changes: 70 additions & 0 deletions test/ext/chainrulescore.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using Enzyme
using Test
using ChainRules
using ChainRulesCore
using LinearAlgebra
using EnzymeTestUtils

fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2]

@testset "import_frule" begin
f1(x) = 2*x
ChainRulesCore.@scalar_rule f1(x) (5*one(x),)
Enzyme.@import_frule typeof(f1) Any
@test fdiff(f1, 1f0) === 5f0
@test fdiff(f1, 1.0) === 5.0

# specific signature
f2(x) = 2*x
ChainRulesCore.@scalar_rule f2(x) (5*one(x),)
Enzyme.@import_frule typeof(f2) Float32
@test fdiff(f2, 1f0) === 5f0
@test fdiff(f2, 1.0) === 2.0

# two arguments
f3(x, y) = 2*x + y
ChainRulesCore.@scalar_rule f3(x, y) (5*one(x), y)
Enzyme.@import_frule typeof(f3) Any Any
@test fdiff(x -> f3(x, 1.0), 2.) === 5.0
@test fdiff(y -> f3(1.0, y), 2.) === 2.0

@testset "batch duplicated" begin
x = [1.0, 2.0, 0.0]
Enzyme.@import_frule typeof(Base.sort) Any

test_forward(Base.sort, Duplicated, (x, Duplicated))
# Unsupported by EnzymeTestUtils
# test_forward(Base.sort, Duplicated, (x, DuplicatedNoNeed))
test_forward(Base.sort, DuplicatedNoNeed, (x, Duplicated))
# Unsupported by EnzymeTestUtils
# test_forward(Base.sort, DuplicatedNoNeed, (x, DuplicatedNoNeed))
test_forward(Base.sort, Const, (x, Duplicated))
# Unsupported by EnzymeTestUtils
# test_forward(Base.sort, Const, (x, DuplicatedNoNeed))

test_forward(Base.sort, Const, (x, Const))

# ChainRules does not support this case (returning notangent)
# test_forward(Base.sort, Duplicated, (x, Const))
# test_forward(Base.sort, DuplicatedNoNeed, (x, Const))

test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicated))
# Unsupported by EnzymeTestUtils
# test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicatedNoNeed))
test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated))
# Unsupported by EnzymeTestUtils
# test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicatedNoNeed))
test_forward(Base.sort, Const, (x, BatchDuplicated))
# Unsupported by EnzymeTestUtils
# test_forward(Base.sort, Const, (x, BatchDuplicatedNoNeed))

# ChainRules does not support this case (returning notangent)
# test_forward(Base.sort, BatchDuplicated, (x, Const))
# test_forward(Base.sort, BatchDuplicatedNoNeed, (x, Const))
end
end





File renamed without changes.
29 changes: 19 additions & 10 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,6 @@ end
include("blas.jl")
end

@static if VERSION v"1.9-"
using SpecialFunctions
@testset "SpecialFunctions ext" begin
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1]
test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5)
test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5)
end
end

f0(x) = 1.0 + x
function vrec(start, x)
if start > length(x)
Expand Down Expand Up @@ -1218,7 +1209,7 @@ end

## https://github.com/JuliaDiff/ChainRules.jl/tree/master/test/rulesets
if !Sys.iswindows()
include("packages/specialfunctions.jl")
include("ext/specialfunctions.jl")
end

@testset "Threads" begin
Expand Down Expand Up @@ -3032,5 +3023,23 @@ end
@test res[2][5] 0
@test res[2][6] 6.0
end

# TEST EXTENSIONS
@static if VERSION v"1.9-"
using SpecialFunctions
@testset "SpecialFunctions ext" begin
lgabsg(x) = SpecialFunctions.logabsgamma(x)[1]
test_scalar(lgabsg, 1.0; rtol = 1.0e-5, atol = 1.0e-5)
test_scalar(lgabsg, 1.0f0; rtol = 1.0e-5, atol = 1.0e-5)
end

using ChainRulesCore
@testset "ChainRulesCore ext" begin
include("ext/chainrulescore.jl")
end
end



end

0 comments on commit ff878b5

Please sign in to comment.