Skip to content

Commit

Permalink
Merge pull request #32 from thomvet/allocations
Browse files Browse the repository at this point in the history
Dispatches get_tmp on wrapper type of cache,
  • Loading branch information
ChrisRackauckas authored Aug 10, 2022
2 parents addfd17 + 86f9832 commit 3de0cde
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 67 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PreallocationTools"
uuid = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "0.4.0"
version = "0.5.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
12 changes: 10 additions & 2 deletions src/PreallocationTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,28 @@ function get_tmp(dc::DiffCache, u::T) where {T <: ForwardDiff.Dual}
if nelem > length(dc.dual_du)
enlargedualcache!(dc, nelem)
end
ArrayInterfaceCore.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
end

function get_tmp(dc::DiffCache, u::AbstractArray{T}) where {T <: ForwardDiff.Dual}
nelem = div(sizeof(T), sizeof(eltype(dc.dual_du))) * length(dc.du)
if nelem > length(dc.dual_du)
enlargedualcache!(dc, nelem)
end
ArrayInterfaceCore.restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
_restructure(dc.du, reinterpret(T, view(dc.dual_du, 1:nelem)))
end

get_tmp(dc::DiffCache, u::Number) = dc.du
get_tmp(dc::DiffCache, u::AbstractArray) = dc.du

function _restructure(normal_cache::Array, duals)
reshape(duals, size(normal_cache)...)
end

function _restructure(normal_cache::AbstractArray, duals)
ArrayInterfaceCore.restructure(normal_cache, duals)
end

function enlargedualcache!(dc, nelem) #warning comes only once per dualcache.
chunksize = div(nelem, length(dc.du)) - 1
@warn "The supplied dualcache was too small and was enlarged. This incurrs allocations
Expand Down
148 changes: 85 additions & 63 deletions test/core_dispatch.jl
Original file line number Diff line number Diff line change
@@ -1,74 +1,96 @@
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools, ForwardDiff, LabelledArrays,
using LinearAlgebra, Test, PreallocationTools, ForwardDiff, LabelledArrays,
RecursiveArrayTools

#Base Array tests
function test(u0, dual, chunk_size)
cache = PreallocationTools.dualcache(u0, chunk_size)
allocs_normal1 = @allocated get_tmp(cache, u0)
allocs_normal2 = @allocated get_tmp(cache, first(u0))
allocs_dual1 = @allocated get_tmp(cache, dual)
allocs_dual2 = @allocated get_tmp(cache, first(dual))
result_normal1 = get_tmp(cache, u0)
result_normal2 = get_tmp(cache, first(u0))
result_dual1 = get_tmp(cache, dual)
result_dual2 = get_tmp(cache, first(dual))
return allocs_normal1, allocs_normal2, allocs_dual1, allocs_dual2, result_normal1,
result_normal2, result_dual1,
result_dual2
end

#Setup Base Array tests
chunk_size = 5
u0_B = ones(5, 5)
dual_B = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size}, 2, 2)
cache_B = dualcache(u0_B, chunk_size)
tmp_du_BA = get_tmp(cache_B, u0_B)
tmp_dual_du_BA = get_tmp(cache_B, dual_B)
tmp_du_BN = get_tmp(cache_B, u0_B[1])
tmp_dual_du_BN = get_tmp(cache_B, dual_B[1])
@test size(tmp_du_BA) == size(u0_B)
@test typeof(tmp_du_BA) == typeof(u0_B)
@test eltype(tmp_du_BA) == eltype(u0_B)
@test size(tmp_dual_du_BA) == size(u0_B)
@test typeof(tmp_dual_du_BA) == typeof(dual_B)
@test eltype(tmp_dual_du_BA) == eltype(dual_B)
@test size(tmp_du_BN) == size(u0_B)
@test typeof(tmp_du_BN) == typeof(u0_B)
@test eltype(tmp_du_BN) == eltype(u0_B)
@test size(tmp_dual_du_BN) == size(u0_B)
@test typeof(tmp_dual_du_BN) == typeof(dual_B)
@test eltype(tmp_dual_du_BN) == eltype(dual_B)
u0 = ones(5, 5)
dual = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
chunk_size}, 5, 5)
results = test(u0, dual, chunk_size)
#allocation tests
@test results[1] == 0
@test results[2] == 0
@test results[3] == 0
@test results[4] == 0
#size tests
@test size(results[5]) == size(u0)
@test size(results[6]) == size(u0)
@test size(results[7]) == size(u0)
@test size(results[8]) == size(u0)
#type tests
@test typeof(results[5]) == typeof(u0)
@test typeof(results[6]) == typeof(u0)
@test_broken typeof(results[7]) == typeof(dual)
@test_broken typeof(results[8]) == typeof(dual)
#eltype tests
@test eltype(results[5]) == eltype(u0)
@test eltype(results[7]) == eltype(dual)

