Skip to content

Commit

Permalink
[Containers] use OrderedDict as the data structure for SparseAxisArray (
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Feb 29, 2024
1 parent 65a4946 commit 1195695
Show file tree
Hide file tree
Showing 8 changed files with 116 additions and 38 deletions.
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MarkdownAST = "d0879d2d-cac2-40c8-9cee-1863dc0c7391"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
MultiObjectiveAlgorithms = "0327d340-17cd-11ea-3e99-2fd5d98cecda"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
15 changes: 15 additions & 0 deletions docs/src/manual/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,28 @@ JuMP.Containers.SparseAxisArray{Tuple{Int64, Symbol}, 1, Tuple{Int64}} with 2 en

Use `eachindex` to loop over the elements:
```jldoctest containers_sparse
julia> for key in eachindex(x)
println(x[key])
end
(2, :A)
(2, :B)
(3, :A)
(3, :B)
julia> for key in eachindex(y)
println(y[key])
end
(2, :B)
(3, :B)
```

!!! warning
If you use a macro to construct a `SparseAxisArray`, then the iteration
order is row-major, that is, indices are varied from right to left. As an
example, when iterating over `x` above, the `j` index is iterated, keeping
`i` constant. This order is in contrast to `Base.Array`s, which iterate in
column-major order, that is, by varying indices from left to right.

### Broadcasting

Broadcasting over a SparseAxisArray returns a SparseAxisArray
Expand Down
2 changes: 1 addition & 1 deletion src/Containers/Containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ necessarily integers.
"""
module Containers

import Base.Meta.isexpr
import OrderedCollections

# Arbitrary typed indices. Linear indexing not supported.
struct IndexAnyCartesian <: Base.IndexStyle end
Expand Down
60 changes: 39 additions & 21 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,29 @@

"""
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
data::Dict{K,T}
data::OrderedCollections.OrderedDict{K,T}
end
`N`-dimensional array with elements of type `T` where only a subset of the
entries are defined. The entries with indices `idx = (i1, i2, ..., iN)` in
`keys(data)` has value `data[idx]`. Note that as opposed to
`SparseArrays.AbstractSparseArray`, the missing entries are not assumed to be
`zero(T)`, they are simply not part of the array. This means that the result of
`map(f, sa::SparseAxisArray)` or `f.(sa::SparseAxisArray)` has the same sparsity
structure than `sa` even if `f(zero(T))` is not zero.
`keys(data)` has value `data[idx]`.
Note that, as opposed to `SparseArrays.AbstractSparseArray`, the missing entries
are not assumed to be `zero(T)`, they are simply not part of the array. This
means that the result of `map(f, sa::SparseAxisArray)` or
`f.(sa::SparseAxisArray)` has the same sparsity structure as `sa`, even if
`f(zero(T))` is not zero.
## Example
```jldoctest
julia> dict = Dict((:a, 2) => 1.0, (:a, 3) => 2.0, (:b, 3) => 3.0)
Dict{Tuple{Symbol, Int64}, Float64} with 3 entries:
julia> using OrderedCollections: OrderedDict
julia> dict = OrderedDict((:a, 2) => 1.0, (:a, 3) => 2.0, (:b, 3) => 3.0)
OrderedDict{Tuple{Symbol, Int64}, Float64} with 3 entries:
(:a, 2) => 1.0
(:a, 3) => 2.0
(:b, 3) => 3.0
(:a, 2) => 1.0
julia> array = Containers.SparseAxisArray(dict)
SparseAxisArray{Float64, 2, Tuple{Symbol, Int64}} with 3 entries:
Expand All @@ -36,15 +40,26 @@ julia> array[:b, 3]
```
"""
struct SparseAxisArray{T,N,K<:NTuple{N,Any}} <: AbstractArray{T,N}
data::Dict{K,T}
data::OrderedCollections.OrderedDict{K,T}
names::NTuple{N,Symbol}
end

function SparseAxisArray(d::Dict{K,T}) where {T,N,K<:NTuple{N,Any}}
return SparseAxisArray(d, ntuple(n -> Symbol("#$n"), N))
function SparseAxisArray(
d::AbstractDict{K,T},
names::NTuple{N,Symbol},
) where {T,N,K<:NTuple{N,Any}}
# convert(OrderedCollections.OrderedDict{K,T}, d) is deprecated, so use an
# iterator to get all key-value pairs.
od = OrderedCollections.OrderedDict{K,T}(k => v for (k, v) in d)
return SparseAxisArray(od, names)
end

SparseAxisArray(d::Dict, ::Nothing) = SparseAxisArray(d)
function SparseAxisArray(
d::AbstractDict{K,T},
::Nothing = nothing,
) where {T,N,K<:NTuple{N,Any}}
return SparseAxisArray(d, ntuple(n -> Symbol("#$n"), N))
end

Base.length(sa::SparseAxisArray) = length(sa.data)

Expand All @@ -71,7 +86,7 @@ function Base.similar(
::Type{T},
length::Integer = 0,
) where {S,T,N,K}
d = Dict{K,T}()
d = OrderedCollections.OrderedDict{K,T}()
if !iszero(length)
sizehint!(d, length)
end
Expand Down Expand Up @@ -165,7 +180,7 @@ function Base.getindex(
end
K2 = _sliced_key_type(K, args...)
if K2 !== nothing
new_data = Dict{K2,T}(
new_data = OrderedCollections.OrderedDict{K2,T}(
_sliced_key(k, args) => v for (k, v) in d.data if _filter(k, args)
)
names = _sliced_key_name(K, d.names, args...)
Expand Down Expand Up @@ -293,12 +308,16 @@ end
function Base.copy(
bc::Base.Broadcast.Broadcasted{BroadcastStyle{N,K}},
) where {N,K}
dict = Dict(index => _getindex(bc, index) for index in _indices(bc.args...))
if isempty(dict) && dict isa Dict{Any,Any}
dict = OrderedCollections.OrderedDict(
index => _getindex(bc, index) for index in _indices(bc.args...)
)
if isempty(dict) && dict isa OrderedCollections.OrderedDict{Any,Any}
# If `dict` is empty (e.g., because there are no indices), then
# inference will produce a `Dict{Any,Any}`, and we won't have enough
# type information to call SparseAxisArray(dict). As a work-around, we
# explicitly construct the type of the resulting SparseAxisArray.
# inference will produce a `OrderedCollections.OrderedDict{Any,Any}`,
# and we won't have enough type information to call
# `SparseAxisArray(dict)`. As a work-around, we explicitly construct the
# type of the resulting SparseAxisArray.
#
# For more, see JuMP issue #2867.
return SparseAxisArray{Any,N,K}(dict, ntuple(n -> Symbol("#$n"), N))
end
Expand Down Expand Up @@ -448,7 +467,6 @@ function Base.show(io::IOContext, x::SparseAxisArray)
(i, (key, value)) in enumerate(x.data) if
i < half_screen_rows || i > length(x) - half_screen_rows
]
sort!(key_strings; by = x -> x[1])
pad = maximum(length(x[1]) for x in key_strings)
for (i, (key, value)) in enumerate(key_strings)
print(io, " [", rpad(key, pad), "] = ", value)
Expand Down
25 changes: 19 additions & 6 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,21 @@ function container(
end
# Same as `map` but does not allocate the resulting vector.
mappings = Base.Generator(I -> I => f(I...), indices)
# Same as `Dict(mapping)` but it will error if two indices are the same.
# Same as `OrderedCollections.OrderedDict(mapping)`, but it will error if
# two indices are the same.
data = NoDuplicateDict(mappings)
return _sparseaxisarray(data.dict, f, indices, names)
end

# The NoDuplicateDict was able to infer the element type.
_sparseaxisarray(dict::Dict, ::Any, ::Any, names) = SparseAxisArray(dict, names)
function _sparseaxisarray(
dict::OrderedCollections.OrderedDict,
::Any,
::Any,
names,
)
return SparseAxisArray(dict, names)
end

# @default_eltype succeeded and inferred a tuple of the appropriate size!
# Use `return_types` to get the value type of the dictionary.
Expand All @@ -159,13 +167,13 @@ function _container_dict(
) where {N}
ret = Base.return_types(f, K)
V = length(ret) == 1 ? first(ret) : Any
return Dict{K,V}()
return OrderedCollections.OrderedDict{K,V}()
end

# @default_eltype bailed and returned Any. Use an NTuple of Any of the
# appropriate size intead.
function _container_dict(::Any, ::Any, K::Type{<:NTuple{N,Any}}) where {N}
return Dict{K,Any}()
return OrderedCollections.OrderedDict{K,Any}()
end

# @default_eltype bailed and returned Union{}. Use an NTuple of Any of the
Expand All @@ -176,7 +184,7 @@ function _container_dict(
::Function,
K::Type{<:NTuple{N,Any}},
) where {N}
return Dict{K,Any}()
return OrderedCollections.OrderedDict{K,Any}()
end

# Calling `@default_eltye` on `x` isn't sufficient, because the iterator may
Expand All @@ -189,7 +197,12 @@ _default_eltype(x) = Base.@default_eltype x
# best-guess attempt, collect all of the keys excluding the conditional
# statement (these must be defined, because the conditional applies to the
# lowest-level of the index loops), then get the eltype of the result.
function _sparseaxisarray(dict::Dict{Any,Any}, f, indices, names)
function _sparseaxisarray(
dict::OrderedCollections.OrderedDict{Any,Any},
f,
indices,
names,
)
@assert isempty(dict)
d = _container_dict(_default_eltype(indices), f, _eltype_or_any(indices))
return SparseAxisArray(d, names)
Expand Down
21 changes: 15 additions & 6 deletions src/Containers/no_duplicate_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,47 @@

"""
struct NoDuplicateDict{K, V} <: AbstractDict{K, V}
dict::Dict{K, V}
dict::OrderedCollections.OrderedDict{K, V}
end
Same as `Dict{K, V}` but errors if constructed from an iterator with duplicate
keys.
Same as `OrderedCollections.OrderedDict{K, V}` but errors if constructed from an
iterator with duplicate keys.
"""
struct NoDuplicateDict{K,V} <: AbstractDict{K,V}
dict::Dict{K,V}
NoDuplicateDict{K,V}() where {K,V} = new{K,V}(Dict{K,V}())
dict::OrderedCollections.OrderedDict{K,V}

function NoDuplicateDict{K,V}() where {K,V}
return new{K,V}(OrderedCollections.OrderedDict{K,V}())
end
end

# Implementation of the `AbstractDict` API.
function Base.empty(::NoDuplicateDict, ::Type{K}, ::Type{V}) where {K,V}
return NoDuplicateDict{K,V}()
end

Base.iterate(d::NoDuplicateDict, args...) = iterate(d.dict, args...)

Base.length(d::NoDuplicateDict) = length(d.dict)

Base.haskey(dict::NoDuplicateDict, key) = haskey(dict.dict, key)

Base.getindex(dict::NoDuplicateDict, key) = getindex(dict.dict, key)

function Base.setindex!(dict::NoDuplicateDict, value, key)
if haskey(dict, key)
error("Repeated index ", key, ". Index sets must have unique elements.")
end
return setindex!(dict.dict, value, key)
end

function NoDuplicateDict{K,V}(it) where {K,V}
dict = NoDuplicateDict{K,V}()
for (k, v) in it
dict[k] = v
end
return dict
end

function NoDuplicateDict(it)
return Base.dict_with_eltype((K, V) -> NoDuplicateDict{K,V}, it, eltype(it))
end
4 changes: 2 additions & 2 deletions src/Containers/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ julia> @variable(model, x[i=1:2, j=i:2] >= 0, start = i+j);
julia> Containers.rowtable(start_value, x; header = [:i, :j, :start])
3-element Vector{@NamedTuple{i::Int64, j::Int64, start::Float64}}:
(i = 1, j = 2, start = 3.0)
(i = 1, j = 1, start = 2.0)
(i = 1, j = 2, start = 3.0)
(i = 2, j = 2, start = 4.0)
julia> Containers.rowtable(x)
3-element Vector{@NamedTuple{x1::Int64, x2::Int64, y::VariableRef}}:
(x1 = 1, x2 = 2, y = x[1,2])
(x1 = 1, x2 = 1, y = x[1,1])
(x1 = 1, x2 = 2, y = x[1,2])
(x1 = 2, x2 = 2, y = x[2,2])
```
"""
Expand Down
26 changes: 24 additions & 2 deletions test/Containers/test_SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ module TestContainersSparseAxisArray
using JuMP.Containers
using Test

import LinearAlgebra
import OrderedCollections

function _util_sparse_test(d, sum_d, d2, d3, dsqr, d_bads)
sqr(x) = x^2
# map
Expand Down Expand Up @@ -49,7 +52,9 @@ function _util_sparse_test(d, sum_d, d2, d3, dsqr, d_bads)
end

function test_1_dimensional()
d = @inferred SparseAxisArray(Dict((:a,) => 1, (:b,) => 2))
d = @inferred SparseAxisArray(
OrderedCollections.OrderedDict((:a,) => 1, (:b,) => 2),
)
@test sprint(summary, d) == """
$(SparseAxisArray{Int,1,Tuple{Symbol}}) with 2 entries"""
@test sprint(show, "text/plain", d) == """
Expand Down Expand Up @@ -81,7 +86,9 @@ $(SparseAxisArray{Int,1,Tuple{Symbol}}) with 2 entries:
end

function test_2_dimensional()
d = @inferred SparseAxisArray(Dict((:a, 'u') => 2.0, (:b, 'v') => 0.5))
d = @inferred SparseAxisArray(
OrderedCollections.OrderedDict((:a, 'u') => 2.0, (:b, 'v') => 0.5),
)
@test d isa SparseAxisArray{Float64,2,Tuple{Symbol,Char}}
@test_throws BoundsError(d, (:a,)) d[:a]
@test sprint(summary, d) == """
Expand Down Expand Up @@ -359,4 +366,19 @@ function test_multi_arg_eachindex()
return
end

function test_sparseaxisarray_order()
A = [[1, 2, 10], [2, 3, 30]]
Containers.@container(
x[i in 1:2, j in A[i]],
i + j,
container = SparseAxisArray,
)
Containers.@container(x1[j in A[1]], 1 + j, container = SparseAxisArray)
Containers.@container(x2[j in A[2]], 2 + j, container = SparseAxisArray)
@test x[1, :] == x1
@test x[2, :] == x2
@test LinearAlgebra.dot(x[1, :], 1:3) == 41
return
end

end # module

0 comments on commit 1195695

Please sign in to comment.