Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch from internal 5-arg searchsorted* methods to views #440

Merged
merged 2 commits into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@
i1 = ia[j]
i2 = ia[j + 1] - 1
done = unit

bj = B[joff + j]
for ii = i1:i2
jai = ja[ii]
Expand Down Expand Up @@ -483,7 +483,7 @@
i2 = ia[j + 1] - 1
akku = Z
j0 = !unit ? j : j - 1

# loop through column j of A - only structural non-zeros
for ii = i1:i2
jai = ja[ii]
Expand All @@ -509,7 +509,7 @@
i1 = ia[j]
i2 = ia[j + 1] - 1
done = unit

bj = B[joff + j]
for ii = i2:-1:i1
jai = ja[ii]
Expand Down Expand Up @@ -539,7 +539,7 @@
i2 = ia[j + 1] - 1
akku = Z
j0 = !unit ? j : j + 1

# loop through column j of A - only structural non-zeros
for ii = i2:-1:i1
jai = ja[ii]
Expand Down Expand Up @@ -582,7 +582,7 @@
i1 = ia[j]
i2 = ia[j + 1] - 1
done = unit

bj = B[joff + j]
for ii = i1:i2
jai = ja[ii]
Expand Down Expand Up @@ -610,7 +610,7 @@
i1 = ia[j]
i2 = ia[j + 1] - 1
done = unit

bj = B[joff + j]
for ii = i2:-1:i1
jai = ja[ii]
Expand Down Expand Up @@ -664,7 +664,7 @@
i2 = ia[j + 1] - one(eltype(ia))

# find diagonal element
ii = searchsortedfirst(ja, j, i1, i2, Base.Order.Forward)
ii = searchsortedfirst(view(ja, i1:i2), j) + i1 - 1
jai = ii > i2 ? zero(eltype(ja)) : ja[ii]

cj = C[j,k]
Expand Down Expand Up @@ -693,7 +693,7 @@
i2 = ia[j + 1] - 1
akku = B[j,k]
done = false

# loop through column j of A - only structural non-zeros
for ii = i2:-1:i1
jai = ja[ii]
Expand Down Expand Up @@ -721,11 +721,11 @@
for j = nrowB:-1:1
i1 = ia[j]
i2 = ia[j + 1] - one(eltype(ia))

# find diagonal element
ii = searchsortedlast(ja, j, i1, i2, Base.Order.Forward)
ii = searchsortedlast(view(ja, i1:i2), j) + i1 - 1
jai = ii < i1 ? zero(eltype(ja)) : ja[ii]

cj = C[j,k]
# check for zero pivot and divide with pivot
if jai == j
Expand All @@ -737,7 +737,7 @@
elseif !unit
throw(LinearAlgebra.SingularException(j))
end

# update remaining part
for i = ii:-1:i1
C[ja[i],k] -= cj * LinearAlgebra._ustrip(aa[i])
Expand All @@ -752,7 +752,7 @@
i2 = ia[j + 1] - 1
akku = B[j,k]
done = false

# loop through column j of A - only structural non-zeros
for ii = i1:i2
jai = ja[ii]
Expand Down Expand Up @@ -801,7 +801,7 @@
i2 = ia[j + 1] - one(eltype(ia))

# find diagonal element
ii = searchsortedfirst(ja, j, i1, i2, Base.Order.Forward)
ii = searchsortedfirst(view(ja, i1:i2), j) + i1 - 1

Check warning on line 804 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L804

Added line #L804 was not covered by tests
jai = ii > i2 ? zero(eltype(ja)) : ja[ii]

cj = C[j,k]
Expand All @@ -828,11 +828,11 @@
for j = nrowB:-1:1
i1 = ia[j]
i2 = ia[j + 1] - one(eltype(ia))

# find diagonal element
ii = searchsortedlast(ja, j, i1, i2, Base.Order.Forward)
ii = searchsortedlast(view(ja, i1:i2), j) + i1 - 1

Check warning on line 833 in src/linalg.jl

View check run for this annotation

Codecov / codecov/patch

src/linalg.jl#L833

Added line #L833 was not covered by tests
jai = ii < i1 ? zero(eltype(ja)) : ja[ii]

