Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move [l]gamma, [l]beta and lfact from base #92

Merged
merged 2 commits into from
Jun 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 225 additions & 0 deletions src/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,228 @@ function polygamma(m::Integer, z::Number)
# There is nothing to fallback to, since this didn't work
polygamma(m, x)
end

@static if VERSION >= v"0.7.0-alpha.69"

export gamma, lgamma, beta, lbeta, lfactorial

## from base/special/gamma.jl

gamma(x::Float64) = nan_dom_err(ccall((:tgamma,libm), Float64, (Float64,), x), x)
gamma(x::Float32) = nan_dom_err(ccall((:tgammaf,libm), Float32, (Float32,), x), x)

"""
gamma(x)

Compute the gamma function of `x`.
"""
gamma(x::Real) = gamma(float(x))

function lgamma_r(x::Float64)
signp = Ref{Int32}()
y = ccall((:lgamma_r,libm), Float64, (Float64, Ptr{Int32}), x, signp)
return y, signp[]
end
function lgamma_r(x::Float32)
signp = Ref{Int32}()
y = ccall((:lgammaf_r,libm), Float32, (Float32, Ptr{Int32}), x, signp)
return y, signp[]
end
lgamma_r(x::Real) = lgamma_r(float(x))
lgamma_r(x::Number) = lgamma(x), 1 # lgamma does not take abs for non-real x
"""
lgamma_r(x)

Return L,s such that `gamma(x) = s * exp(L)`
"""
lgamma_r

"""
lfactorial(x)

Compute the logarithmic factorial of a nonnegative integer `x`.
Equivalent to [`lgamma`](@ref) of `x + 1`, but `lgamma` extends this function
to non-integer `x`.
"""
lfactorial(x::Integer) = x < 0 ? throw(DomainError(x, "`x` must be non-negative.")) : lgamma(x + oneunit(x))
Base.@deprecate lfact lfactorial

"""
lgamma(x)

Compute the logarithm of the absolute value of [`gamma`](@ref) for
[`Real`](@ref) `x`, while for [`Complex`](@ref) `x` compute the
principal branch cut of the logarithm of `gamma(x)` (defined for negative `real(x)`
by analytic continuation from positive `real(x)`).
"""
function lgamma end

# asymptotic series for log(gamma(z)), valid for sufficiently large real(z) or |imag(z)|
@inline function lgamma_asymptotic(z::Complex{Float64})
zinv = inv(z)
t = zinv*zinv
# coefficients are bernoulli[2:n+1] .// (2*(1:n).*(2*(1:n) - 1))
return (z-0.5)*log(z) - z + 9.1893853320467274178032927e-01 + # <-- log(2pi)/2
zinv*@evalpoly(t, 8.3333333333333333333333368e-02,-2.7777777777777777777777776e-03,
7.9365079365079365079365075e-04,-5.9523809523809523809523806e-04,
8.4175084175084175084175104e-04,-1.9175269175269175269175262e-03,
6.4102564102564102564102561e-03,-2.9550653594771241830065352e-02)
end

