Skip to content

Commit

Permalink
Merge pull request #295 from JuliaDiffEq/setindex
Browse files Browse the repository at this point in the history
don't use recursive factorization if can't setindex
  • Loading branch information
ChrisRackauckas authored Jul 17, 2019
2 parents e2b91ef + 48f2f4d commit 68fd464
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 3 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
value(x) = x
cuify(x) = error("To use LinSolveGPUFactorize, you must do `using CuArrays`")
can_setindex(x) = true

function __init__()
@require ApproxFun="28f2ccd6-bb30-5033-b560-165f7b14dc2f" begin
Expand Down Expand Up @@ -88,6 +89,8 @@ function __init__()
value(x::Flux.Tracker.TrackedReal) = x.data
value(x::Flux.Tracker.TrackedArray) = x.data

can_setindex(x::Flux.Tracker.TrackedArray) = false

# Support adaptive with non-tracked time
@inline function ODE_DEFAULT_NORM(u::Flux.Tracker.TrackedArray,t) where {N}
sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u))
Expand Down
2 changes: 1 addition & 1 deletion src/linear_nonlinear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function (p::DefaultLinSolve)(x,A,b,update_matrix=false;tol=nothing, kwargs...)
if update_matrix
if typeof(A) <: Matrix
blasvendor = BLAS.vendor()
if (blasvendor === :openblas || blasvendor === :openblas64) && size(A,1) <= 500 # if the user doesn't use OpenBLAS, we assume that is a much better BLAS implementation like MKL
if (blasvendor === :openblas || blasvendor === :openblas64) && size(A,1) <= 500 && can_setindex(x) # if the user doesn't use OpenBLAS, we assume that is a much better BLAS implementation like MKL
p.A = RecursiveFactorization.lu!(A)
else
p.A = lu!(A)
Expand Down

0 comments on commit 68fd464

Please sign in to comment.