Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Containers] use OrderedDict as the data structure for SparseAxisArray #3681

Merged
merged 7 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MarkdownAST = "d0879d2d-cac2-40c8-9cee-1863dc0c7391"
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
MultiObjectiveAlgorithms = "0327d340-17cd-11ea-3e99-2fd5d98cecda"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PATHSolver = "f5f7c340-0bb3-5c69-969a-41884d311d1b"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand Down
15 changes: 15 additions & 0 deletions docs/src/manual/containers.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,28 @@ JuMP.Containers.SparseAxisArray{Tuple{Int64, Symbol}, 1, Tuple{Int64}} with 2 en

Use `eachindex` to loop over the elements:
```jldoctest containers_sparse
julia> for key in eachindex(x)
println(x[key])
end
(2, :A)
(2, :B)
(3, :A)
(3, :B)
julia> for key in eachindex(y)
println(y[key])
end
(2, :B)
(3, :B)
```

!!! warning
If you use a macro to construct a `SparseAxisArray`, then the iteration
order is row-major, that is, indices are varied from right to left. As an
example, when iterating over `x` above, the `j` index is iterated, keeping
`i` constant. This order is in contrast to `Base.Array`s, which iterate in
column-major order, that is, by varying indices from left to right.

### Broadcasting

Broadcasting over a SparseAxisArray returns a SparseAxisArray
Expand Down
2 changes: 1 addition & 1 deletion src/Containers/Containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ necessarily integers.
"""
module Containers

import Base.Meta.isexpr
import OrderedCollections

# Arbitrary typed indices. Linear indexing not supported.
struct IndexAnyCartesian <: Base.IndexStyle end
Expand Down
60 changes: 39 additions & 21 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,29 @@

"""
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
data::Dict{K,T}
data::OrderedCollections.OrderedDict{K,T}
end
`N`-dimensional array with elements of type `T` where only a subset of the
entries are defined. The entries with indices `idx = (i1, i2, ..., iN)` in
`keys(data)` has value `data[idx]`. Note that as opposed to
`SparseArrays.AbstractSparseArray`, the missing entries are not assumed to be
`zero(T)`, they are simply not part of the array. This means that the result of
`map(f, sa::SparseAxisArray)` or `f.(sa::SparseAxisArray)` has the same sparsity
structure than `sa` even if `f(zero(T))` is not zero.
`keys(data)` has value `data[idx]`.
Note that, as opposed to `SparseArrays.AbstractSparseArray`, the missing entries
are not assumed to be `zero(T)`, they are simply not part of the array. This
means that the result of `map(f, sa::SparseAxisArray)` or
`f.(sa::SparseAxisArray)` has the same sparsity structure as `sa`, even if
`f(zero(T))` is not zero.
## Example
```jldoctest
julia> dict = Dict((:a, 2) => 1.0, (:a, 3) => 2.0, (:b, 3) => 3.0)
Dict{Tuple{Symbol, Int64}, Float64} with 3 entries:
julia> using OrderedCollections: OrderedDict
julia> dict = OrderedDict((:a, 2) => 1.0, (:a, 3) => 2.0, (:b, 3) => 3.0)
OrderedDict{Tuple{Symbol, Int64}, Float64} with 3 entries:
(:a, 2) => 1.0
(:a, 3) => 2.0
(:b, 3) => 3.0
(:a, 2) => 1.0
julia> array = Containers.SparseAxisArray(dict)
SparseAxisArray{Float64, 2, Tuple{Symbol, Int64}} with 3 entries:
Expand All @@ -36,15 +40,26 @@ julia> array[:b, 3]
```
"""
struct SparseAxisArray{T,N,K<:NTuple{N,Any}} <: AbstractArray{T,N}
data::Dict{K,T}
data::OrderedCollections.OrderedDict{K,T}
names::NTuple{N,Symbol}
end

function SparseAxisArray(d::Dict{K,T}) where {T,N,K<:NTuple{N,Any}}
return SparseAxisArray(d, ntuple(n -> Symbol("#$n"), N))
function SparseAxisArray(
d::AbstractDict{K,T},
names::NTuple{N,Symbol},
) where {T,N,K<:NTuple{N,Any}}
# convert(OrderedCollections.OrderedDict{K,T}, d) is deprecated, so use an
# iterator to get all key-value pairs.
od = OrderedCollections.OrderedDict{K,T}(k => v for (k, v) in d)
return SparseAxisArray(od, names)
end

SparseAxisArray(d::Dict, ::Nothing) = SparseAxisArray(d)
function SparseAxisArray(
d::AbstractDict{K,T},
::Nothing = nothing,
) where {T,N,K<:NTuple{N,Any}}
return SparseAxisArray(d, ntuple(n -> Symbol("#$n"), N))
end

Base.length(sa::SparseAxisArray) = length(sa.data)

Expand All @@ -71,7 +86,7 @@ function Base.similar(
::Type{T},
length::Integer = 0,
) where {S,T,N,K}
d = Dict{K,T}()
d = OrderedCollections.OrderedDict{K,T}()
if !iszero(length)
sizehint!(d, length)
end
Expand Down Expand Up @@ -165,7 +180,7 @@ function Base.getindex(
end
K2 = _sliced_key_type(K, args...)
if K2 !== nothing
new_data = Dict{K2,T}(
new_data = OrderedCollections.OrderedDict{K2,T}(
_sliced_key(k, args) => v for (k, v) in d.data if _filter(k, args)
)
names = _sliced_key_name(K, d.names, args...)
Expand Down Expand Up @@ -293,12 +308,16 @@ end
function Base.copy(
bc::Base.Broadcast.Broadcasted{BroadcastStyle{N,K}},
) where {N,K}
dict = Dict(index => _getindex(bc, index) for index in _indices(bc.args...))
if isempty(dict) && dict isa Dict{Any,Any}
dict = OrderedCollections.OrderedDict(
index => _getindex(bc, index) for index in _indices(bc.args...)
)
if isempty(dict) && dict isa OrderedCollections.OrderedDict{Any,Any}
# If `dict` is empty (e.g., because there are no indices), then
# inference will produce a `Dict{Any,Any}`, and we won't have enough
# type information to call SparseAxisArray(dict). As a work-around, we
# explicitly construct the type of the resulting SparseAxisArray.
# inference will produce a `OrderedCollections.OrderedDict{Any,Any}`,
# and we won't have enough type information to call
# `SparseAxisArray(dict)`. As a work-around, we explicitly construct the
# type of the resulting SparseAxisArray.
#
# For more, see JuMP issue #2867.
return SparseAxisArray{Any,N,K}(dict, ntuple(n -> Symbol("#$n"), N))
end
Expand Down Expand Up @@ -448,7 +467,6 @@ function Base.show(io::IOContext, x::SparseAxisArray)
(i, (key, value)) in enumerate(x.data) if
i < half_screen_rows || i > length(x) - half_screen_rows
]
sort!(key_strings; by = x -> x[1])
odow marked this conversation as resolved.
Show resolved Hide resolved
pad = maximum(length(x[1]) for x in key_strings)
for (i, (key, value)) in enumerate(key_strings)
print(io, " [", rpad(key, pad), "] = ", value)
Expand Down
25 changes: 19 additions & 6 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,21 @@ function container(
end
# Same as `map` but does not allocate the resulting vector.
mappings = Base.Generator(I -> I => f(I...), indices)
# Same as `Dict(mapping)` but it will error if two indices are the same.
# Same as `OrderedCollections.OrderedDict(mapping)`, but it will error if
# two indices are the same.
data = NoDuplicateDict(mappings)
return _sparseaxisarray(data.dict, f, indices, names)
end

# The NoDuplicateDict was able to infer the element type.
_sparseaxisarray(dict::Dict, ::Any, ::Any, names) = SparseAxisArray(dict, names)
function _sparseaxisarray(
dict::OrderedCollections.OrderedDict,
::Any,
::Any,
names,
)
return SparseAxisArray(dict, names)
end

# @default_eltype succeeded and inferred a tuple of the appropriate size!
# Use `return_types` to get the value type of the dictionary.
Expand All @@ -159,13 +167,13 @@ function _container_dict(
) where {N}
ret = Base.return_types(f, K)
V = length(ret) == 1 ? first(ret) : Any
return Dict{K,V}()
return OrderedCollections.OrderedDict{K,V}()
end

# @default_eltype bailed and returned Any. Use an NTuple of Any of the
# appropriate size intead.
function _container_dict(::Any, ::Any, K::Type{<:NTuple{N,Any}}) where {N}
return Dict{K,Any}()
return OrderedCollections.OrderedDict{K,Any}()
end

# @default_eltype bailed and returned Union{}. Use an NTuple of Any of the
Expand All @@ -176,7 +184,7 @@ function _container_dict(
::Function,
K::Type{<:NTuple{N,Any}},
) where {N}
return Dict{K,Any}()
return OrderedCollections.OrderedDict{K,Any}()
end

# Calling `@default_eltye` on `x` isn't sufficient, because the iterator may
Expand All @@ -189,7 +197,12 @@ _default_eltype(x) = Base.@default_eltype x
# best-guess attempt, collect all of the keys excluding the conditional
# statement (these must be defined, because the conditional applies to the
# lowest-level of the index loops), then get the eltype of the result.
function _sparseaxisarray(dict::Dict{Any,Any}, f, indices, names)
function _sparseaxisarray(
dict::OrderedCollections.OrderedDict{Any,Any},
f,
indices,
names,
)
@assert isempty(dict)
d = _container_dict(_default_eltype(indices), f, _eltype_or_any(indices))
return SparseAxisArray(d, names)
Expand Down
21 changes: 15 additions & 6 deletions src/Containers/no_duplicate_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,47 @@

"""
struct NoDuplicateDict{K, V} <: AbstractDict{K, V}
dict::Dict{K, V}
dict::OrderedCollections.OrderedDict{K, V}
end
Same as `Dict{K, V}` but errors if constructed from an iterator with duplicate
keys.
Same as `OrderedCollections.OrderedDict{K, V}` but errors if constructed from an
iterator with duplicate keys.
"""
struct NoDuplicateDict{K,V} <: AbstractDict{K,V}
dict::Dict{K,V}
NoDuplicateDict{K,V}() where {K,V} = new{K,V}(Dict{K,V}())
dict::OrderedCollections.OrderedDict{K,V}

function NoDuplicateDict{K,V}() where {K,V}
return new{K,V}(OrderedCollections.OrderedDict{K,V}())
end
end

# Implementation of the `AbstractDict` API.
function Base.empty(::NoDuplicateDict, ::Type{K}, ::Type{V}) where {K,V}
return NoDuplicateDict{K,V}()
end

Base.iterate(d::NoDuplicateDict, args...) = iterate(d.dict, args...)

Base.length(d::NoDuplicateDict) = length(d.dict)

Base.haskey(dict::NoDuplicateDict, key) = haskey(dict.dict, key)

Base.getindex(dict::NoDuplicateDict, key) = getindex(dict.dict, key)

function Base.setindex!(dict::NoDuplicateDict, value, key)
if haskey(dict, key)
error("Repeated index ", key, ". Index sets must have unique elements.")
end
return setindex!(dict.dict, value, key)
end

function NoDuplicateDict{K,V}(it) where {K,V}
dict = NoDuplicateDict{K,V}()
for (k, v) in it
dict[k] = v
end
return dict
end

function NoDuplicateDict(it)
return Base.dict_with_eltype((K, V) -> NoDuplicateDict{K,V}, it, eltype(it))
end
4 changes: 2 additions & 2 deletions src/Containers/tables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ julia> @variable(model, x[i=1:2, j=i:2] >= 0, start = i+j);
julia> Containers.rowtable(start_value, x; header = [:i, :j, :start])
3-element Vector{@NamedTuple{i::Int64, j::Int64, start::Float64}}:
(i = 1, j = 2, start = 3.0)
(i = 1, j = 1, start = 2.0)
(i = 1, j = 2, start = 3.0)
(i = 2, j = 2, start = 4.0)
julia> Containers.rowtable(x)
3-element Vector{@NamedTuple{x1::Int64, x2::Int64, y::VariableRef}}:
(x1 = 1, x2 = 2, y = x[1,2])
(x1 = 1, x2 = 1, y = x[1,1])
(x1 = 1, x2 = 2, y = x[1,2])
(x1 = 2, x2 = 2, y = x[2,2])
```
"""
Expand Down
26 changes: 24 additions & 2 deletions test/Containers/test_SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ module TestContainersSparseAxisArray
using JuMP.Containers
using Test