cj = C[j,k]
# check for zero pivot and divide with pivot
if jai == j
Expand All @@ -844,7 +844,7 @@
elseif !unit
throw(LinearAlgebra.SingularException(j))
end

# update remaining part
for i = ii:-1:i1
C[ja[i],k] -= cj * LinearAlgebra._ustrip(conj(aa[i]))
Expand Down Expand Up @@ -904,13 +904,13 @@
function nzrangeup(A, i, excl=false)
r = nzrange(A, i); r1 = r.start; r2 = r.stop
rv = rowvals(A)
@inbounds r2 < r1 || rv[r2] <= i - excl ? r : r1:searchsortedlast(rv, i - excl, r1, r2, Forward)
@inbounds r2 < r1 || rv[r2] <= i - excl ? r : r1:(searchsortedlast(view(rv, r1:r2), i - excl) + r1-1)
end
# row range from diagonal (included if excl=false) to end
function nzrangelo(A, i, excl=false)
r = nzrange(A, i); r1 = r.start; r2 = r.stop
rv = rowvals(A)
@inbounds r2 < r1 || rv[r1] >= i + excl ? r : searchsortedfirst(rv, i + excl, r1, r2, Forward):r2
@inbounds r2 < r1 || rv[r1] >= i + excl ? r : (searchsortedfirst(view(rv, r1:r2), i + excl) + r1-1):r2
end

dot(x::AbstractVector, A::RealHermSymComplexHerm{<:Any,<:AbstractSparseMatrixCSC}, y::AbstractVector) =
Expand Down Expand Up @@ -985,7 +985,7 @@
r1 = Int(Acolptr[i])
r2 = Int(Acolptr[i+1]-1)
r1 > r2 && continue
r1 = searchsortedfirst(Arowval, i, r1, r2, Forward)
r1 += searchsortedfirst(view(Arowval, r1:r2), i) - 1
((r1 > r2) || (Arowval[r1] != i)) && continue
r += dot(x[i], diagop(Anzval[r1]), y[i])
end
Expand Down
3 changes: 1 addition & 2 deletions src/solvers/cholmod.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1180,8 +1180,7 @@ function getindex(A::Sparse{T}, i0::Integer, i1::Integer) where T
r1 = Int(unsafe_load(s.p, i1) + 1)
r2 = Int(unsafe_load(s.p, i1 + 1))
(r1 > r2) && return zero(T)
r1 = Int(searchsortedfirst(unsafe_wrap(Array, s.i, (s.nzmax,), own = false),
i0 - 1, r1, r2, Base.Order.Forward))
r1 += Int(searchsortedfirst(view(unsafe_wrap(Array, s.i, (s.nzmax,), own = false), r1:r2), i0 - 1) - 1)
((r1 > r2) || (unsafe_load(s.i, r1) + 1 != i0)) ? zero(T) : unsafe_load(Ptr{T}(s.x), r1)
end

Expand Down
30 changes: 15 additions & 15 deletions src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1939,11 +1939,11 @@
col > size(m, 2) && return nothing

lo, hi = getcolptr(m)[col], getcolptr(m)[col+1]
n = searchsortedfirst(rowvals(m), row, lo, hi-1, Base.Order.Forward)
n = searchsortedfirst(view(rowvals(m), lo:hi-1), row) + lo - 1

Check warning on line 1942 in src/sparsematrix.jl

View check run for this annotation

Codecov / codecov/patch

src/sparsematrix.jl#L1942

Added line #L1942 was not covered by tests
if lo <= n <= hi-1
return CartesianIndex(rowvals(m)[n], col)
end
nextcol = searchsortedfirst(getcolptr(m), hi + 1, col + 1, length(getcolptr(m)), Base.Order.Forward)
nextcol = searchsortedfirst(view(getcolptr(m), col+1:length(getcolptr(m))), hi + 1) + col

Check warning on line 1946 in src/sparsematrix.jl

View check run for this annotation

Codecov / codecov/patch

src/sparsematrix.jl#L1946

Added line #L1946 was not covered by tests
nextcol > length(getcolptr(m)) && return nothing
nextlo = getcolptr(m)[nextcol-1]
return CartesianIndex(rowvals(m)[nextlo], nextcol - 1)
Expand All @@ -1954,11 +1954,11 @@
iszero(col) && return nothing