# Compute the logΓ(z) function using a combination of the asymptotic series,
# the Taylor series around z=1 and z=2, the reflection formula, and the shift formula.
# Many details of these techniques are discussed in D. E. G. Hare,
# "Computing the principal branch of log-Gamma," J. Algorithms 25, pp. 221-236 (1997),
# and similar techniques are used (in a somewhat different way) by the
# SciPy loggamma function. The key identities are also described
# at http://functions.wolfram.com/GammaBetaErf/LogGamma/
function lgamma(z::Complex{Float64})
x = real(z)
y = imag(z)
yabs = abs(y)
if !isfinite(x) || !isfinite(y) # Inf or NaN
if isinf(x) && isfinite(y)
return Complex(x, x > 0 ? (y == 0 ? y : copysign(Inf, y)) : copysign(Inf, -y))
elseif isfinite(x) && isinf(y)
return Complex(-Inf, y)
else
return Complex(NaN, NaN)
end
elseif x > 7 || yabs > 7 # use the Stirling asymptotic series for sufficiently large x or |y|
return lgamma_asymptotic(z)
elseif x < 0.1 # use reflection formula to transform to x > 0
if x == 0 && y == 0 # return Inf with the correct imaginary part for z == 0
return Complex(Inf, signbit(x) ? copysign(oftype(x, pi), -y) : -y)
end
# the 2pi * floor(...) stuff is to choose the correct branch cut for log(sinpi(z))
return Complex(1.1447298858494001741434262, # log(pi)
copysign(6.2831853071795864769252842, y) # 2pi
* floor(0.5*x+0.25)) -
log(sinpi(z)) - lgamma(1-z)
elseif abs(x - 1) + yabs < 0.1
# taylor series around zero at z=1
# ... coefficients are [-eulergamma; [(-1)^k * zeta(k)/k for k in 2:15]]
w = Complex(x - 1, y)
return w * @evalpoly(w, -5.7721566490153286060651188e-01,8.2246703342411321823620794e-01,
-4.0068563438653142846657956e-01,2.705808084277845478790009e-01,
-2.0738555102867398526627303e-01,1.6955717699740818995241986e-01,
-1.4404989676884611811997107e-01,1.2550966952474304242233559e-01,
-1.1133426586956469049087244e-01,1.000994575127818085337147e-01,
-9.0954017145829042232609344e-02,8.3353840546109004024886499e-02,
-7.6932516411352191472827157e-02,7.1432946295361336059232779e-02,
-6.6668705882420468032903454e-02)
elseif abs(x - 2) + yabs < 0.1
# taylor series around zero at z=2
# ... coefficients are [1-eulergamma; [(-1)^k * (zeta(k)-1)/k for k in 2:12]]
w = Complex(x - 2, y)
return w * @evalpoly(w, 4.2278433509846713939348812e-01,3.2246703342411321823620794e-01,
-6.7352301053198095133246196e-02,2.0580808427784547879000897e-02,
-7.3855510286739852662729527e-03,2.8905103307415232857531201e-03,
-1.1927539117032609771139825e-03,5.0966952474304242233558822e-04,
-2.2315475845357937976132853e-04,9.9457512781808533714662972e-05,
-4.4926236738133141700224489e-05,2.0507212775670691553131246e-05)
end
# use recurrence relation lgamma(z) = lgamma(z+1) - log(z) to shift to x > 7 for asymptotic series
shiftprod = Complex(x,yabs)
x += 1
sb = false # == signbit(imag(shiftprod)) == signbit(yabs)
# To use log(product of shifts) rather than sum(logs of shifts),
# we need to keep track of the number of + to - sign flips in
# imag(shiftprod), as described in Hare (1997), proposition 2.2.
signflips = 0
while x <= 7
shiftprod *= Complex(x,yabs)
sb′ = signbit(imag(shiftprod))
signflips += sb′ & (sb′ != sb)
sb = sb′
x += 1
end
shift = log(shiftprod)
if signbit(y) # if y is negative, conjugate the shift
shift = Complex(real(shift), signflips*-6.2831853071795864769252842 - imag(shift))
else
shift = Complex(real(shift), imag(shift) + signflips*6.2831853071795864769252842)
end
return lgamma_asymptotic(Complex(x,y)) - shift
end
lgamma(z::Complex{T}) where {T<:Union{Integer,Rational}} = lgamma(float(z))
lgamma(z::Complex{T}) where {T<:Union{Float32,Float16}} = Complex{T}(lgamma(Complex{Float64}(z)))

gamma(z::Complex) = exp(lgamma(z))

"""
beta(x, y)

Euler integral of the first kind ``\\operatorname{B}(x,y) = \\Gamma(x)\\Gamma(y)/\\Gamma(x+y)``.
"""
function beta(x::Number, w::Number)
yx, sx = lgamma_r(x)
yw, sw = lgamma_r(w)
yxw, sxw = lgamma_r(x+w)
return exp(yx + yw - yxw) * (sx*sw*sxw)
end

"""
lbeta(x, y)

Natural logarithm of the absolute value of the [`beta`](@ref)
function ``\\log(|\\operatorname{B}(x,y)|)``.
"""
lbeta(x::Number, w::Number) = lgamma(x)+lgamma(w)-lgamma(x+w)

## from base/mpfr.jl

# Functions for which NaN results are converted to DomainError, following Base
function gamma(x::BigFloat)
isnan(x) && return x
z = BigFloat()
ccall((:mpfr_lgamma, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Int32), z, x, ROUNDING_MODE[])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in ccall target?

isnan(z) && throw(DomainError(x, "NaN result for non-NaN input."))
return z
end

# log of absolute value of gamma function
function lgamma_r(x::BigFloat)
z = BigFloat()
lgamma_signp = Ref{Cint}()
ccall((:mpfr_lgamma,:libmpfr), Cint, (Ref{BigFloat}, Ref{Cint}, Ref{BigFloat}, Int32), z, lgamma_signp, x, ROUNDING_MODE[])
return z, lgamma_signp[]
end

