Skip to content

Commit

Permalink
Make MatrixModificationCache more efficient.
Browse files Browse the repository at this point in the history
  • Loading branch information
tkoolen committed Feb 16, 2018
1 parent 776db15 commit f4331dc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 47 deletions.
78 changes: 39 additions & 39 deletions src/MathOptInterfaceOSQP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,63 +49,62 @@ end

# TODO: consider not storing modifications in a SparseVector
struct MatrixModificationCache{T}
cartesian_to_nzval_index::Dict{CartesianIndex{2}, Int}
modifications::SparseVector{T, Int}
cartesian_indices::Vector{CartesianIndex{2}}
cartesian_indices_set::Set{CartesianIndex{2}} # to speed up checking whether indices are set out of bounds
modifications::Dict{CartesianIndex{2}, T}
vals::Vector{T}
inds::Vector{Int}

function MatrixModificationCache(S::SparseMatrixCSC{T}) where T
cartesian_to_nzval_index = Dict{CartesianIndex{2}, Int}()
sizehint!(cartesian_to_nzval_index, nnz(S))
cartesian_indices = Vector{CartesianIndex{2}}(uninitialized, nnz(S))
@inbounds for col = 1 : S.n, k = S.colptr[col] : (S.colptr[col+1]-1) # from sparse findn
row = S.rowval[k]
cartesian_to_nzval_index[CartesianIndex(row, col)] = k
I = CartesianIndex(row, col)
cartesian_indices[k] = I
end
modifications = spzeros(length(cartesian_to_nzval_index))
new{T}(cartesian_to_nzval_index, modifications)
modifications = Dict{CartesianIndex{2}, Int}()
new{T}(cartesian_indices, Set(cartesian_indices), modifications, T[], Int[])
end
end

@inline function modification_index(cache::MatrixModificationCache, row::Integer, col::Integer)
function Base.setindex!(cache::MatrixModificationCache, x, row::Integer, col::Integer)
I = CartesianIndex(row, col)
@boundscheck haskey(cache.cartesian_to_nzval_index, I) || throw(ArgumentError("Changing the sparsity pattern is not allowed."))
cache.cartesian_to_nzval_index[I]
end

Base.@propagate_inbounds function Base.setindex!(cache::MatrixModificationCache, x, row::Integer, col::Integer)
cache.modifications[modification_index(cache, row, col)] = x
I cache.cartesian_indices_set || throw(ArgumentError("Changing the sparsity pattern is not allowed."))
cache.modifications[I] = x
end

function Base.setindex!(cache::MatrixModificationCache, x::Real, ::Colon)
x == 0 || throw(ArgumentError("Changing the sparsity pattern is not allowed."))
# using internals of SparseVector here...
ninds = length(cache.cartesian_to_nzval_index)
nzind = cache.modifications.nzind
nzval = cache.modifications.nzval
resize!(nzind, ninds)
resize!(nzval, ninds)
cache.modifications.nzind[:] = 1 : ninds
cache.modifications.nzval[:] = 0
end

Base.@propagate_inbounds function Base.getindex(cache::MatrixModificationCache, row::Integer, col::Integer)
cache.modifications[modification_index(cache, row, col)]
for I in cache.cartesian_indices
cache.modifications[I] = 0
end
end

Base.@propagate_inbounds function isassigned(cache::MatrixModificationCache, row::Integer, col::Integer)
modification_index(cache, row, col) cache.modifications.nzind
function Base.getindex(cache::MatrixModificationCache, row::Integer, col::Integer)
cache.modifications[CartesianIndex(row, col)]
end

function clearmodifications!(cache::MatrixModificationCache)
empty!(cache.modifications.nzind)
empty!(cache.modifications.nzval)
cache
function isassigned(cache::MatrixModificationCache, row::Integer, col::Integer)
haskey(cache.modifications, CartesianIndex(row, col))
end

