From 1156e34838380695355a1e56251abb2f2d2d2f18 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Jan 2024 09:00:01 -0500 Subject: [PATCH 1/3] Test master From 4ad919ae7d666ec5a23c5844925b355f3ef6bb8f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 16 Jan 2024 10:50:04 -0500 Subject: [PATCH 2/3] Fix num_vecjac --- Project.toml | 2 +- ext/SparseDiffToolsZygoteExt.jl | 3 ++- src/differentiation/jaches_products.jl | 7 ++++--- src/differentiation/vecjac_products.jl | 11 ++++++----- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 6537d778..1055d5cf 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "2.15.0" +version = "2.15.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/SparseDiffToolsZygoteExt.jl b/ext/SparseDiffToolsZygoteExt.jl index 49c3fbf0..1563c216 100644 --- a/ext/SparseDiffToolsZygoteExt.jl +++ b/ext/SparseDiffToolsZygoteExt.jl @@ -45,7 +45,8 @@ end ### Jac, Hes products -function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v)) where {F} +function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(v)) where {F} g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) end diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 604a10ce..dead4290 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -32,8 +32,8 @@ function auto_jacvec(f, x, v) vec(partials.(vec(f(y)), 1)) end -function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v); - compute_f0 = true) +function num_jacvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(v); compute_f0 = true) vv = reshape(v, axes(x)) compute_f0 && (f(cache1, x)) T = eltype(x) @@ -134,7 +134,8 @@ function autonum_hesvec(f, x, v) partials.(g(Dual{DeivVecTag}.(x, v)), 1) end -function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v)) +function num_hesvecgrad!(dy, g, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(v)) T = eltype(x) # Should it be min? max? mean? ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index 7f827583..dd559643 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -1,5 +1,5 @@ -function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), cache3 = similar(v); - compute_f0 = true) where {F} +function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), + cache3 = similar(x); compute_f0 = true) where {F} compute_f0 && (f(cache1, x)) T = eltype(x) # Should it be min? max? mean? @@ -22,10 +22,11 @@ function num_vecjac(f::F, x, v, f0 = nothing) where {F} # Should it be min? max? mean? ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) du = similar(x) - cache = copy(x) + cache = similar(x) + copyto!(cache, x) for i in 1:length(x) cache[i] += ϵ - f0 = f(x) + f0 = f(cache) cache[i] = x[i] du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1] end @@ -93,7 +94,7 @@ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, end function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F} - cache = (similar(fu), similar(fu), similar(fu)) + cache = (similar(fu), similar(fu), similar(u)) pullback = nothing return AutoDiffVJP(f, u, cache, autodiff, pullback) end From 7e06e45e3ad2e3f5c9a9159ffd86dcec46fa4462 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 17 Jan 2024 10:41:13 -0500 Subject: [PATCH 3/3] Use allowed_setindex! for cache --- src/SparseDiffTools.jl | 2 +- src/differentiation/vecjac_products.jl | 21 ++++++++++++++++++--- test/test_vecjac_products.jl | 17 ++++++++++++++++- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 238c9323..7390ddcc 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -14,7 +14,7 @@ import ADTypes: AbstractADType, AutoSparseZygote, AbstractSparseForwardMode, import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD # Array Packages using ArrayInterface, SparseArrays -import ArrayInterface: matrix_colors +import ArrayInterface: matrix_colors, allowed_setindex! import StaticArrays import StaticArrays: StaticArray, SArray, MArray, Size # Others diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index dd559643..4113e655 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -15,6 +15,21 @@ function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v), return du end +# Special Non-Allocating case for StaticArrays +function num_vecjac(f::F, x::SArray, v::SArray, f0 = nothing) where {F} + f0 === nothing ? (_f0 = f(x)) : (_f0 = f0) + vv = reshape(v, axes(_f0)) + T = eltype(x) + ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) + du = zeros(typeof(x)) + for i in 1:length(x) + cache = Base.setindex(x, x[i] + ϵ, i) + f0 = f(cache) + du = Base.setindex(du, (((f0 .- _f0) ./ ϵ)' * vv), i) + end + return du +end + function num_vecjac(f::F, x, v, f0 = nothing) where {F} f0 === nothing ? (_f0 = f(x)) : (_f0 = f0) vv = reshape(v, axes(_f0)) @@ -25,10 +40,10 @@ function num_vecjac(f::F, x, v, f0 = nothing) where {F} cache = similar(x) copyto!(cache, x) for i in 1:length(x) - cache[i] += ϵ + cache = allowed_setindex!(cache, x[i] + ϵ, i) f0 = f(cache) - cache[i] = x[i] - du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1] + cache = allowed_setindex!(cache, x[i], i) + du = allowed_setindex!(du, (((f0 .- _f0) ./ ϵ)' * vv)[1], i) end return vec(du) end diff --git a/test/test_vecjac_products.jl b/test/test_vecjac_products.jl index 44dadca9..cc221ba6 100644 --- a/test/test_vecjac_products.jl +++ b/test/test_vecjac_products.jl @@ -1,5 +1,6 @@ -using SparseDiffTools, Zygote +using SparseDiffTools, Zygote, ForwardDiff using LinearAlgebra, Test +using StaticArrays using Random Random.seed!(123) @@ -170,3 +171,17 @@ L = VecJac(f3_iip, copy(x); autodiff = AutoFiniteDiff(), fu = copy(y)) L = VecJac(f3_oop, copy(x); autodiff = AutoZygote()) @test size(L) == (length(x), length(y)) @test L * y ≈ Zygote.jacobian(f3_oop, copy(x))[1]' * y + +@info "Testing StaticArrays" + +const A_sa = rand(SMatrix{4, 4, Float32}) +_f_sa(x) = A_sa * (x .^ 2) + +u = rand(SVector{4, Float32}) +v = rand(SVector{4, Float32}) + +J = ForwardDiff.jacobian(_f_sa, u) +Jᵀv_true = J' * v + +@test num_vecjac(_f_sa, u, v) isa SArray +@test num_vecjac(_f_sa, u, v)≈Jᵀv_true atol=1e-3