lo, hi = getcolptr(m)[col], getcolptr(m)[col+1]
n = searchsortedlast(rowvals(m), row, lo, hi-1, Base.Order.Forward)
n = searchsortedlast(view(rowvals(m), lo:hi-1), row) + lo - 1

Check warning on line 1957 in src/sparsematrix.jl

View check run for this annotation

Codecov / codecov/patch

src/sparsematrix.jl#L1957

Added line #L1957 was not covered by tests
if lo <= n <= hi-1
return CartesianIndex(rowvals(m)[n], col)
end
prevcol = searchsortedlast(getcolptr(m), lo - 1, 1, col - 1, Base.Order.Forward)
prevcol = searchsortedlast(view(getcolptr(m), 1:col-1), lo - 1)

Check warning on line 1961 in src/sparsematrix.jl

View check run for this annotation

Codecov / codecov/patch

src/sparsematrix.jl#L1961

Added line #L1961 was not covered by tests
prevcol < 1 && return nothing
prevhi = getcolptr(m)[prevcol+1]
return CartesianIndex(rowvals(m)[prevhi-1], prevcol)
Expand Down Expand Up @@ -2564,8 +2564,8 @@
r1::Int = colptr[col]
r2::Int = colptr[col+1] - 1
if !allrows && (r1 <= r2)
r1 = searchsortedfirst(rowval, rowmin, r1, r2, Forward)
(r1 <= r2 ) && (r2 = searchsortedlast(rowval, rowmax, r1, r2, Forward))
r1 += searchsortedfirst(view(rowval, r1:r2), rowmin) - 1
(r1 <= r2 ) && (r2 = searchsortedlast(view(rowval, r1:r2), rowmax) + r1 - 1)
end
row = rowmin
while (r1 <= r2) && (row == rowval[r1]) && _isnotzero(nzval[r1])
Expand Down Expand Up @@ -2677,7 +2677,7 @@
r1 = Int(@inbounds getcolptr(A)[i1])
r2 = Int(@inbounds getcolptr(A)[i1+1]-1)
(r1 > r2) && return zero(T)
r1 = searchsortedfirst(rowvals(A), i0, r1, r2, Forward)
r1 = searchsortedfirst(view(rowvals(A), r1:r2), i0) + r1 - 1
((r1 > r2) || (rowvals(A)[r1] != i0)) ? zero(T) : nonzeros(A)[r1]
end

Expand Down Expand Up @@ -2813,7 +2813,7 @@
rowI = I[ptrI]
ptrI += 1
(rowvalA[ptrA] > rowI) && continue
ptrA = searchsortedfirst(rowvalA, rowI, ptrA, stopA, Base.Order.Forward)
ptrA += searchsortedfirst(view(rowvalA, ptrA:stopA), rowI) - 1
(ptrA <= stopA) || break
if rowvalA[ptrA] == rowI
ptrS += 1
Expand All @@ -2837,7 +2837,7 @@
while ptrI <= nI
rowI = I[ptrI]
if rowvalA[ptrA] <= rowI
ptrA = searchsortedfirst(rowvalA, rowI, ptrA, stopA, Base.Order.Forward)
ptrA += searchsortedfirst(view(rowvalA, ptrA:stopA), rowI) - 1
(ptrA <= stopA) || break
if rowvalA[ptrA] == rowI
rowvalS[ptrS] = ptrI
Expand Down Expand Up @@ -2941,7 +2941,7 @@
@inbounds for j = 1:m
cval = cacheI[j]
(cval == 0) && continue
ptrI = searchsortedfirst(I, j, ptrI, nI, Base.Order.Forward)
ptrI += searchsortedfirst(view(I, ptrI:nI), j) - 1
cacheI[j] = ptrI
while ptrI <= nI && I[ptrI] == j
ptrS += cval
Expand Down Expand Up @@ -3129,7 +3129,7 @@
end
coljfirstk = Int(getcolptr(A)[j])
coljlastk = Int(getcolptr(A)[j+1] - 1)
searchk = searchsortedfirst(rowvals(A), i, coljfirstk, coljlastk, Base.Order.Forward)
searchk = searchsortedfirst(view(rowvals(A), coljfirstk:coljlastk), i) + coljfirstk - 1
if searchk <= coljlastk && rowvals(A)[searchk] == i
# Column j contains entry A[i,j]. Update and return
nonzeros(A)[searchk] = v
Expand Down Expand Up @@ -3491,7 +3491,7 @@
xidx += 1

