Skip to content

Commit

Permalink
Improve support for ranges with nonstandard Integers (#27302)
Browse files Browse the repository at this point in the history
* Improve support for nonstandard ranges.

Make it so that an Integer subtype, say Position <: Integer, for which a difference is of a
different type, say Displacement <: Integer, is properly handled in UnitRange, OneTo, and StepRange.

* Remove : method that promotes step.

* Add test for #26619, fix underflow bug, fix length computation.

* Clean up steprange_last.
  • Loading branch information
tkoolen authored and mbauman committed Jun 22, 2018
1 parent 2706a8d commit 63e7fae
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 21 deletions.
36 changes: 18 additions & 18 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@

(:)(start::T, stop::T) where {T} = (:)(start, oftype(stop-start, 1), stop)

# first promote start and stop, leaving step alone
# promote start and stop, leaving step alone
(:)(start::A, step, stop::C) where {A<:Real,C<:Real} =
(:)(convert(promote_type(A,C),start), step, convert(promote_type(A,C),stop))
(:)(start::T, step::Real, stop::T) where {T<:Real} = (:)(promote(start, step, stop)...)

# AbstractFloat specializations
(:)(a::T, b::T) where {T<:AbstractFloat} = (:)(a, T(1), b)
Expand Down Expand Up @@ -143,18 +142,19 @@ function steprange_last(start::T, step, stop) where T
if (step > z) != (stop > start)
last = steprange_last_empty(start, step, stop)
else
diff = stop - start
if T<:Signed && (diff > zero(diff)) != (stop > start)
# handle overflowed subtraction with unsigned rem
if diff > zero(diff)
remain = -convert(T, unsigned(-diff) % step)
else
remain = convert(T, unsigned(diff) % step)
end
# Compute absolute value of difference between `start` and `stop`
# (to simplify handling both signed and unsigned T and checking for signed overflow):
absdiff, absstep = stop > start ? (stop - start, step) : (start - stop, -step)

# Compute remainder as a nonnegative number:
if T <: Signed && absdiff < zero(absdiff)
# handle signed overflow with unsigned rem
remain = convert(T, unsigned(absdiff) % absstep)
else
remain = steprem(start,stop,step)
remain = absdiff % absstep
end
last = stop - remain
# Move `stop` closer to `start` if there is a remainder:
last = stop > start ? stop - remain : stop + remain
end
end
last
Expand All @@ -175,8 +175,6 @@ end
# For types where x+oneunit(x) may not be well-defined
steprange_last_empty(start, step, stop) = start - step

steprem(start,stop,step) = (stop-start) % step

StepRange(start::T, step::S, stop::T) where {T,S} = StepRange{T,S}(start, step, stop)

struct UnitRange{T<:Real} <: AbstractUnitRange{T}
Expand Down Expand Up @@ -385,7 +383,7 @@ julia> step(range(2.5, stop=10.9, length=85))
```
"""
step(r::StepRange) = r.step
step(r::AbstractUnitRange{T}) where{T} = oneunit(T)
step(r::AbstractUnitRange{T}) where{T} = oneunit(T) - zero(T)
step(r::StepRangeLen{T}) where {T} = T(r.step)
step(r::LinRange) = (last(r)-first(r))/r.lendiv

Expand All @@ -399,8 +397,8 @@ function unsafe_length(r::StepRange)
isempty(r) ? zero(n) : n
end
length(r::StepRange) = unsafe_length(r)
unsafe_length(r::AbstractUnitRange) = Integer(last(r) - first(r) + 1)
unsafe_length(r::OneTo) = r.stop
unsafe_length(r::AbstractUnitRange) = Integer(last(r) - first(r) + step(r))
unsafe_length(r::OneTo) = Integer(r.stop - zero(r.stop))
length(r::AbstractUnitRange) = unsafe_length(r)
length(r::OneTo) = unsafe_length(r)
length(r::StepRangeLen) = r.len
Expand All @@ -412,8 +410,10 @@ function length(r::StepRange{T}) where T<:Union{Int,UInt,Int64,UInt64}
return checked_add(convert(T, div(unsigned(r.stop - r.start), r.step)), one(T))
elseif r.step < -1
return checked_add(convert(T, div(unsigned(r.start - r.stop), -r.step)), one(T))
elseif r.step > 0
return checked_add(div(checked_sub(r.stop, r.start), r.step), one(T))
else
checked_add(div(checked_sub(r.stop, r.start), r.step), one(T))
return checked_add(div(checked_sub(r.start, r.stop), -r.step), one(T))
end
end

Expand Down
27 changes: 24 additions & 3 deletions stdlib/Dates/src/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,30 @@ Base.length(r::StepRange{<:TimeType}) = isempty(r) ? Int64(0) : len(r.start, r.s
# Period ranges hook into Int64 overflow detection
Base.length(r::StepRange{<:Period}) = length(StepRange(value(r.start), value(r.step), value(r.stop)))

# Used to calculate the last valid date in the range given the start, stop, and step
# last = stop - steprem(start, stop, step)
Base.steprem(a::T, b::T, c) where {T<:TimeType} = b - (a + c * len(a, b, c))
# Overload Base.steprange_last because `rem` is not overloaded for `TimeType`s
function Base.steprange_last(start::T, step, stop) where T<:TimeType
if isa(step,AbstractFloat)
throw(ArgumentError("StepRange should not be used with floating point"))
end
z = zero(step)
step == z && throw(ArgumentError("step cannot be zero"))

if stop == start
last = stop
else
if (step > z) != (stop > start)
last = Base.steprange_last_empty(start, step, stop)
else
diff = stop - start
if (diff > zero(diff)) != (stop > start)
throw(OverflowError())
end
remain = stop - (start + step * len(start, stop, step))
last = stop - remain
end
end
last
end

import Base.in
function in(x::T, r::StepRange{T}) where T<:TimeType
Expand Down
77 changes: 77 additions & 0 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1285,3 +1285,80 @@ end
@test step(x) == 0.0
@test x isa StepRangeLen{Float64,Base.TwicePrecision{Float64},Base.TwicePrecision{Float64}}
end

module NonStandardIntegerRangeTest

using Test

struct Position <: Integer
val::Int
end
Position(x::Position) = x # to resolve ambiguity with boot.jl:728

struct Displacement <: Integer
val::Int
end
Displacement(x::Displacement) = x # to resolve ambiguity with boot.jl:728

Base.:-(x::Displacement) = Displacement(-x.val)
Base.:-(x::Position, y::Position) = Displacement(x.val - y.val)
Base.:-(x::Position, y::Displacement) = Position(x.val - y.val)
Base.:-(x::Displacement, y::Displacement) = Displacement(x.val - y.val)
Base.:+(x::Position, y::Displacement) = Position(x.val + y.val)
Base.:+(x::Displacement, y::Displacement) = Displacement(x.val + y.val)
Base.:(<=)(x::Position, y::Position) = x.val <= y.val
Base.:(<)(x::Position, y::Position) = x.val < y.val
Base.:(<)(x::Displacement, y::Displacement) = x.val < y.val

# for StepRange computation:
Base.Unsigned(x::Displacement) = Unsigned(x.val)
Base.rem(x::Displacement, y::Displacement) = Displacement(rem(x.val, y.val))
Base.div(x::Displacement, y::Displacement) = Displacement(div(x.val, y.val))

# required for collect (summing lengths); alternatively, should unsafe_length return Int by default?
Base.promote_rule(::Type{Displacement}, ::Type{Int}) = Int
Base.convert(::Type{Int}, x::Displacement) = x.val

@testset "Ranges with nonstandard Integers" begin
for (start, stop) in [(2, 4), (3, 3), (3, -2)]
@test collect(Position(start) : Position(stop)) == Position.(start : stop)
end

for start in [3, 0, -2]
@test collect(Base.OneTo(Position(start))) == Position.(Base.OneTo(start))
end

for step in [-3, -2, -1, 1, 2, 3]
for start in [-1, 0, 2]
for stop in [start, start - 1, start + 2 * step, start + 2 * step + 1]
r1 = StepRange(Position(start), Displacement(step), Position(stop))
@test collect(r1) == Position.(start : step : stop)

r2 = Position(start) : Displacement(step) : Position(stop)
@test r1 === r2
end
end
end
end

end # module NonStandardIntegerRangeTest

@testset "Issue #26619" begin
@test length(UInt(100) : -1 : 1) === UInt(100)
@test collect(UInt(5) : -1 : 3) == [UInt(5), UInt(4), UInt(3)]

let r = UInt(5) : -2 : 2
@test r.start === UInt(5)
@test r.step === -2
@test r.stop === UInt(3)
@test collect(r) == [UInt(5), UInt(3)]
end

for step in [-3, -2, -1, 1, 2, 3]
for start in [0, 15]
for stop in [0, 15]
@test collect(UInt(start) : step : UInt(stop)) == start : step : stop
end
end
end
end

0 comments on commit 63e7fae

Please sign in to comment.