import LinearAlgebra
import OrderedCollections

function _util_sparse_test(d, sum_d, d2, d3, dsqr, d_bads)
sqr(x) = x^2
# map
Expand Down Expand Up @@ -49,7 +52,9 @@ function _util_sparse_test(d, sum_d, d2, d3, dsqr, d_bads)
end

function test_1_dimensional()
d = @inferred SparseAxisArray(Dict((:a,) => 1, (:b,) => 2))
d = @inferred SparseAxisArray(
OrderedCollections.OrderedDict((:a,) => 1, (:b,) => 2),
)
@test sprint(summary, d) == """
$(SparseAxisArray{Int,1,Tuple{Symbol}}) with 2 entries"""
@test sprint(show, "text/plain", d) == """
Expand Down Expand Up @@ -81,7 +86,9 @@ $(SparseAxisArray{Int,1,Tuple{Symbol}}) with 2 entries:
end

function test_2_dimensional()
d = @inferred SparseAxisArray(Dict((:a, 'u') => 2.0, (:b, 'v') => 0.5))
d = @inferred SparseAxisArray(
OrderedCollections.OrderedDict((:a, 'u') => 2.0, (:b, 'v') => 0.5),
)
@test d isa SparseAxisArray{Float64,2,Tuple{Symbol,Char}}
@test_throws BoundsError(d, (:a,)) d[:a]
@test sprint(summary, d) == """
Expand Down Expand Up @@ -359,4 +366,19 @@ function test_multi_arg_eachindex()
return
end

function test_sparseaxisarray_order()
A = [[1, 2, 10], [2, 3, 30]]
Containers.@container(
x[i in 1:2, j in A[i]],
i + j,
container = SparseAxisArray,
)
Containers.@container(x1[j in A[1]], 1 + j, container = SparseAxisArray)
Containers.@container(x2[j in A[2]], 2 + j, container = SparseAxisArray)
@test x[1, :] == x1
@test x[2, :] == x2
@test LinearAlgebra.dot(x[1, :], 1:3) == 41
return
end

end # module
Loading