function processupdates!(model::OSQP.Model, cache::MatrixModificationCache, updatefun::Function)
dirty = nnz(cache.modifications) > 0
dirty = length(cache.modifications) > 0
if dirty
# using internals of SparseVector here...
updatefun(model, cache.modifications.nzval, cache.modifications.nzind)
clearmodifications!(cache)
nmods = length(cache.modifications)
resize!(cache.vals, nmods)
resize!(cache.inds, nmods)
count = 1
for i in eachindex(cache.cartesian_indices)
I = cache.cartesian_indices[i]
if haskey(cache.modifications, I)
cache.vals[count] = cache.modifications[I]
cache.inds[count] = i
count += 1
end
end
updatefun(model, cache.vals, cache.inds)
empty!(cache.modifications)
end
end

Expand Down Expand Up @@ -442,15 +441,16 @@ function MOI.set!(optimizer::OSQPOptimizer, ::MOI.ObjectiveFunction{Quadratic},
coeffs = obj.quadratic_coefficients
n = length(coeffs)
@assert length(rows) == length(cols) == n

for i = 1 : n
row = rows[i].value
col = cols[i].value
coeff = coeffs[i]
row > col && ((row, col) = (col, row)) # upper triangle only
if isassigned(Pcache, row, col)
@inbounds Pcache[row, col] += coeff
Pcache[row, col] += coeff
else
@inbounds Pcache[row, col] = coeff
Pcache[row, col] = coeff
end
end
processlinearterms!(optimizer.modcache.qcache, obj.affine_variables, obj.affine_coefficients)
Expand Down Expand Up @@ -533,7 +533,7 @@ function MOI.get(optimizer::OSQPOptimizer, ::MOI.DualStatus)
end


## Variables and constraints:
## Variables:
MOI.isvalid(optimizer::OSQPOptimizer, vi::VI) = MOI.canget(optimizer, MOI.NumberOfVariables()) && vi.value 1 : get(optimizer, MOI.NumberOfVariables())
MOI.canaddvariable(optimizer::OSQPOptimizer) = false # TODO: currently required by tests; should there be a default fallback in MOI, similar to canget?

Expand Down
11 changes: 3 additions & 8 deletions test/MathOptInterfaceOSQP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,11 @@ const MOIU = MathOptInterfaceUtilities
baseresults = OSQP.solve!(model)
@test baseresults.info.status == :Solved

# VectorModificationCache basics
# Modify q, ensure that updating results in the same solution as calling setup! with the modified q
@test !modcache.qcache.dirty
modcache.qcache[3] = 5.
@test modcache.qcache.dirty
@test modcache.qcache.data[3] == 5.

# Process q modifications, ensure that updating results in the same solution as calling setup! with the modified q
MathOptInterfaceOSQP.processupdates!(model, modcache)
@test !modcache.qcache.dirty
qmod_update_results = OSQP.solve!(model)
Expand All @@ -44,18 +42,15 @@ const MOIU = MathOptInterfaceUtilities
qmod_setup_results = OSQP.solve!(model2)
@test qmod_update_results.x qmod_setup_results.x atol = 1e-8

# MatrixModificationCache basics
# Modify A, ensure that updating results in the same solution as calling setup! with the modified A and q
(I, J) = findn(A)
Amodindex = rand(rng, 1 : nnz(A))
row = I[Amodindex]
col = J[Amodindex]
val = randn(rng)
modcache.Acache[row, col] = val
@test any(x -> x == val, modcache.Acache.modifications)

# Process A modifications, ensure that updating results in the same solution as calling setup! with the modified A and q
MathOptInterfaceOSQP.processupdates!(model, modcache)
@test all(iszero, modcache.Acache.modifications)
@test isempty(modcache.Acache.modifications)
Amod_update_results = OSQP.solve!(model)
@test !isapprox(baseresults.x, Amod_update_results.x; atol = 1e-1) # ensure that new results are significantly different
@test !isapprox(qmod_update_results.x, Amod_update_results.x; atol = 1e-1) # ensure that new results are significantly different
Expand Down

0 comments on commit f4331dc

Please sign in to comment.