lgamma(x::BigFloat) = lgamma_r(x)[1]

if Base.MPFR.version() >= v"4.0.0"
function beta(y::BigFloat, x::BigFloat)
z = BigFloat()
ccall((:mpfr_beta, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Ref{BigFloat}, Int32), z, y, x, ROUNDING_MODE[])
return z
end
end

## from base/combinatorics.jl'

function gamma(n::Union{Int8,UInt8,Int16,UInt16,Int32,UInt32,Int64,UInt64})
n < 0 && throw(DomainError(n, "`n` must not be negative."))
n == 0 && return Inf
n <= 2 && return 1.0
n > 20 && return gamma(Float64(n))
@inbounds return Float64(Base._fact_table64[n-1])
end

## from base/math.jl

@inline lgamma(x::Float64) = nan_dom_err(ccall((:lgamma, libm), Float64, (Float64,), x), x)
@inline lgamma(x::Float32) = nan_dom_err(ccall((:lgammaf, libm), Float32, (Float32,), x), x)
@inline lgamma(x::Real) = lgamma(float(x))

## from base/numbers.jl
# TODO: deprecate instead of doing this type-piracy here?
Base.factorial(x::Number) = gamma(x + 1) # fallback for x not Integer

else # @static if

import Base: gamma, lgamma, beta, lbeta, lfact
const lfactorial = lfact
export lfactorial

end # @static if
81 changes: 81 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,84 @@ end
end

@test sprint(showerror, SF.AmosException(1)) == "AmosException with id 1: input error."

