forked from EnzymeAD/Enzyme
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add import_frule (reprised) (EnzymeAD#1333)
* Add import frule functionality * Add wip rrule importer * move everything to extension; add tests frule * remove import_rrule * runtests * esc(fn) fixes tests * added failing batchduplicated test * remove test import * address review comments * add dollar in macro * cleanup * Fixup and cleanup --------- Co-authored-by: William S. Moses <gh@wsmoses.com> Co-authored-by: Billy Moses <wmoses@google.com>
- Loading branch information
1 parent
1c184e3
commit ff878b5
Showing
7 changed files
with
211 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
module EnzymeChainRulesCoreExt | ||
|
||
using ChainRulesCore | ||
using EnzymeCore | ||
using Enzyme | ||
|
||
|
||
""" | ||
import_frule(::fn, tys...) | ||
Automatically import a `ChainRulesCore.frule`` as a custom forward mode `EnzymeRule`. When called in batch mode, this | ||
will end up calling the primal multiple times, which may result in incorrect behavior if the function mutates, | ||
and slow code, always. Importing the rule from `ChainRules` is also likely to be slower than writing your own rule, | ||
and may also be slower than not having a rule at all. | ||
Use with caution. | ||
```jldoctest | ||
Enzyme.@import_frule(typeof(Base.sort), Any); | ||
x=[1.0, 2.0, 0.0]; dx=[0.1, 0.2, 0.3]; ddx = [0.01, 0.02, 0.03]; | ||
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,ddx))) | ||
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,ddx))) | ||
Enzyme.autodiff(Forward, sort, DuplicatedNoNeed, BatchDuplicated(x, (dx,))) | ||
Enzyme.autodiff(Forward, sort, Duplicated, BatchDuplicated(x, (dx,))) | ||
# output | ||
(var"1" = [0.0, 1.0, 2.0], var"2" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02])) | ||
(var"1" = (var"1" = [0.3, 0.1, 0.2], var"2" = [0.03, 0.01, 0.02]),) | ||
(var"1" = [0.3, 0.1, 0.2],) | ||
(var"1" = [0.0, 1.0, 2.0], var"2" = [0.3, 0.1, 0.2]) | ||
``` | ||
""" | ||
function Enzyme._import_frule(fn, tys...) | ||
vals = [] | ||
exprs = [] | ||
primals = [] | ||
tangents = [] | ||
tangentsi = [] | ||
anns = [] | ||
for (i, ty) in enumerate(tys) | ||
val = Symbol("arg_$i") | ||
TA = Symbol("AN_$i") | ||
e = :($val::$TA) | ||
push!(anns, :($TA <: Annotation{<:$ty})) | ||
push!(vals, val) | ||
push!(exprs, e) | ||
push!(primals, :($val.val)) | ||
push!(tangents, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval)) | ||
push!(tangentsi, :($val isa Const ? $ChainRulesCore.NoTangent() : $val.dval[i])) | ||
end | ||
|
||
quote | ||
function EnzymeRules.forward(fn::FA, ::Type{RetAnnotation}, $(exprs...); kwargs...) where {RetAnnotation, FA<:Annotation{<:$(esc(fn))}, $(anns...)} | ||
batchsize = same_or_one(1, $(vals...)) | ||
if batchsize == 1 | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval | ||
cres = $ChainRulesCore.frule((dfn, $(tangents...),), fn.val, $(primals...); kwargs...) | ||
if RetAnnotation <: Const | ||
return nothing | ||
elseif RetAnnotation <: Duplicated | ||
return Duplicated(cres[1], cres[2]) | ||
elseif RetAnnotation <: DuplicatedNoNeed | ||
return cres[2]::eltype(RetAnnotation) | ||
else | ||
@assert false | ||
end | ||
else | ||
if RetAnnotation <: Const | ||
ntuple(Val(batchsize)) do i | ||
Base.@_inline_meta | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] | ||
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) | ||
end | ||
return nothing | ||
elseif RetAnnotation <: BatchDuplicated | ||
cres1 = begin | ||
i = 1 | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] | ||
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...) | ||
end | ||
batches = ntuple(Val(batchsize-1)) do j | ||
Base.@_inline_meta | ||
i = j+1 | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] | ||
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2] | ||
end | ||
return BatchDuplicated(cres1[1], (cres1[2], batches...)) | ||
elseif RetAnnotation <: BatchDuplicatedNoNeed | ||
ntuple(Val(batchsize)) do i | ||
Base.@_inline_meta | ||
dfn = fn isa Const ? $ChainRulesCore.NoTangent() : fn.dval[i] | ||
$ChainRulesCore.frule((dfn, $(tangentsi...),), fn.val, $(primals...); kwargs...)[2] | ||
end | ||
else | ||
@assert false | ||
end | ||
end | ||
end | ||
end # quote | ||
end | ||
|
||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
using Enzyme | ||
using Test | ||
using ChainRules | ||
using ChainRulesCore | ||
using LinearAlgebra | ||
using EnzymeTestUtils | ||
|
||
fdiff(f, x::Number) = autodiff(Forward, f, Duplicated, Duplicated(x, one(x)))[2] | ||
|
||
@testset "import_frule" begin | ||
f1(x) = 2*x | ||
ChainRulesCore.@scalar_rule f1(x) (5*one(x),) | ||
Enzyme.@import_frule typeof(f1) Any | ||
@test fdiff(f1, 1f0) === 5f0 | ||
@test fdiff(f1, 1.0) === 5.0 | ||
|
||
# specific signature | ||
f2(x) = 2*x | ||
ChainRulesCore.@scalar_rule f2(x) (5*one(x),) | ||
Enzyme.@import_frule typeof(f2) Float32 | ||
@test fdiff(f2, 1f0) === 5f0 | ||
@test fdiff(f2, 1.0) === 2.0 | ||
|
||
# two arguments | ||
f3(x, y) = 2*x + y | ||
ChainRulesCore.@scalar_rule f3(x, y) (5*one(x), y) | ||
Enzyme.@import_frule typeof(f3) Any Any | ||
@test fdiff(x -> f3(x, 1.0), 2.) === 5.0 | ||
@test fdiff(y -> f3(1.0, y), 2.) === 2.0 | ||
|
||
@testset "batch duplicated" begin | ||
x = [1.0, 2.0, 0.0] | ||
Enzyme.@import_frule typeof(Base.sort) Any | ||
|
||
test_forward(Base.sort, Duplicated, (x, Duplicated)) | ||
# Unsupported by EnzymeTestUtils | ||
# test_forward(Base.sort, Duplicated, (x, DuplicatedNoNeed)) | ||
test_forward(Base.sort, DuplicatedNoNeed, (x, Duplicated)) | ||
# Unsupported by EnzymeTestUtils | ||
# test_forward(Base.sort, DuplicatedNoNeed, (x, DuplicatedNoNeed)) | ||
test_forward(Base.sort, Const, (x, Duplicated)) | ||
# Unsupported by EnzymeTestUtils | ||
# test_forward(Base.sort, Const, (x, DuplicatedNoNeed)) | ||
|
||
test_forward(Base.sort, Const, (x, Const)) | ||
|
||
# ChainRules does not support this case (returning notangent) | ||
# test_forward(Base.sort, Duplicated, (x, Const)) | ||
# test_forward(Base.sort, DuplicatedNoNeed, (x, Const)) | ||
|
||
test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicated)) | ||
# Unsupported by EnzymeTestUtils | ||
# test_forward(Base.sort, BatchDuplicated, (x, BatchDuplicatedNoNeed)) | ||
test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicated)) | ||
# Unsupported by EnzymeTestUtils | ||
# test_forward(Base.sort, BatchDuplicatedNoNeed, (x, BatchDuplicatedNoNeed)) | ||
test_forward(Base.sort, Const, (x, BatchDuplicated)) | ||
# Unsupported by EnzymeTestUtils | ||
# test_forward(Base.sort, Const, (x, BatchDuplicatedNoNeed)) | ||
|
||
# ChainRules does not support this case (returning notangent) | ||
# test_forward(Base.sort, BatchDuplicated, (x, Const)) | ||
# test_forward(Base.sort, BatchDuplicatedNoNeed, (x, Const)) | ||
end | ||
end | ||
|
||
|
||
|
||
|
||
|
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters