From 73866b5c0e637c1bb6aa02388f143f8e9ed1ce66 Mon Sep 17 00:00:00 2001 From: Miles Lubin Date: Tue, 9 Feb 2016 17:47:35 -0500 Subject: [PATCH] allow AbstractVector input for gradients --- src/api/gradient.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/api/gradient.jl b/src/api/gradient.jl index 7034f7da..5d44ae6d 100644 --- a/src/api/gradient.jl +++ b/src/api/gradient.jl @@ -4,13 +4,13 @@ # Exposed API methods # #---------------------# -@generated function gradient!{T,A}(output::Vector{T}, f, x::Vector, ::Type{A}=Void; +@generated function gradient!{T,A}(output::AbstractVector{T}, f, x::AbstractVector, ::Type{A}=Void; chunk_size::Int=default_chunk_size, cache::ForwardDiffCache=dummy_cache) if A <: Void - return_stmt = :(gradient!(output, result)::Vector{T}) + return_stmt = :(gradient!(output, result)::typeof(output)) elseif A <: AllResults - return_stmt = :(gradient!(output, result)::Vector{T}, result) + return_stmt = :(gradient!(output, result)::typeof(output), result) else error("invalid argument $A passed to FowardDiff.gradient") end @@ -21,7 +21,7 @@ end end -@generated function gradient{T,A}(f, x::Vector{T}, ::Type{A}=Void; +@generated function gradient{T,A}(f, x::AbstractVector{T}, ::Type{A}=Void; chunk_size::Int=default_chunk_size, cache::ForwardDiffCache=dummy_cache) if A <: Void @@ -43,14 +43,14 @@ function gradient{A}(f, ::Type{A}=Void; chunk_size::Int=default_chunk_size, cache::ForwardDiffCache=ForwardDiffCache()) if mutates - function g!(output::Vector, x::Vector) + function g!(output::AbstractVector, x::AbstractVector) return ForwardDiff.gradient!(output, f, x, A; chunk_size=chunk_size, cache=cache) end return g! else - function g(x::Vector) + function g(x::AbstractVector) return ForwardDiff.gradient(f, x, A; chunk_size=chunk_size, cache=cache) @@ -61,7 +61,7 @@ end # Calculate gradient of a given function # #----------------------------------------# -function _calc_gradient{S}(f, x::Vector, ::Type{S}, +function _calc_gradient{S}(f, x::AbstractVector, ::Type{S}, chunk_size::Int, cache::ForwardDiffCache) X = Val{length(x)} @@ -69,7 +69,7 @@ function _calc_gradient{S}(f, x::Vector, ::Type{S}, return _calc_gradient(f, x, S, X, C, cache) end -@generated function _calc_gradient{T,S,xlen,chunk_size}(f, x::Vector{T}, ::Type{S}, +@generated function _calc_gradient{T,S,xlen,chunk_size}(f, x::AbstractVector{T}, ::Type{S}, X::Type{Val{xlen}}, C::Type{Val{chunk_size}}, cache::ForwardDiffCache)