diff --git a/src/diffrules.jl b/src/diffrules.jl index 27f8041..991ef00 100644 --- a/src/diffrules.jl +++ b/src/diffrules.jl @@ -262,13 +262,15 @@ end # sum() and mean() @diffrule sum(x::Real) x ds @diffrule sum(x::AbstractArray) x sum_grad(x, ds) -@diffrule sum(x::AbstractArray, y::Int) x ones(size(x)) .* ds -@diffrule sum(x::AbstractArray, y::Int) y 0.0 +@diffrule Base._sum(x::AbstractArray, y::Int) x ones(size(x)) .* ds +@diffrule Base._sum(x::AbstractArray, y::Int) y 0.0 @diffrule Statistics.mean(x::Real) x ds @diffrule Statistics.mean(x::AbstractArray) x ones(size(x)) ./ length(x) .* ds -@diffrule Statistics.mean(x::AbstractArray, y::Int) x ones(size(x)) ./ length(x) .* ds -@diffrule Statistics.mean(x::AbstractArray, y::Int) y 0.0 +# @diffrule Statistics.mean(x::AbstractArray, y::Int) x ones(size(x)) ./ length(x) .* ds +# @diffrule Statistics.mean(x::AbstractArray, y::Int) y 0.0 +@diffrule Statistics._mean(x::AbstractArray, y::Int) x mean_grad(x, ds) +@diffrule Statistics._mean(x::AbstractArray, y::Int) y 0.0 # dot() @diffrule dot(x::Real, y::Real) x y * ds diff --git a/src/grad.jl b/src/grad.jl index bbef37a..7c75a18 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -2,48 +2,6 @@ # GRAD RESULT # ######################################################################## -# """ -# Fill `fieldmap` property of a tape. `fieldmap` contains a dict of dicts -# (op_id -> field_name -> field_op_id). So `tape.fieldmap[x][:fld]` points -# to a variable representing `x.fld`. -# """ -# function map_fields!(tape::Tape) -# fm = Dict{Int, Dict{Symbol, Int}}() -# for op in tape -# if op isa Call && op.fn == Base.getproperty -# op_id = op.args[1] -# field_name = tape[op.args[2]].val -# field_op_id = op.id -# if !haskey(fm, op_id) -# fm[op_id] = Dict() -# end -# fm[op_id][field_name] = field_op_id -# end -# end -# tape.fieldmap = fm -# end - - -# """ -# Flatten field paths in a fieldmap -# """ -# function field_paths(fm::Dict{Int, Dict{Symbol, Int}}; current_path=[], result=Dict()) -# for (struct_id, dct) in fm -# # path = [current_path; struct_id] -# for (fld, fld_id) in dct -# path = [current_path; fld] -# # struct field is also a struct -# if haskey(fm, fld_id) -# field_paths(fm; current_path=path, result=result) -# else -# result[tuple(path...)] = fld_id -# end -# end -# end -# return result -# end - - function field_paths(tape::Tape) paths = Dict() for op in reverse(tape.ops) diff --git a/src/helpers.jl b/src/helpers.jl index 1700900..c76e2da 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -18,3 +18,8 @@ end function sum_grad(x::AbstractArray, ds; opts...) return ones(size(x)) .* ds end + + +function mean_grad(x::AbstractArray, ds) + return ones(size(x)) ./ length(x) .* ds +end