diff --git a/stdlib/SparseArrays/src/sparsematrix.jl b/stdlib/SparseArrays/src/sparsematrix.jl index 5b107ff86d258..bbc5135229beb 100644 --- a/stdlib/SparseArrays/src/sparsematrix.jl +++ b/stdlib/SparseArrays/src/sparsematrix.jl @@ -3401,37 +3401,47 @@ function sortSparseMatrixCSC!(A::SparseMatrixCSC{Tv,Ti}; sortindices::Symbol = : row = zeros(Ti, m) val = zeros(Tv, m) - for i = 1:n - @inbounds col_start = colptr[i] - @inbounds col_end = (colptr[i+1] - 1) + perm = Base.Perm(Base.ord(isless, identity, false, Base.Order.Forward), row) - numrows = col_end - col_start + 1 + @inbounds for i = 1:n + nzr = nzrange(A, i) + numrows = length(nzr) if numrows <= 1 continue elseif numrows == 2 - f = col_start + f = first(nzr) s = f+1 if rowval[f] > rowval[s] - @inbounds rowval[f], rowval[s] = rowval[s], rowval[f] - @inbounds nzval[f], nzval[s] = nzval[s], nzval[f] + rowval[f], rowval[s] = rowval[s], rowval[f] + nzval[f], nzval[s] = nzval[s], nzval[f] end continue end + resize!(row, numrows) + resize!(index, numrows) jj = 1 - @simd for j = col_start:col_end - @inbounds row[jj] = rowval[j] - @inbounds val[jj] = nzval[j] + @simd for j = nzr + row[jj] = rowval[j] + val[jj] = nzval[j] jj += 1 end - sortperm!(unsafe_wrap(Vector{Ti}, pointer(index), numrows), - unsafe_wrap(Vector{Ti}, pointer(row), numrows)) + if numrows <= 16 + alg = Base.Sort.InsertionSort + else + alg = Base.Sort.QuickSort + end + + # Reset permutation + index .= 1:numrows + + sort!(index, alg, perm) jj = 1 - @simd for j = col_start:col_end - @inbounds rowval[j] = row[index[jj]] - @inbounds nzval[j] = val[index[jj]] + @simd for j = nzr + rowval[j] = row[index[jj]] + nzval[j] = val[index[jj]] jj += 1 end end