if r1 <= r2
copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1
copylen = searchsortedfirst(view(rowvalA, r1:r2), row) - 1
if (copylen > 0)
if (nadd > 0)
copyto!(rowvalB, bidx, rowvalA, r1, copylen)
Expand Down Expand Up @@ -3621,7 +3621,7 @@
end

if r1 <= r2
copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1
copylen = searchsortedfirst(view(rowvalA, r1:r2), row) - 1
if (copylen > 0)
if (nadd > 0)
copyto!(rowvalB, bidx, rowvalA, r1, copylen)
Expand Down Expand Up @@ -3705,7 +3705,7 @@
end
coljfirstk = Int(getcolptr(A)[j])
coljlastk = Int(getcolptr(A)[j+1] - 1)
searchk = searchsortedfirst(rowvals(A), i, coljfirstk, coljlastk, Base.Order.Forward)
searchk = searchsortedfirst(view(rowvals(A), coljfirstk:coljlastk), i) + coljfirstk - 1
if searchk <= coljlastk && rowvals(A)[searchk] == i
# Entry A[i,j] is stored. Drop and return.
deleteat!(rowvals(A), searchk)
Expand Down Expand Up @@ -4222,7 +4222,7 @@
r1 = Int(getcolptr(A)[c])
r2 = Int(getcolptr(A)[c+1]-1)
r1 > r2 && continue
r1 = searchsortedfirst(rowvals(A), r, r1, r2, Forward)
r1 += searchsortedfirst(view(rowvals(A), r1:r2), r) - 1
((r1 > r2) || (rowvals(A)[r1] != r)) && continue
push!(ind, i)
push!(val, nonzeros(A)[r1])
Expand Down
8 changes: 4 additions & 4 deletions src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,8 @@ function getindex(x::AbstractSparseMatrixCSC, I::AbstractUnitRange, j::Integer)
c1 = convert(Int, getcolptr(x)[j])
c2 = convert(Int, getcolptr(x)[j+1]) - 1
# Restrict to the selected rows
r1 = searchsortedfirst(rowvals(x), first(I), c1, c2, Forward)
r2 = searchsortedlast(rowvals(x), last(I), c1, c2, Forward)
r1 = searchsortedfirst(view(rowvals(x), c1:c2), first(I)) + c1 - 1
r2 = searchsortedlast(view(rowvals(x), c1:c2), last(I)) + c1 - 1
return @if_move_fixed x SparseVector(length(I), [rowvals(x)[i] - first(I) + 1 for i = r1:r2], nonzeros(x)[r1:r2])
end

Expand Down Expand Up @@ -670,7 +670,7 @@ function Base.getindex(A::AbstractSparseMatrixCSC{Tv,Ti}, i::Integer, J::Abstrac
stopA = Int(colptrA[col+1]-1)
if ptrA <= stopA
if rowvalA[ptrA] <= rowI
ptrA = searchsortedfirst(rowvalA, rowI, ptrA, stopA, Base.Order.Forward)
ptrA += searchsortedfirst(view(rowvalA, ptrA:stopA), rowI) - 1
if ptrA <= stopA && rowvalA[ptrA] == rowI
push!(nzinds, j)
push!(nzvals, nzvalA[ptrA])
Expand Down Expand Up @@ -959,7 +959,7 @@ function getindex(x::AbstractSparseVector{Tv,Ti}, I::AbstractUnitRange) where {T
# locate the first j0, s.t. xnzind[j0] >= i0
j0 = searchsortedfirst(xnzind, i0)
# locate the last j1, s.t. xnzind[j1] <= i1
j1 = searchsortedlast(xnzind, i1, j0, m, Forward)
j1 = searchsortedlast(view(xnzind, j0:m), i1) + j0 - 1

# compute the number of non-zeros
jrgn = j0:j1
Expand Down
Loading