diff --git a/base/fastmath.jl b/base/fastmath.jl index 7d51eba54a6d42..f14b1d24156471 100644 --- a/base/fastmath.jl +++ b/base/fastmath.jl @@ -111,6 +111,9 @@ function make_fastmath(expr::Expr) $arrvar[$(indvars...)] = $op($arrvar[$(indvars...)], $rhs) end end + elseif is_literal_int_pow(expr) + expr = Expr(expr.head, :(Base.literal_pow), + expr.args[1], expr.args[2], Val{expr.args[3]}) end Expr(make_fastmath(expr.head), map(make_fastmath, expr.args)...) end @@ -127,6 +130,10 @@ macro fastmath(expr) make_fastmath(esc(expr)) end +function is_literal_int_pow(ex::Expr) + return (ex.head === :call && length(ex.args) == 3 && + ex.args[1] === :^ && isa(ex.args[3], Integer)) +end # Basic arithmetic diff --git a/test/fastmath.jl b/test/fastmath.jl index 754e84362c5e6c..932d0ab65fc9c9 100644 --- a/test/fastmath.jl +++ b/test/fastmath.jl @@ -201,3 +201,10 @@ let a = ones(2,2), b = ones(2,2) local c = 0 @test @fastmath(c |= 1) == 1 end + +struct LitPowTest end +Base.literal_pow{p}(::typeof(Base.FastMath.pow_fast), ::LitPowTest, ::Type{Val{p}}) = 1 +Base.literal_pow{p}(::typeof(^), ::LitPowTest, ::Type{Val{p}}) = 2 +LPT = LitPowTest() +@test (@fastmath LPT^2) == 1 +@test LPT^2 == 2