From 48f2f4d28fe0e07e91bad9b48b3378f40849699c Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 16 Jul 2019 21:30:06 -0400 Subject: [PATCH] don't use recursive factorization if can't setindex --- src/init.jl | 3 +++ src/linear_nonlinear.jl | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/init.jl b/src/init.jl index c7f521faf..5ead31b04 100644 --- a/src/init.jl +++ b/src/init.jl @@ -1,5 +1,6 @@ value(x) = x cuify(x) = error("To use LinSolveGPUFactorize, you must do `using CuArrays`") +can_setindex(x) = true function __init__() @require ApproxFun="28f2ccd6-bb30-5033-b560-165f7b14dc2f" begin @@ -88,6 +89,8 @@ function __init__() value(x::Flux.Tracker.TrackedReal) = x.data value(x::Flux.Tracker.TrackedArray) = x.data + can_setindex(x::Flux.Tracker.TrackedArray) = false + # Support adaptive with non-tracked time @inline function ODE_DEFAULT_NORM(u::Flux.Tracker.TrackedArray,t) where {N} sqrt(sum(x->ODE_DEFAULT_NORM(x[1],x[2]),zip((value(x) for x in u),Iterators.repeated(t))) / length(u)) diff --git a/src/linear_nonlinear.jl b/src/linear_nonlinear.jl index 80cf9913a..0b98c77a0 100644 --- a/src/linear_nonlinear.jl +++ b/src/linear_nonlinear.jl @@ -53,7 +53,7 @@ function (p::DefaultLinSolve)(x,A,b,update_matrix=false;tol=nothing, kwargs...) if update_matrix if typeof(A) <: Matrix blasvendor = BLAS.vendor() - if (blasvendor === :openblas || blasvendor === :openblas64) && size(A,1) <= 500 # if the user doesn't use OpenBLAS, we assume that is a much better BLAS implementation like MKL + if (blasvendor === :openblas || blasvendor === :openblas64) && size(A,1) <= 500 && can_setindex(x) # if the user doesn't use OpenBLAS, we assume that is a much better BLAS implementation like MKL p.A = RecursiveFactorization.lu!(A) else p.A = lu!(A)