Skip to content

Commit

Permalink
add wrapaxes
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Jun 19, 2021
1 parent ae1b469 commit 100f8e5
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 33 deletions.
31 changes: 24 additions & 7 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,19 @@ end

getindex(r::AbstractRange, ::Colon) = copy(r)

# The result of the indexing operation r[s] should have the same indices as s
# However this is not possible to achieve in general without offset arrays
# To get aroudn this we introduce methods that are intended to be pirated by OffsetArrays
# This way it does not need to pirate getindex while dispatching on the second argument,
# which introduces a host of ambiguities

# Indexing with OneTo is guaranteed to produce correct indices
withindices(r, axs::Tuple{OneTo, Vararg{OneTo}}) = r
# This method is expected to be pirated by OffsetArrays to produce a result with the correct indices
# A package that seeks to produce a specific return type (and not participate in the piracy by OffsetArrays)
# may extend it themselves for their own types
withindices(r, axs) = r

function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integer}
@_inline_meta
@boundscheck checkbounds(r, s)
Expand All @@ -831,8 +844,9 @@ function getindex(r::AbstractUnitRange, s::AbstractUnitRange{T}) where {T<:Integ
range(first(s) ? first(r) : last(r), length = Int(last(s)))
else
f = first(r)
st = oftype(f, f + first(s)-1)
return range(st, length=length(s))
st = oftype(f, f + first(s)-firstindex(r))
ret = range(st, length=length(s))
return withindices(ret, axes(s))
end
end

Expand All @@ -849,7 +863,7 @@ function getindex(r::AbstractUnitRange, s::StepRange{T}) where {T<:Integer}
if T === Bool
range(first(s) ? first(r) : last(r), step=oneunit(eltype(r)), length = Int(last(s)))
else
st = oftype(first(r), first(r) + s.start-1)
st = oftype(first(r), first(r) + s.start-firstindex(r))
return range(st, step=step(s), length=length(s))
end
end
Expand All @@ -872,7 +886,8 @@ function getindex(r::StepRange, s::AbstractRange{T}) where {T<:Integer}
end
else
st = oftype(r.start, r.start + (first(s)-1)*step(r))
return range(st, step=step(r)*step(s), length=length(s))
ret = range(st, step=step(r)*step(s), length=length(s))
return withindices(ret, axes(s))
end
end

Expand All @@ -894,10 +909,11 @@ function getindex(r::StepRangeLen{T}, s::OrdinalRange{S}) where {T, S<:Integer}
end
else
# Find closest approach to offset by s
ind = LinearIndices(s)
ind = 1:length(s)
offset = max(min(1 + round(Int, (r.offset - first(s))/step(s)), last(ind)), first(ind))
ref = _getindex_hiprec(r, first(s) + (offset-1)*step(s))
return StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
ret = StepRangeLen{T}(ref, r.step*step(s), length(s), offset)
return withindices(ret, axes(s))
end
end

Expand All @@ -920,7 +936,8 @@ function getindex(r::LinRange{T}, s::OrdinalRange{S}) where {T, S<:Integer}
else
vfirst = unsafe_getindex(r, first(s))
vlast = unsafe_getindex(r, last(s))
return LinRange{T}(vfirst, vlast, length(s))
ret = LinRange{T}(vfirst, vlast, length(s))
return withindices(ret, axes(s))
end
end

Expand Down
5 changes: 3 additions & 2 deletions base/twiceprecision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,11 @@ function getindex(r::StepRangeLen{T,<:TwicePrecision,<:TwicePrecision}, s::Ordin
newstep = twiceprecision(r.step*step(s), nbitslen(T, length(s), soffset))
end
if ioffset == r.offset
return StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
ret = StepRangeLen(r.ref, newstep, length(s), max(1,soffset))
else
return StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
ret = StepRangeLen(r.ref + (ioffset-r.offset)*r.step, newstep, length(s), max(1,soffset))
end
return withindices(ret, axes(s))
end
end

Expand Down
19 changes: 19 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,25 @@ end
for i in axes(ax,1)
@test a[ax[i]] == a[ax][i]
end

@testset "propagate indices in vector indexing" begin
ur = 5:12
idur = Base.IdentityUnitRange(ur)
ior = OffsetArrays.IdOffsetRange(ur .- 2, 2)
for r in Any[10.0:10.0:1000.0, StepRangeLen(Float64(10), Float64(1000), 1000), LinRange(10, 1000, 991),
Base.IdentityUnitRange(0:1000)]
rur = r[ur]
ridur = r[idur]
rior = r[ior]
@test all(((x,y,z),) -> x == y == z, zip(rur, ridur, rior))
@test axes(rur) == axes(ur)
@test all(i -> r[ur][i] == rur[i], eachindex(ur))
@test axes(ridur) == axes(idur)
@test all(i -> r[idur][i] == ridur[i], eachindex(idur))
@test axes(rior) == axes(ior)
@test all(i -> r[ior][i] == rior[i], eachindex(ior))
end
end
end

@testset "show OffsetMatrix" begin
Expand Down
28 changes: 28 additions & 0 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1948,3 +1948,31 @@ end
@test_throws BoundsError r[Base.IdentityUnitRange(-1:100)]
end
end

@testset "non 1-based ranges indexing" begin
struct ZeroBasedUnitRange{T,A<:AbstractUnitRange{T}} <: AbstractUnitRange{T}
a :: A
function ZeroBasedUnitRange(a::AbstractUnitRange{T}) where {T}
@assert !Base.has_offset_axes(a)
new{T, typeof(a)}(a)
end
end

Base.parent(A::ZeroBasedUnitRange) = A.a
Base.first(A::ZeroBasedUnitRange) = first(parent(A))
Base.length(A::ZeroBasedUnitRange) = length(parent(A))
Base.last(A::ZeroBasedUnitRange) = last(parent(A))
Base.size(A::ZeroBasedUnitRange) = size(parent(A))
Base.axes(A::ZeroBasedUnitRange) = map(x -> Base.IdentityUnitRange(0:x-1), size(parent(A)))
Base.getindex(A::ZeroBasedUnitRange, i::Int) = parent(A)[i + 1]
Base.getindex(A::ZeroBasedUnitRange, i::Integer) = parent(A)[i + 1]
Base.firstindex(A::ZeroBasedUnitRange) = 0
function Base.show(io::IO, A::ZeroBasedUnitRange)
show(io, parent(A))
print(io, " with indices $(axes(A,1))")
end

r = ZeroBasedUnitRange(5:8)
@test r[0:2] == r[0]:r[2]
@test r[0:1:2] == r[0]:1:r[2]
end
57 changes: 33 additions & 24 deletions test/testhelpers/OffsetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,24 @@ function _checkindices(N::Integer, indices, label)
N == length(indices) || throw_argumenterror(N, indices, label)
end

@inline _indexedby(r::AbstractVector, ax::Tuple{Any}) = _indexedby(r, ax[1])
@inline _indexedby(r::AbstractUnitRange{<:Integer}, ::Base.OneTo) = no_offset_view(r)
@inline _indexedby(r::AbstractUnitRange{Bool}, ::Base.OneTo) = no_offset_view(r)
@inline _indexedby(r::AbstractVector, ::Base.OneTo) = no_offset_view(r)
@inline function _indexedby(r::AbstractUnitRange{<:Integer}, ax::AbstractUnitRange)
of = convert(eltype(r), first(ax) - 1)
IdOffsetRange(_subtractoffset(r, of), of)
end
@inline _indexedby(r::AbstractUnitRange{Bool}, ax::AbstractUnitRange) = OffsetArray(r, ax)
@inline _indexedby(r::AbstractVector, ax::AbstractUnitRange) = OffsetArray(r, ax)

# These functions are equivalent to the broadcasted operation r .- of
# However these ensure that the result is an AbstractRange even if a specific
# broadcasting behavior is not defined for a custom type
_subtractoffset(r::AbstractUnitRange, of) = UnitRange(first(r) - of, last(r) - of)
_subtractoffset(r::AbstractRange, of) = range(first(r) - of, stop = last(r) - of, step = step(r))

Base.withindices(r, ax::Tuple{AbstractUnitRange, Vararg{AbstractUnitRange}}) = _indexedby(r, ax)

# Technically we know the length of CartesianIndices but we need to convert it first, so here we
# don't put it in OffsetAxisKnownLength.
Expand Down Expand Up @@ -375,22 +393,6 @@ end
@propagate_inbounds Base.getindex(a::OffsetRange, r::AbstractRange) = a.parent[r .- a.offsets[1]]
@propagate_inbounds Base.getindex(a::AbstractRange, r::OffsetRange) = OffsetArray(a[parent(r)], r.offsets)

@propagate_inbounds Base.getindex(r::UnitRange, s::IIUR) =
OffsetArray(r[s.indices], s)

@propagate_inbounds Base.getindex(r::StepRange, s::IIUR) =
OffsetArray(r[s.indices], s)

# this method is needed for ambiguity resolution
@propagate_inbounds Base.getindex(r::StepRangeLen{T,<:Base.TwicePrecision,<:Base.TwicePrecision}, s::IIUR) where T =
OffsetArray(r[s.indices], s)

@propagate_inbounds Base.getindex(r::StepRangeLen{T}, s::IIUR) where {T} =
OffsetArray(r[s.indices], s)

@propagate_inbounds Base.getindex(r::LinRange, s::IIUR) =
OffsetArray(r[s.indices], s)

function Base.show(io::IO, r::OffsetRange)
show(io, r.parent)
o = r.offsets[1]
Expand Down Expand Up @@ -429,14 +431,21 @@ function Base.replace_in_print_matrix(A::OffsetArray{<:Any,1}, i::Integer, j::In
Base.replace_in_print_matrix(parent(A), ip, j, s)
end

function no_offset_view(A::AbstractArray)
if Base.has_offset_axes(A)
OffsetArray(A, map(r->1-first(r), axes(A)))
else
A
end
end

no_offset_view(A::OffsetArray) = no_offset_view(parent(A))
if isdefined(Base, :IdentityUnitRange)
# valid only if Slice is distinguished from IdentityUnitRange
no_offset_view(a::Base.Slice{<:Base.OneTo}) = a
no_offset_view(a::Base.Slice) = Base.Slice(UnitRange(a))
no_offset_view(S::SubArray) = view(parent(S), map(no_offset_view, parentindices(S))...)
end
no_offset_view(a::Array) = a
no_offset_view(i::Number) = i
no_offset_view(A::AbstractArray) = _no_offset_view(axes(A), A)
_no_offset_view(::Tuple{}, A::AbstractArray{T,0}) where T = A
_no_offset_view(::Tuple{<:Base.OneTo,Vararg{<:Base.OneTo}}, A::AbstractArray) = A
# the following method is needed for ambiguity resolution
_no_offset_view(::Tuple{<:Base.OneTo,Vararg{<:Base.OneTo}}, A::AbstractUnitRange) = A
_no_offset_view(::Any, A::AbstractArray) = OffsetArray(A, Origin(1))
_no_offset_view(::Any, A::AbstractUnitRange) = UnitRange(A)

end # module

0 comments on commit 100f8e5

Please sign in to comment.