Skip to content

Commit

Permalink
WIP: Add internal reverse-mode rules for ranges (#1656)
Browse files Browse the repository at this point in the history
* WIP: Add internal reverse-mode rules for ranges

This is the second PR to fix #274. It's separated as I think the forward mode one can just be merged no problem, and this one may take a little bit more time.

The crux of why this one is hard is because of how Julia deals with malformed ranges.

```
Basically dret.val = 182.0:156.0:26.0, the 26.0 is not the true value. Same as

julia> 10:1:1
10:1:9
```

Because of that behavior, the reverse `dret` does not actually have the information as to what its final point is, and its length is "incorrect" as it's changed by the constructor. In order to "fix" the reverse, we'd want to swap the `step` to negative and then use the same start/stop, but that information is already lost so it cannot be fixed within the rule. You can see the commented out code that would do the fixing if the information is there, and without that we cannot get a correctly sized reversed range for the rule.

But it's a bit puzzling to figure out how to remove that behavior. In Base Julia it seems to be done in the `function (:)(start::T, step::T, stop::T) where T<:IEEEFloat`, and as I showed in the issue, I can overload that function and the behavior goes away, but Enzyme's constructed range still has that truncation behavior, which means I missed spot or something.

namespace ConfigWidth

namespace

namespace needs_primal

namespace AugmentedReturn

* Complete implementation

* fix

* fix

---------

Co-authored-by: Billy Moses <wmoses@google.com>
  • Loading branch information
ChrisRackauckas and wsmoses authored Aug 26, 2024
1 parent c1e98c9 commit 44febc5
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
53 changes: 53 additions & 0 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,59 @@ function EnzymeRules.forward(func::Const{Colon},
end
end



function EnzymeRules.augmented_primal(config, func::Const{Colon}, ::Type{<:Active},
start::Annotation{<:AbstractFloat}, step::Annotation{<:AbstractFloat}, stop::Annotation{<:AbstractFloat})

if EnzymeRules.needs_primal(config)
primal = func.val(start.val, step.val, stop.val)
else
primal = nothing
end
return EnzymeRules.AugmentedReturn(primal, nothing, nothing)
end

function EnzymeRules.reverse(config, func::Const{Colon}, dret, tape::Nothing,
start::Annotation{T1}, step::Annotation{T2}, stop::Annotation{T3}) where {T1<:AbstractFloat, T2<:AbstractFloat, T3<:AbstractFloat}

dstart = if start isa Const
nothing
elseif EnzymeRules.width(config) == 1
T1(dret.val.ref.hi)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
T1(dret.val[i].ref.hi)
end
end

dstep = if step isa Const
nothing
elseif EnzymeRules.width(config) == 1
T2(dret.val.step.hi)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
T2(dret.val[i].step.hi)
end
end

dstop = if stop isa Const
nothing
elseif EnzymeRules.width(config) == 1
zero(T3)
else
ntuple(Val(EnzymeRules.width(config))) do i
Base.@_inline_meta
zero(T3)
end
end

return (dstart, dstep, dstop)
end


function EnzymeRules.forward(
Ty::Const{Type{BigFloat}},
RT::Type{<:Union{DuplicatedNoNeed, Duplicated, BatchDuplicated, BatchDuplicatedNoNeed}};
Expand Down
11 changes: 11 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,17 @@ end
((var"1"=75.0, var"2"=150.0),)
@test Enzyme.autodiff(Forward, f4, BatchDuplicated(0.12, (1.0, 2.0))) ==
((var"1"=0.0, var"2"=0.0),)

@test Enzyme.autodiff(Reverse, f1, Active, Active(0.1)) == ((25.0,),)
@test Enzyme.autodiff(Reverse, f2, Active, Active(0.1)) == ((25.0,),)
@test Enzyme.autodiff(Reverse, f3, Active, Active(0.1)) == ((75.0,),)
@test Enzyme.autodiff(Reverse, f4, Active, Active(0.12)) == ((0.0,),)

# Batch active rule isnt setup
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f1(x); nothing end, Active(1.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),)
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f2(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((25.0,50.0)),)
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f3(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((75.0,150.0)),)
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),)
end

end # InternalRules

0 comments on commit 44febc5

Please sign in to comment.