Skip to content

Commit

Permalink
WIP: linalg: Make bareiss more flexible
Browse files Browse the repository at this point in the history
We have several implementations of the Bareiss fraction-free row-reduction
algorithm in the Julia ecosystem. The one in Base was added in #40868
to compute exact determinants. We also have implementations in MTK [1]
and Modia [2].

The MTK and Modia versions additionally have support for custom pivot
selection, open-code a sparse matrix data structure adapted to their
domains and support rank-deficient matrices.

I would like to separate out the algorithmic and data-structures concerns
so that they may tested independently. Of course this function isn't
particularly large, but implementing it correctly and performantly
is still not trivial, so it seems better to have one implementation
rather than three.

[1] https://github.com/SciML/ModelingToolkit.jl/blob/master/src/systems/alias_elimination.jl#L236
[2] https://github.com/ModiaSim/ModiaBase.jl/blob/6c341eed72d9867553cb9565330d8ae85221b343/src/LinearIntegerEquations.jl#L204
  • Loading branch information
Keno authored and oscardssmith committed Feb 1, 2022
1 parent 0cbacf2 commit 8321718
Show file tree
Hide file tree
Showing 2 changed files with 3,401 additions and 16 deletions.
91 changes: 75 additions & 16 deletions stdlib/LinearAlgebra/src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1619,6 +1619,70 @@ const NumberArray{T<:Number} = AbstractArray{T}
exactdiv(a, b) = a/b
exactdiv(a::Integer, b::Integer) = div(a, b)

# Bareiss algorithm
function bareiss_update!(zero!, M::Matrix, k, swapto, pivot, prev_pivot)
for i in k+1:size(M, 2), j in k+1:size(M, 1)
M[j,i] = exactdiv(M[j,i]*pivot - M[j,k]*M[k,i], prev_pivot)
end
zero!(M, k+1:size(M, 1), k)
end

function bareiss_update!(zero!, M::AbstractMatrix, k, swapto, pivot, prev_pivot)
V = @view M[k+1:end, k+1:end]
V .= exactdiv.(V * pivot - M[k+1:end, k] * M[k, k+1:end]', prev_pivot)
zero!(M, k+1:size(M, 1), k)
end

function bareiss_update_virtual_colswap!(zero!, M::AbstractMatrix, k, swapto, pivot, prev_pivot)
V = @view M[k+1:end, :]
V .= exactdiv.(V * pivot - M[k+1:end, swapto[2]] * M[k, :]', prev_pivot)
zero!(M, k+1:size(M, 1), swapto[2])
end

bareiss_zero!(M, i, j) = M[i,j] .= zero(eltype(M))

function find_pivot_col(M, i)
p = findfirst(!iszero, @view M[i,i:end])
p === nothing && return nothing
idx = CartesianIndex(i, p + i - 1)
(idx, M[idx])
end

function find_pivot_any(M, i)
p = findfirst(!iszero, @view M[i:end,i:end])
p === nothing && return nothing
idx = p + CartesianIndex(i - 1, i - 1)
(idx, M[idx])
end

const bareiss_colswap = (Base.swapcols!, Base.swaprows!, bareiss_update!, bareiss_zero!)
const bareiss_virtcolswap = ((M,i,j)->nothing, Base.swaprows!, bareiss_update_virtual_colswap!, bareiss_zero!)

"""
bareiss!(M)
Perform Bareiss's fraction-free row-reduction algorithm on the matrix `M`.
Optionally, a specific pivoting method may be specified.
"""
function bareiss!(M::AbstractMatrix,
(swapcols!, swaprows!, update!, zero!) = bareiss_colswap;
find_pivot=find_pivot_any)
prev = one(eltype(M))
n = size(M, 1)
for k in 1:n
r = find_pivot(M, k)
r === nothing && return k - 1
(swapto, pivot) = r
if CartesianIndex(k, k) != swapto
swapcols!(M, k, swapto[2])
swaprows!(M, k, swapto[1])
end
update!(zero!, M, k, swapto, pivot, prev)
prev = pivot
end
return n
end

"""
det_bareiss!(M)
Expand All @@ -1639,21 +1703,18 @@ julia> LinearAlgebra.det_bareiss!(M)
"""
function det_bareiss!(M)
n = checksquare(M)
sign, prev = Int8(1), one(eltype(M))
for i in 1:n-1
if iszero(M[i,i]) # swap with another col to make nonzero
swapto = findfirst(!iszero, @view M[i,i+1:end])
isnothing(swapto) && return zero(prev)
sign = -sign
Base.swapcols!(M, i, i + swapto)
end
for k in i+1:n, j in i+1:n
M[j,k] = exactdiv(M[j,k]*M[i,i] - M[j,i]*M[i,k], prev)
end
prev = M[i,i]
end
return sign * M[end,end]
parity = true
swaprows!(M, i, j) = (i != j && (parity = !parity); Base.swaprows!(M, i, j))
swapcols!(M, i, j) = (i != j && (parity = !parity); Base.swapcols!(M, i, j))
# We only look at the last entry, so we don't care that the sub-diagonals are
# garbage.
zero!(M, i, j) = nothing
rank = bareiss!(M, (swapcols!, swaprows!, bareiss_update!, zero!);
find_pivot=find_pivot_col)
rank != n && return zero(eltype(M))
return parity ? M[n,n] : -M[n, n]
end

"""
LinearAlgebra.det_bareiss(M)
Expand All @@ -1663,8 +1724,6 @@ Also refer to [`det_bareiss!`](@ref).
"""
det_bareiss(M) = det_bareiss!(copy(M))



"""
promote_leaf_eltypes(itr)
Expand Down
Loading

0 comments on commit 8321718

Please sign in to comment.