@testset "gamma and friends" begin
@testset "gamma, lgamma (complex argument)" begin
if Base.Math.libm == "libopenlibm"
@test gamma.(Float64[1:25;]) == gamma.(1:25)
else
@test gamma.(Float64[1:25;]) ≈ gamma.(1:25)
end
for elty in (Float32, Float64)
@test gamma(convert(elty,1/2)) ≈ convert(elty,sqrt(π))
@test gamma(convert(elty,-1/2)) ≈ convert(elty,-2sqrt(π))
@test lgamma(convert(elty,-1/2)) ≈ convert(elty,log(abs(gamma(-1/2))))
end
@test lgamma(1.4+3.7im) ≈ -3.7094025330996841898 + 2.4568090502768651184im
@test lgamma(1.4+3.7im) ≈ log(gamma(1.4+3.7im))
@test lgamma(-4.2+0im) ≈ lgamma(-4.2)-5pi*im
@test factorial(3.0) == gamma(4.0) == factorial(3)
for x in (3.2, 2+1im, 3//2, 3.2+0.1im)
@test factorial(x) == gamma(1+x)
end
@test lfactorial(0) == lfactorial(1) == 0
@test lfactorial(2) == lgamma(3)
# Ensure that the domain of lfactorial matches that of factorial (issue #21318)
@test_throws DomainError lfactorial(-3)
@test_throws MethodError lfactorial(1.0)
end

# lgamma test cases (from Wolfram Alpha)
@test lgamma(-300im) ≅ -473.17185074259241355733179182866544204963885920016823743 - 1410.3490664555822107569308046418321236643870840962522425im
@test lgamma(3.099) ≅ lgamma(3.099+0im) ≅ 0.786413746900558058720665860178923603134125854451168869796
@test lgamma(1.15) ≅ lgamma(1.15+0im) ≅ -0.06930620867104688224241731415650307100375642207340564554
@test lgamma(0.89) ≅ lgamma(0.89+0im) ≅ 0.074022173958081423702265889979810658434235008344573396963
@test lgamma(0.91) ≅ lgamma(0.91+0im) ≅ 0.058922567623832379298241751183907077883592982094770449167
@test lgamma(0.01) ≅ lgamma(0.01+0im) ≅ 4.599479878042021722513945411008748087261001413385289652419
@test lgamma(-3.4-0.1im) ≅ -1.1733353322064779481049088558918957440847715003659143454 + 12.331465501247826842875586104415980094316268974671819281im
@test lgamma(-13.4-0.1im) ≅ -22.457344044212827625152500315875095825738672314550695161 + 43.620560075982291551250251193743725687019009911713182478im
@test lgamma(-13.4+0.0im) ≅ conj(lgamma(-13.4-0.0im)) ≅ -22.404285036964892794140985332811433245813398559439824988 - 43.982297150257105338477007365913040378760371591251481493im
@test lgamma(-13.4+8im) ≅ -44.705388949497032519400131077242200763386790107166126534 - 22.208139404160647265446701539526205774669649081807864194im
@test lgamma(1+exp2(-20)) ≅ lgamma(1+exp2(-20)+0im) ≅ -5.504750066148866790922434423491111098144565651836914e-7
@test lgamma(1+exp2(-20)+exp2(-19)*im) ≅ -5.5047799872835333673947171235997541985495018556426e-7 - 1.1009485171695646421931605642091915847546979851020e-6im
@test lgamma(-300+2im) ≅ -1419.3444991797240659656205813341478289311980525970715668 - 932.63768120761873747896802932133229201676713644684614785im
@test lgamma(300+2im) ≅ 1409.19538972991765122115558155209493891138852121159064304 + 11.4042446282102624499071633666567192538600478241492492652im
@test lgamma(1-6im) ≅ -7.6099596929506794519956058191621517065972094186427056304 - 5.5220531255147242228831899544009162055434670861483084103im
@test lgamma(1-8im) ≅ -10.607711310314582247944321662794330955531402815576140186 - 9.4105083803116077524365029286332222345505790217656796587im
@test lgamma(1+6.5im) ≅ conj(lgamma(1-6.5im)) ≅ -8.3553365025113595689887497963634069303427790125048113307 + 6.4392816159759833948112929018407660263228036491479825744im
@test lgamma(1+1im) ≅ conj(lgamma(1-1im)) ≅ -0.6509231993018563388852168315039476650655087571397225919 - 0.3016403204675331978875316577968965406598997739437652369im
@test lgamma(-pi*1e7 + 6im) ≅ -5.10911758892505772903279926621085326635236850347591e8 - 9.86959420047365966439199219724905597399295814979993e7im
@test lgamma(-pi*1e7 + 8im) ≅ -5.10911765175690634449032797392631749405282045412624e8 - 9.86959074790854911974415722927761900209557190058925e7im
@test lgamma(-pi*1e14 + 6im) ≅ -1.0172766411995621854526383224252727000270225301426e16 - 9.8696044010873714715264929863618267642124589569347e14im
@test lgamma(-pi*1e14 + 8im) ≅ -1.0172766411995628137711690403794640541491261237341e16 - 9.8696044010867038531027376655349878694397362250037e14im
@test lgamma(2.05 + 0.03im) ≅ conj(lgamma(2.05 - 0.03im)) ≅ 0.02165570938532611215664861849215838847758074239924127515 + 0.01363779084533034509857648574107935425251657080676603919im
@test lgamma(2+exp2(-20)+exp2(-19)*im) ≅ 4.03197681916768997727833554471414212058404726357753e-7 + 8.06398296652953575754782349984315518297283664869951e-7im

@testset "lgamma for non-finite arguments" begin
@test lgamma(Inf + 0im) === Inf + 0im
@test lgamma(Inf - 0.0im) === Inf - 0.0im
@test lgamma(Inf + 1im) === Inf + Inf*im
@test lgamma(Inf - 1im) === Inf - Inf*im
@test lgamma(-Inf + 0.0im) === -Inf - Inf*im
@test lgamma(-Inf - 0.0im) === -Inf + Inf*im
@test lgamma(Inf*im) === -Inf + Inf*im
@test lgamma(-Inf*im) === -Inf - Inf*im
@test lgamma(Inf + Inf*im) === lgamma(NaN + 0im) === lgamma(NaN*im) === NaN + NaN*im
end
end

@testset "beta, lbeta" begin
@test beta(3/2,7/2) ≈ 5π/128
@test beta(3,5) ≈ 1/105
@test lbeta(5,4) ≈ log(beta(5,4))
@test beta(5,4) ≈ beta(4,5)
@test beta(-1/2, 3) ≈ beta(-1/2 + 0im, 3 + 0im) ≈ -16/3
@test lbeta(-1/2, 3) ≈ log(16/3)
@test beta(Float32(5),Float32(4)) == beta(Float32(4),Float32(5))
@test beta(3,5) ≈ beta(3+0im,5+0im)
@test(beta(3.2+0.1im,5.3+0.3im) ≈ exp(lbeta(3.2+0.1im,5.3+0.3im)) ≈
0.00634645247782269506319336871208405439180447035257028310080 -
0.00169495384841964531409376316336552555952269360134349446910im)

@test beta(big(1.0),big(1.2)) ≈ beta(1.0,1.2) rtol=4*eps()
end