Skip to content

Commit

Permalink
Add optimized arithmethic with rationals and integers
Browse files Browse the repository at this point in the history
  • Loading branch information
Joel-Dahne committed Aug 20, 2023
1 parent 9a10de8 commit dc93618
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
15 changes: 12 additions & 3 deletions src/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@ for (jf, af) in [(:+, :add!), (:-, :sub!), (:*, :mul!), (:/, :div!)]
@eval Base.$jf(x::AcbOrRef, y::Union{AcbOrRef,ArbOrRef,_BitInteger}) =
$af(Acb(prec = _precision(x, y)), x, y)

# Avoid one allocation for operations on irrationals
@eval function Base.$jf(x::Union{ArbOrRef,AcbOrRef}, y::Irrational)
# Avoid one allocation for operations on irrationals or BitInteger
# rationals
@eval function Base.$jf(
x::Union{ArbOrRef,AcbOrRef},
y::Union{Irrational,Rational{<:_BitInteger}},
)
z = zero(x)
z[] = y
return $af(z, x, z)
Expand All @@ -33,7 +37,12 @@ for (jf, af) in [(:+, :add!), (:-, :sub!), (:*, :mul!), (:/, :div!)]

@eval Base.$jf(x::Union{ArbOrRef,_BitInteger,Irrational}, y::AcbOrRef) = $jf(y, x)
else
@eval function Base.$jf(x::Irrational, y::Union{ArbOrRef,AcbOrRef})
# Subtraction and division also need optimizations for when
# left argument is an integer
@eval function Base.$jf(
x::Union{Irrational,Rational{<:_BitInteger},_BitInteger},
y::Union{ArbOrRef,AcbOrRef},
)
z = zero(y)
z[] = x
return $af(z, z, y)
Expand Down
18 changes: 16 additions & 2 deletions test/arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
end

@testset "$T" for T in [Arb, Acb]
# +, -, *, /
@test T(1) + T(2) ==
T(1) + 2 ==
2 + T(1) ==
Expand All @@ -82,9 +83,22 @@
T(1) + UInt8(2) ==
UInt8(2) + T(1) ==
3
@test T(1) - T(2) == T(1) - 2 == T(1) - UInt(2) == T(-1)
@test T(1) - T(2) == T(1) - 2 == T(1) - UInt(2) == 1 - T(2) == T(-1)
@test T(2) * T(3) == T(2) * 3 == 3 * T(2) == T(2) * UInt(3) == UInt(3) * T(2) == 6
@test T(6) / T(2) == T(6) / 2 == T(6) / UInt(2) == 3
@test T(6) / T(2) == T(6) / 2 == T(6) / UInt(2) == 6 / T(2) == 3

@test isequal(T(1) + π, T(1) + T(π))
@test isequal+ T(1), T(1) + T(π))
@test T(1) + 3 // 2 == 3 // 2 + T(1) == 5 // 2
@test isequal(T(1) - π, T(1) - T(π))
@test isequal- T(1), T(π) - T(1))
@test T(1) - 3 // 2 == -(3 // 2 - T(1)) == -1 // 2
@test isequal(T(2) * π, T(2) * T(π))
@test isequal* T(2), T(π) * T(2))
@test T(2) * 3 // 2 == 3 // 2 * T(2) == 3
@test isequal(T(2) / π, T(2) / T(π))
@test isequal/ T(2), T(π) / T(2))
@test T(4) / 4 // 1 == 4 // 1 / T(4) == 1

# ^
@test Base.literal_pow(^, T(2), Val(-2)) ==
Expand Down

0 comments on commit dc93618

Please sign in to comment.