Skip to content

Commit

Permalink
allow AbstractVector input for gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
mlubin committed Feb 9, 2016
1 parent 675818e commit 58f9370
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/api/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

# Exposed API methods #
#---------------------#
@generated function gradient!{T,A}(output::Vector{T}, f, x::Vector, ::Type{A}=Void;
@generated function gradient!{T,A}(output::AbstractVector{T}, f, x::AbstractVector, ::Type{A}=Void;
chunk_size::Int=default_chunk_size,
cache::ForwardDiffCache=dummy_cache)
if A <: Void
return_stmt = :(gradient!(output, result)::Vector{T})
return_stmt = :(gradient!(output, result)::AbstractVector{T})
elseif A <: AllResults
return_stmt = :(gradient!(output, result)::Vector{T}, result)
return_stmt = :(gradient!(output, result)::AbstractVector{T}, result)

This comment has been minimized.

Copy link
@jrevels

jrevels Feb 9, 2016

Member

Might as well remove these assertions, or change them to typeof(output) if you see type inference failing.

else
error("invalid argument $A passed to FowardDiff.gradient")
end
Expand All @@ -21,7 +21,7 @@
end
end

@generated function gradient{T,A}(f, x::Vector{T}, ::Type{A}=Void;
@generated function gradient{T,A}(f, x::AbstractVector{T}, ::Type{A}=Void;
chunk_size::Int=default_chunk_size,
cache::ForwardDiffCache=dummy_cache)
if A <: Void
Expand All @@ -43,14 +43,14 @@ function gradient{A}(f, ::Type{A}=Void;
chunk_size::Int=default_chunk_size,
cache::ForwardDiffCache=ForwardDiffCache())
if mutates
function g!(output::Vector, x::Vector)
function g!(output::AbstractVector, x::AbstractVector)
return ForwardDiff.gradient!(output, f, x, A;
chunk_size=chunk_size,
cache=cache)
end
return g!
else
function g(x::Vector)
function g(x::AbstractVector)
return ForwardDiff.gradient(f, x, A;
chunk_size=chunk_size,
cache=cache)
Expand All @@ -61,15 +61,15 @@ end

# Calculate gradient of a given function #
#----------------------------------------#
function _calc_gradient{S}(f, x::Vector, ::Type{S},
function _calc_gradient{S}(f, x::AbstractVector, ::Type{S},
chunk_size::Int,
cache::ForwardDiffCache)
X = Val{length(x)}
C = Val{chunk_size}
return _calc_gradient(f, x, S, X, C, cache)
end

@generated function _calc_gradient{T,S,xlen,chunk_size}(f, x::Vector{T}, ::Type{S},
@generated function _calc_gradient{T,S,xlen,chunk_size}(f, x::AbstractVector{T}, ::Type{S},
X::Type{Val{xlen}},
C::Type{Val{chunk_size}},
cache::ForwardDiffCache)
Expand Down

0 comments on commit 58f9370

Please sign in to comment.