Skip to content

Commit

Permalink
promote ranges to the largest of the the start, step, or length (as a…
Browse files Browse the repository at this point in the history
…pplicable) (#43059)

Be careful to use `oneunit` instead of `1`, so that arithmetic on
user-given types does not promote first to Int.

Fixes #35711
Fixes #10554
  • Loading branch information
vtjnash authored Jan 12, 2022
1 parent 08ff456 commit fd8b2ab
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 72 deletions.
1 change: 1 addition & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,7 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::OrdinalRange, x::AbstractFloat) =
Base.range_start_step_length(first(r)*x, step(r)*x, length(r))

#broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, last(r)/x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset)
Expand Down
4 changes: 2 additions & 2 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ function _hypot(x, y)

# Return Inf if either or both inputs is Inf (Compliance with IEEE754)
if isinf(ax) || isinf(ay)
return oftype(axu, Inf)
return typeof(axu)(Inf)
end

# Order the operands
Expand Down Expand Up @@ -745,7 +745,7 @@ _hypot(x::ComplexF16, y::ComplexF16) = Float16(_hypot(ComplexF32(x), ComplexF32(
function _hypot(x...)
maxabs = maximum(abs, x)
if isnan(maxabs) && any(isinf, x)
return oftype(maxabs, Inf)
return typeof(maxabs)(Inf)
elseif (iszero(maxabs) || isinf(maxabs))
return maxabs
else
Expand Down
112 changes: 56 additions & 56 deletions base/range.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

(:)(a::Real, b::Real) = (:)(promote(a,b)...)
(:)(a::Real, b::Real) = (:)(promote(a, b)...)

(:)(start::T, stop::T) where {T<:Real} = UnitRange{T}(start, stop)

(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop >= start ? stop - start : start - stop, 1), stop)

# promote start and stop, leaving step alone
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
(:)(convert(promote_type(A,C),start), step, convert(promote_type(A,C),stop))
(:)(start::A, step, stop::C) where {A<:Real, C<:Real} =
(:)(convert(promote_type(A, C), start), step, convert(promote_type(A, C), stop))

# AbstractFloat specializations
(:)(a::T, b::T) where {T<:AbstractFloat} = (:)(a, T(1), b)

(:)(a::T, b::AbstractFloat, c::T) where {T<:Real} = (:)(promote(a,b,c)...)
(:)(a::T, b::AbstractFloat, c::T) where {T<:AbstractFloat} = (:)(promote(a,b,c)...)
(:)(a::T, b::Real, c::T) where {T<:AbstractFloat} = (:)(promote(a,b,c)...)
(:)(a::T, b::AbstractFloat, c::T) where {T<:Real} = (:)(promote(a, b, c)...)
(:)(a::T, b::AbstractFloat, c::T) where {T<:AbstractFloat} = (:)(promote(a, b, c)...)
(:)(a::T, b::Real, c::T) where {T<:AbstractFloat} = (:)(promote(a, b, c)...)

(:)(start::T, step::T, stop::T) where {T<:AbstractFloat} =
_colon(OrderStyle(T), ArithmeticStyle(T), start, step, stop)
Expand Down Expand Up @@ -167,49 +167,39 @@ range_length(len::Integer) = OneTo(len)
range_stop(stop) = range_start_stop(oftype(stop, 1), stop)
range_stop(stop::Integer) = range_length(stop)

# Stop and length as the only argument
range_stop_length(a::Real, len::Integer) = UnitRange{typeof(a)}(oftype(a, a-len+1), a)
range_stop_length(a::AbstractFloat, len::Integer) = range_step_stop_length(oftype(a, 1), a, len)
range_stop_length(a, len::Integer) = range_step_stop_length(oftype(a-a, 1), a, len)

range_step_stop_length(step, stop, length) = reverse(range_start_step_length(stop, -step, length))

range_start_length(a::Real, len::Integer) = UnitRange{typeof(a)}(a, oftype(a, a+len-1))
range_start_length(a::AbstractFloat, len::Integer) = range_start_step_length(a, oftype(a, 1), len)
range_start_length(a, len::Integer) = range_start_step_length(a, oftype(a-a, 1), len)

range_start_stop(start, stop) = start:stop

function range_start_step_length(a::AbstractFloat, step::AbstractFloat, len::Integer)
range_start_step_length(promote(a, step)..., len)
end

function range_start_step_length(a::Real, step::AbstractFloat, len::Integer)
range_start_step_length(float(a), step, len)
end

function range_start_step_length(a::AbstractFloat, step::Real, len::Integer)
range_start_step_length(a, float(step), len)
# Stop and length as the only argument
function range_stop_length(a, len::Integer)
step = oftype(a - a, 1) # assert that step is representable
start = a - (len - oneunit(len))
if start isa Signed
# overflow in recomputing length from stop is okay
return UnitRange(start, oftype(start, a))
end
return range_step_stop_length(oftype(a - a, 1), a, len)
end

function range_start_step_length(a::T, step::T, len::Integer) where {T <: AbstractFloat}
_rangestyle(OrderStyle(T), ArithmeticStyle(T), a, step, len)
# Start and length as the only argument
function range_start_length(a, len::Integer)
step = oftype(a - a, 1) # assert that step is representable
stop = a + (len - oneunit(len))
if stop isa Signed
# overflow in recomputing length from stop is okay
return UnitRange(oftype(stop, a), stop)
end
return range_start_step_length(a, oftype(a - a, 1), len)
end

function range_start_step_length(a::T, step, len::Integer) where {T}
_rangestyle(OrderStyle(T), ArithmeticStyle(T), a, step, len)
end
range_start_stop(start, stop) = start:stop

function _rangestyle(::Ordered, ::ArithmeticWraps, a, step, len::Integer)
start = a + zero(step)
stop = a + step * (len - 1)
T = typeof(start)
return StepRange{T,typeof(step)}(start, step, convert(T, stop))
end
function _rangestyle(::Any, ::Any, a, step, len::Integer)
start = a + zero(step)
T = typeof(a)
return StepRangeLen{typeof(start),T,typeof(step)}(a, step, len)
function range_start_step_length(a, step, len::Integer)
stop = a + step * (len - oneunit(len))
if stop isa Signed
# overflow in recomputing length from stop is okay
return StepRange{typeof(stop),typeof(step)}(convert(typeof(stop), a), step, stop)
end
return StepRangeLen{typeof(stop),typeof(a),typeof(step)}(a, step, len)
end

range_start_step_stop(start, step, stop) = start:step:stop
Expand Down Expand Up @@ -893,7 +883,7 @@ _in_unit_range(v::UnitRange, val, i::Integer) = i > 0 && val <= v.stop && val >=
function getindex(v::UnitRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = convert(T, v.start + (i - 1))
val = convert(T, v.start + (i - oneunit(i)))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val
end
Expand All @@ -904,7 +894,7 @@ const OverflowSafe = Union{Bool,Int8,Int16,Int32,Int64,Int128,
function getindex(v::UnitRange{T}, i::Integer) where {T<:OverflowSafe}
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
val = v.start + (i - 1)
val = v.start + (i - oneunit(i))
@boundscheck _in_unit_range(v, val, i) || throw_boundserror(v, i)
val % T
end
Expand All @@ -919,7 +909,7 @@ end
function getindex(v::AbstractRange{T}, i::Integer) where T
@inline
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
ret = convert(T, first(v) + (i - 1)*step_hp(v))
ret = convert(T, first(v) + (i - oneunit(i))*step_hp(v))
ok = ifelse(step(v) > zero(step(v)),
(ret <= last(v)) & (ret >= first(v)),
(ret <= first(v)) & (ret >= last(v)))
Expand Down Expand Up @@ -949,13 +939,14 @@ end

function unsafe_getindex(r::LinRange, i::Integer)
i isa Bool && throw(ArgumentError("invalid index: $i of type Bool"))
lerpi(i-1, r.lendiv, r.start, r.stop)
lerpi(i-oneunit(i), r.lendiv, r.start, r.stop)
end

function lerpi(j::Integer, d::Integer, a::T, b::T) where T
@inline
t = j/d
T((1-t)*a + t*b)
t = j/d # ∈ [0,1]
# compute approximately fma(t, b, -fma(t, a, a))
return T((1-t)*a + t*b)
end

getindex(r::AbstractRange, ::Colon) = copy(r)
Expand All @@ -968,8 +959,10 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
return range(first(s) ? first(r) : last(r), length = last(s))
else
f = first(r)
start = oftype(f, f + first(s)-firstindex(r))
return range(start, length=length(s))
start = oftype(f, f + first(s) - firstindex(r))
len = length(s)
stop = oftype(f, start + (len - oneunit(len)))
return range(start, stop)
end
end

Expand All @@ -984,11 +977,14 @@ function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
@boundscheck checkbounds(r, s)

if T === Bool
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = last(s))
return range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length=last(s))
else
f = first(r)
start = oftype(f, f + s.start-firstindex(r))
return range(start, step=step(s), length=length(s))
start = oftype(f, f + s.start - firstindex(r))
st = step(s)
len = length(s)
stop = oftype(f, start + (len - oneunit(len)) * st)
return range(start, stop; step=st)
end
end

Expand All @@ -1011,9 +1007,13 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
return range(start, step=step(r); length=len)
else
f = r.start
fs = first(s)
st = r.step
start = oftype(f, f + (first(s)-oneunit(first(s)))*st)
return range(start; step=st*step(s), length=length(s))
start = oftype(f, f + (fs - oneunit(fs)) * st)
st = st * step(s)
len = length(s)
stop = oftype(f, start + (len - oneunit(len)) * st)
return range(start, stop; step=st)
end
end

Expand Down Expand Up @@ -1042,7 +1042,7 @@ function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
# Find closest approach to offset by s
ind = LinearIndices(s)
offset = L(max(min(1 + round(L, (r.offset - first(s))/sstep), last(ind)), first(ind)))
ref = _getindex_hiprec(r, first(s) + (offset-1)*sstep)
ref = _getindex_hiprec(r, first(s) + (offset - oneunit(offset)) * sstep)
return StepRangeLen{T}(ref, rstep*sstep, len, offset)
end
end
Expand Down
6 changes: 2 additions & 4 deletions stdlib/Dates/test/periods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,8 @@ end
@test_throws InexactError y * 3//4
@test (1:1:5)*Second(5) === Second(5)*(1:1:5) === Second(5):Second(5):Second(25) === (1:5)*Second(5)
@test collect(1:1:5)*Second(5) == Second(5)*collect(1:1:5) == (1:5)*Second(5)
@test (Second(2):Second(2):Second(10))/Second(2) === 1.0:1.0:5.0
@test collect(Second(2):Second(2):Second(10))/Second(2) == 1:1:5
@test (Second(2):Second(2):Second(10)) / 2 === Second(1):Second(1):Second(5)
@test collect(Second(2):Second(2):Second(10)) / 2 == Second(1):Second(1):Second(5)
@test (Second(2):Second(2):Second(10))/Second(2) === 1.0:1.0:5.0 == collect(Second(2):Second(2):Second(10))/Second(2)
@test (Second(2):Second(2):Second(10)) / 2 == Second(1):Second(1):Second(5) == collect(Second(2):Second(2):Second(10)) / 2
@test Dates.Year(4) / 2 == Dates.Year(2)
@test Dates.Year(4) / 2f0 == Dates.Year(2)
@test Dates.Year(4) / 0.5 == Dates.Year(8)
Expand Down
18 changes: 12 additions & 6 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,12 +375,18 @@ using .Main.Furlongs
@test one(S) isa SymTridiagonal

# eltype with dimensions
D = Diagonal{Furlong{2, Int64}}([1, 2, 3, 4])
Bu = Bidiagonal{Furlong{2, Int64}}([1, 2, 3, 4], [1, 2, 3], 'U')
Bl = Bidiagonal{Furlong{2, Int64}}([1, 2, 3, 4], [1, 2, 3], 'L')
T = Tridiagonal{Furlong{2, Int64}}([1, 2, 3], [1, 2, 3, 4], [1, 2, 3])
S = SymTridiagonal{Furlong{2, Int64}}([1, 2, 3, 4], [1, 2, 3])
mats = [D, Bu, Bl, T, S]
D0 = Diagonal{Furlong{0, Int64}}([1, 2, 3, 4])
Bu0 = Bidiagonal{Furlong{0, Int64}}([1, 2, 3, 4], [1, 2, 3], 'U')
Bl0 = Bidiagonal{Furlong{0, Int64}}([1, 2, 3, 4], [1, 2, 3], 'L')
T0 = Tridiagonal{Furlong{0, Int64}}([1, 2, 3], [1, 2, 3, 4], [1, 2, 3])
S0 = SymTridiagonal{Furlong{0, Int64}}([1, 2, 3, 4], [1, 2, 3])
F2 = Furlongs.Furlong{2}(1)
D2 = Diagonal{Furlong{2, Int64}}([1, 2, 3, 4].*F2)
Bu2 = Bidiagonal{Furlong{2, Int64}}([1, 2, 3, 4].*F2, [1, 2, 3].*F2, 'U')
Bl2 = Bidiagonal{Furlong{2, Int64}}([1, 2, 3, 4].*F2, [1, 2, 3].*F2, 'L')
T2 = Tridiagonal{Furlong{2, Int64}}([1, 2, 3].*F2, [1, 2, 3, 4].*F2, [1, 2, 3].*F2)
S2 = SymTridiagonal{Furlong{2, Int64}}([1, 2, 3, 4].*F2, [1, 2, 3].*F2)
mats = Any[D0, Bu0, Bl0, T0, S0, D2, Bu2, Bl2, T2, S2]
for A in mats
@test iszero(zero(A))
@test isone(one(A))
Expand Down
22 changes: 20 additions & 2 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1525,8 +1525,10 @@ isdefined(Main, :Furlongs) || @eval Main include("testhelpers/Furlongs.jl")
using .Main.Furlongs

@testset "dimensional correctness" begin
@test length(Vector(Furlong(2):Furlong(10))) == 9
@test length(range(Furlong(2), length=9)) == checked_length(range(Furlong(2), length=9)) == 9
@test_throws TypeError Furlong(2):Furlong(10)
@test_throws TypeError range(Furlong(2), length=9)
@test length(Vector(Furlong(2):Furlong(1):Furlong(10))) == 9
@test length(range(Furlong(2), step=Furlong(1), length=9)) == checked_length(range(Furlong(2), step=Furlong(1), length=9)) == 9
@test @inferred(length(StepRange(Furlong(2), Furlong(1), Furlong(1)))) == 0
@test Vector(Furlong(2):Furlong(1):Furlong(10)) == Vector(range(Furlong(2), step=Furlong(1), length=9)) == Furlong.(2:10)
@test Vector(Furlong(1.0):Furlong(0.5):Furlong(10.0)) ==
Expand Down Expand Up @@ -2265,3 +2267,19 @@ let r = Ptr{Cvoid}(20):-UInt(2):Ptr{Cvoid}(10)
@test step(r) === -UInt(2)
@test last(r) === Ptr{Cvoid}(10)
end

# test behavior of wrap-around and promotion of empty ranges (#35711)
@test length(range(0, length=UInt(0))) === UInt(0)
@test isempty(range(0, length=UInt(0)))
@test length(range(typemax(Int), length=UInt(0))) === UInt(0)
@test isempty(range(typemax(Int), length=UInt(0)))
@test length(range(0, length=UInt(0), step=UInt(2))) == UInt(0)
@test isempty(range(0, length=UInt(0), step=UInt(2)))
@test length(range(typemax(Int), length=UInt(0), step=UInt(2))) === UInt(0)
@test isempty(range(typemax(Int), length=UInt(0), step=UInt(2)))
@test length(range(typemax(Int), length=UInt(0), step=2)) === UInt(0)
@test isempty(range(typemax(Int), length=UInt(0), step=2))
@test length(range(typemax(Int), length=0, step=UInt(2))) === 0
@test isempty(range(typemax(Int), length=0, step=UInt(2)))

@test length(range(1, length=typemax(Int128))) === typemax(Int128)
17 changes: 15 additions & 2 deletions test/testhelpers/Furlongs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,26 @@ end
Furlong(x::T) where {T<:Number} = Furlong{1,T}(x)
Furlong(x::Furlong) = x
(::Type{T})(x::Furlong{0}) where {T<:Number} = T(x.val)::T
(::Type{T})(x::Furlong{0}) where {T<:Furlong{0}} = T(x.val)::T
(::Type{T})(x::Furlong{0}) where {T<:Furlong} = typeassert(x, T)
Furlong{p}(v::Number) where {p} = Furlong{p,typeof(v)}(v)
Furlong{p}(x::Furlong{q}) where {p,q} = (@assert(p==q); Furlong{p,typeof(x.val)}(x.val))
Furlong{p,T}(x::Furlong{q}) where {T,p,q} = (@assert(p==q); Furlong{p,T}(T(x.val)))
Furlong{p}(x::Furlong{q}) where {p,q} = (typeassert(x, Furlong{p}); Furlong{p,typeof(x.val)}(x.val))
Furlong{p,T}(x::Furlong{q}) where {T,p,q} = (typeassert(x, Furlong{p}); Furlong{p,T}(T(x.val)))

Base.promote_type(::Type{Furlong{p,T}}, ::Type{Furlong{p,S}}) where {p,T,S} =
Furlong{p,promote_type(T,S)}

# only Furlong{0} forms a ring and isa Number
Base.convert(::Type{T}, y::Number) where {T<:Furlong{0}} = T(y)
Base.convert(::Type{Furlong}, y::Number) = Furlong{0}(y)
Base.convert(::Type{Furlong{<:Any,T}}, y::Number) where {T<:Number} = Furlong{0,T}(y)
Base.convert(::Type{T}, y::Number) where {T<:Furlong} = typeassert(y, T) # throws, since cannot convert a Furlong{0} to a Furlong{p}
# other Furlong{p} form a group
Base.convert(::Type{T}, y::Furlong) where {T<:Furlong{0}} = T(y)
Base.convert(::Type{Furlong}, y::Furlong) = y
Base.convert(::Type{Furlong{<:Any,T}}, y::Furlong{p}) where {p,T<:Number} = Furlong{p,T}(y)
Base.convert(::Type{T}, y::Furlong) where {T<:Furlong} = T(y)

Base.one(x::Furlong{p,T}) where {p,T} = one(T)
Base.one(::Type{Furlong{p,T}}) where {p,T} = one(T)
Base.oneunit(x::Furlong{p,T}) where {p,T} = Furlong{p,T}(one(T))
Expand Down

0 comments on commit fd8b2ab

Please sign in to comment.