Skip to content

Commit

Permalink
Merge pull request #3 from dfdx/sum-mean-grad
Browse files Browse the repository at this point in the history
gradient for mea() and sum() with keywords
  • Loading branch information
dfdx authored Jan 19, 2019
2 parents 0c22089 + 77b168d commit 425b1ab
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 46 deletions.
10 changes: 6 additions & 4 deletions src/diffrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,15 @@ end
# sum() and mean()
@diffrule sum(x::Real) x ds
@diffrule sum(x::AbstractArray) x sum_grad(x, ds)
@diffrule sum(x::AbstractArray, y::Int) x ones(size(x)) .* ds
@diffrule sum(x::AbstractArray, y::Int) y 0.0
@diffrule Base._sum(x::AbstractArray, y::Int) x ones(size(x)) .* ds
@diffrule Base._sum(x::AbstractArray, y::Int) y 0.0

@diffrule Statistics.mean(x::Real) x ds
@diffrule Statistics.mean(x::AbstractArray) x ones(size(x)) ./ length(x) .* ds
@diffrule Statistics.mean(x::AbstractArray, y::Int) x ones(size(x)) ./ length(x) .* ds
@diffrule Statistics.mean(x::AbstractArray, y::Int) y 0.0
# @diffrule Statistics.mean(x::AbstractArray, y::Int) x ones(size(x)) ./ length(x) .* ds
# @diffrule Statistics.mean(x::AbstractArray, y::Int) y 0.0
@diffrule Statistics._mean(x::AbstractArray, y::Int) x mean_grad(x, ds)
@diffrule Statistics._mean(x::AbstractArray, y::Int) y 0.0

# dot()
@diffrule dot(x::Real, y::Real) x y * ds
Expand Down
42 changes: 0 additions & 42 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,6 @@
# GRAD RESULT #
########################################################################

# """
# Fill `fieldmap` property of a tape. `fieldmap` contains a dict of dicts
# (op_id -> field_name -> field_op_id). So `tape.fieldmap[x][:fld]` points
# to a variable representing `x.fld`.
# """
# function map_fields!(tape::Tape)
# fm = Dict{Int, Dict{Symbol, Int}}()
# for op in tape
# if op isa Call && op.fn == Base.getproperty
# op_id = op.args[1]
# field_name = tape[op.args[2]].val
# field_op_id = op.id
# if !haskey(fm, op_id)
# fm[op_id] = Dict()
# end
# fm[op_id][field_name] = field_op_id
# end
# end
# tape.fieldmap = fm
# end


# """
# Flatten field paths in a fieldmap
# """
# function field_paths(fm::Dict{Int, Dict{Symbol, Int}}; current_path=[], result=Dict())
# for (struct_id, dct) in fm
# # path = [current_path; struct_id]
# for (fld, fld_id) in dct
# path = [current_path; fld]
# # struct field is also a struct
# if haskey(fm, fld_id)
# field_paths(fm; current_path=path, result=result)
# else
# result[tuple(path...)] = fld_id
# end
# end
# end
# return result
# end


function field_paths(tape::Tape)
paths = Dict()
for op in reverse(tape.ops)
Expand Down
5 changes: 5 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,8 @@ end
function sum_grad(x::AbstractArray, ds; opts...)
return ones(size(x)) .* ds
end


function mean_grad(x::AbstractArray, ds)
return ones(size(x)) ./ length(x) .* ds
end

0 comments on commit 425b1ab

Please sign in to comment.