diff --git a/Project.toml b/Project.toml index 0eca0754..094cc3d2 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,8 @@ DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "LinearAlgebra", "SparseArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "LinearAlgebra", "SparseArrays", "Test", "InteractiveUtils", "BenchmarkTools"] diff --git a/src/apiutils.jl b/src/apiutils.jl index 7424ffa2..654c64a3 100644 --- a/src/apiutils.jl +++ b/src/apiutils.jl @@ -79,6 +79,6 @@ function seed!(duals::AbstractArray{Dual{T,V,N}}, x, index, offset = index - 1 seed_inds = 1:chunksize dual_inds = seed_inds .+ offset - duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), seeds[seed_inds]) + duals[dual_inds] .= Dual{T,V,N}.(view(x, dual_inds), getindex.(Ref(seeds), seed_inds)) return duals end diff --git a/test/AllocationsTest.jl b/test/AllocationsTest.jl new file mode 100644 index 00000000..cf113acd --- /dev/null +++ b/test/AllocationsTest.jl @@ -0,0 +1,27 @@ +module AllocationsTest + +using ForwardDiff +using BenchmarkTools + +include(joinpath(dirname(@__FILE__), "utils.jl")) + +@testset "Test seed! allocations" begin + x = rand(1000) + cfg = ForwardDiff.GradientConfig(nothing, x) + + balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $(cfg.seeds)) + @test balloc == 0 + + balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $(cfg.seeds[1])) + @test balloc == 0 + + index = 1 + balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $index, $(cfg.seeds)) + @test balloc == 0 + + index = 1 + balloc = @ballocated ForwardDiff.seed!($(cfg.duals), $x, $index, $(cfg.seeds[1])) + @test balloc == 0 +end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e19c9527..8c519f12 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -31,3 +31,7 @@ println("done (took $t seconds).") println("Testing miscellaneous functionality...") t = @elapsed include("MiscTest.jl") println("done (took $t seconds).") + +println("Testing allocations...") +t = @elapsed include("AllocationsTest.jl") +println("done (took $t seconds).")