Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Enzyme support for fastpow #1072

Merged
merged 10 commits into from
Sep 28, 2024
Merged
6 changes: 4 additions & 2 deletions ext/DiffEqBaseEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DiffEqBaseEnzymeExt

using DiffEqBase
import DiffEqBase: value
import DiffEqBase: value, fastpow
using Enzyme
import Enzyme: Const
using ChainRulesCore
Expand Down Expand Up @@ -53,4 +53,6 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.RevConfigWidth{1}
return ntuple(_ -> nothing, Val(length(args) + 4))
end

end
Enzyme.Compiler.known_ops[typeof(DiffEqBase.fastpow)] = (:pow, 2, nothing)

end
66 changes: 13 additions & 53 deletions src/fastpow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,60 +51,20 @@ const EXP2FT = (Float32(0x1.6a09e667f3bcdp-1),
Float32(0x1.3dea64c123422p+0),
Float32(0x1.4bfdad5362a27p+0),
Float32(0x1.5ab07dd485429p+0))
@inline function _exp2(x::Float32)
TBLBITS = UInt32(4)
TBLSIZE = UInt32(1 << TBLBITS)

redux = Float32(0x1.8p23) / TBLSIZE
P1 = Float32(0x1.62e430p-1)
P2 = Float32(0x1.ebfbe0p-3)
P3 = Float32(0x1.c6b348p-5)
P4 = Float32(0x1.3b2c9cp-7)

# Reduce x, computing z, i0, and k.
t::Float32 = x + redux
i0 = reinterpret(UInt32, t)
i0 += TBLSIZE ÷ UInt32(2)
k::UInt32 = unsafe_trunc(UInt32, (i0 >> TBLBITS) << 20)
i0 &= TBLSIZE - UInt32(1)
t -= redux
z = x - t
twopk = Float32(reinterpret(Float64, UInt64(0x3ff00000 + k) << 32))

# Compute r = exp2(y) = exp2ft[i0] * p(z).
tv = EXP2FT[i0 + UInt32(1)]
u = tv * z
tv = tv + u * (P1 + z * P2) + u * (z * z) * (P3 + z * P4)

# Scale by 2**(k>>20)
return tv * twopk
end

if VERSION < v"1.7.0"
"""
fastpow(x::Real, y::Real) -> Float32
"""
@inline function fastpow(x::Real, y::Real)
if iszero(x)
return 0.0f0
elseif isinf(x) && isinf(y)
return Float32(Inf)
else
return _exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
end
end
else
"""
fastpow(x::Real, y::Real) -> Float32
"""
@inline function fastpow(x::Real, y::Real)
if iszero(x)
return 0.0f0
elseif isinf(x) && isinf(y)
return Float32(Inf)
else
return @fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x)))
end
"""
fastpow(x::T, y::T) where {T} -> float(T)
Trips through Float32 for performance.
"""
@inline function fastpow(x::T, y::T) where {T}
outT = float(T)
if iszero(x)
return zero(outT)
elseif isinf(x) && isinf(y)
return convert(outT,Inf)
else
return convert(outT,@fastmath exp2(convert(Float32, y) * fastlog2(convert(Float32, x))))
end
end

@inline fastpow(x, y) = x^y
1 change: 1 addition & 0 deletions test/downstream/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
Expand Down
23 changes: 23 additions & 0 deletions test/downstream/enzyme.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Enzyme, EnzymeTestUtils
using DiffEqBase: fastlog2, fastpow
using Test

@testset "Fast pow - Enzyme forward rule" begin
@testset for RT in (Duplicated, DuplicatedNoNeed),
Tx in (Const, Duplicated),
Ty in (Const, Duplicated)
x = 3.0
y = 2.0
test_forward(fastpow, RT, (x, Tx), (y, Ty), atol=0.005, rtol=0.005)
end
end

@testset "Fast pow - Enzyme reverse rule" begin
@testset for RT in (Active,),
Tx in (Active,),
Ty in (Active,)
x = 2.0
y = 3.0
test_reverse(fastpow, RT, (x, Tx), (y, Ty), atol=0.001, rtol=0.001)
end
end
14 changes: 4 additions & 10 deletions test/fastpow.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DiffEqBase: fastlog2, _exp2, fastpow
using DiffEqBase: fastlog2, fastpow
using Test

@testset "Fast log2" begin
Expand All @@ -7,15 +7,9 @@ using Test
end
end

@testset "Exp2" begin
for x in -100:0.01:3
@test exp2(x)≈_exp2(Float32(x)) atol=1e-6
end
end

@testset "Fast pow" begin
@test fastpow(1, 1) isa Float32
@test fastpow(1.0, 1.0) isa Float32
@test fastpow(1, 1) isa Float64
@test fastpow(1.0, 1.0) isa Float64
errors = [abs(^(x, y) - fastpow(x, y)) for x in 0.001:0.001:1, y in 0.08:0.001:0.5]
@test maximum(errors) < 1e-4
end
end
Loading