#LArray tests
chunk_size = 4
u0_L = LArray((2, 2); a = 1.0, b = 1.0, c = 1.0, d = 1.0)
zerodual = zero(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
u0 = LArray((2, 2); a = 1.0, b = 1.0, c = 1.0, d = 1.0)
zerodual = zero(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
chunk_size})
dual_L = LArray((2, 2); a = zerodual, b = zerodual, c = zerodual, d = zerodual)
cache_L = dualcache(u0_L, chunk_size)
tmp_du_LA = get_tmp(cache_L, u0_L)
tmp_dual_du_LA = get_tmp(cache_L, dual_L)
tmp_du_LN = get_tmp(cache_L, u0_L[1])
tmp_dual_du_LN = get_tmp(cache_L, dual_L[1])
@test size(tmp_du_LA) == size(u0_L)
@test typeof(tmp_du_LA) == typeof(u0_L)
@test eltype(tmp_du_LA) == eltype(u0_L)
@test size(tmp_dual_du_LA) == size(u0_L)
@test typeof(tmp_dual_du_LA) == typeof(dual_L)
@test eltype(tmp_dual_du_LA) == eltype(dual_L)
@test size(tmp_du_LN) == size(u0_L)
@test typeof(tmp_du_LN) == typeof(u0_L)
@test eltype(tmp_du_LN) == eltype(u0_L)
@test size(tmp_dual_du_LN) == size(u0_L)
@test typeof(tmp_dual_du_LN) == typeof(dual_L)
@test eltype(tmp_dual_du_LN) == eltype(dual_L)
dual = LArray((2, 2); a = zerodual, b = zerodual, c = zerodual, d = zerodual)
results = test(u0, dual, chunk_size)
#allocation tests
@test results[1] == 0
@test results[2] == 0
@test_broken results[3] == 0
@test_broken results[4] == 0
#size tests
@test size(results[5]) == size(u0)
@test size(results[6]) == size(u0)
@test size(results[7]) == size(u0)
@test size(results[8]) == size(u0)
#type tests
@test typeof(results[5]) == typeof(u0)
@test typeof(results[6]) == typeof(u0)
@test typeof(results[7]) == typeof(dual)
@test typeof(results[8]) == typeof(dual)
#eltype tests
@test eltype(results[5]) == eltype(u0)
@test eltype(results[7]) == eltype(dual)

#ArrayPartition tests
u0_AP = ArrayPartition(ones(2, 2), ones(3, 3))
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
chunk_size = 2
u0 = ArrayPartition(ones(2, 2), ones(3, 3))
dual_a = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
chunk_size}, 2, 2)
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{typeof(something), Float64}, Float64,
dual_b = zeros(ForwardDiff.Dual{ForwardDiff.Tag{nothing, Float64}, Float64,
chunk_size}, 3, 3)
dual_AP = ArrayPartition(dual_a, dual_b)
cache_AP = dualcache(u0_AP, chunk_size)
tmp_du_APA = get_tmp(cache_AP, u0_AP)
tmp_dual_du_APA = get_tmp(cache_AP, dual_AP)
tmp_du_APN = get_tmp(cache_AP, u0_AP[1])
tmp_dual_du_APN = get_tmp(cache_AP, dual_AP[1])
@test size(tmp_du_APA) == size(u0_AP)
@test typeof(tmp_du_APA) == typeof(u0_AP)
@test eltype(tmp_du_APA) == eltype(u0_AP)
@test size(tmp_dual_du_APA) == size(u0_AP)
@test typeof(tmp_dual_du_APA) == typeof(dual_AP)
@test eltype(tmp_dual_du_APA) == eltype(dual_AP)
@test size(tmp_du_APN) == size(u0_AP)
@test typeof(tmp_du_APN) == typeof(u0_AP)
@test eltype(tmp_du_APN) == eltype(u0_AP)
@test size(tmp_dual_du_APN) == size(u0_AP)
@test typeof(tmp_dual_du_APN) == typeof(dual_AP)
@test eltype(tmp_dual_du_APN) == eltype(dual_AP)
dual = ArrayPartition(dual_a, dual_b)
results = test(u0, dual, chunk_size)
#allocation tests
@test results[1] == 0
@test results[2] == 0
@test_broken results[3] == 0
@test_broken results[4] == 0
#size tests
@test size(results[5]) == size(u0)
@test size(results[6]) == size(u0)
@test size(results[7]) == size(u0)
@test size(results[8]) == size(u0)
#type tests
@test typeof(results[5]) == typeof(u0)
@test typeof(results[6]) == typeof(u0)
@test typeof(results[7]) == typeof(dual)
@test typeof(results[8]) == typeof(dual)
#eltype tests
@test eltype(results[5]) == eltype(u0)
@test eltype(results[7]) == eltype(dual)
3 changes: 2 additions & 1 deletion test/core_odes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ sol = solve(prob, TRBDF2(chunk_size = chunk_size))
@test sol.retcode == :Success

#with auto-detected chunk_size
prob = ODEProblem(foo, ones(5, 5), (0.0, 1.0), (ones(5, 5), dualcache(zeros(5, 5))))
cache = dualcache(zeros(5, 5))
prob = ODEProblem(foo, ones(5, 5), (0.0, 1.0), (A, cache))
sol = solve(prob, TRBDF2())
@test sol.retcode == :Success

Expand Down

0 comments on commit 3de0cde

Please sign in to comment.