Skip to content

Commit

Permalink
Merge pull request #2 from dfdx/struct-grad
Browse files Browse the repository at this point in the history
gradient of structure fields
  • Loading branch information
dfdx authored Jan 19, 2019
2 parents 96dcb15 + 9e41a10 commit 0c22089
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 23 deletions.
1 change: 0 additions & 1 deletion src/diffrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
########################################################################

const DIFF_RULES = Vector{Tuple}()

const DIFF_PHS = Set([:w, :x, :y, :z, :i, :j, :k,])


Expand Down
87 changes: 76 additions & 11 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,91 @@
# GRAD RESULT #
########################################################################

# """
# Fill `fieldmap` property of a tape. `fieldmap` contains a dict of dicts
# (op_id -> field_name -> field_op_id). So `tape.fieldmap[x][:fld]` points
# to a variable representing `x.fld`.
# """
# function map_fields!(tape::Tape)
# fm = Dict{Int, Dict{Symbol, Int}}()
# for op in tape
# if op isa Call && op.fn == Base.getproperty
# op_id = op.args[1]
# field_name = tape[op.args[2]].val
# field_op_id = op.id
# if !haskey(fm, op_id)
# fm[op_id] = Dict()
# end
# fm[op_id][field_name] = field_op_id
# end
# end
# tape.fieldmap = fm
# end


# """
# Flatten field paths in a fieldmap
# """
# function field_paths(fm::Dict{Int, Dict{Symbol, Int}}; current_path=[], result=Dict())
# for (struct_id, dct) in fm
# # path = [current_path; struct_id]
# for (fld, fld_id) in dct
# path = [current_path; fld]
# # struct field is also a struct
# if haskey(fm, fld_id)
# field_paths(fm; current_path=path, result=result)
# else
# result[tuple(path...)] = fld_id
# end
# end
# end
# return result
# end


function field_paths(tape::Tape)
paths = Dict()
for op in reverse(tape.ops)
_op = op
path = []
while _op isa Call && _op.fn == Base.getproperty
field_name = tape[_op.args[2]].val
push!(path, field_name)
_op_id = _op.args[1]
_op = tape[_op_id]
end
if !isempty(path)
struct_id = _op.id
if !haskey(paths, struct_id)
paths[struct_id] = Dict()
end
paths[struct_id][(reverse(path)...)] = op.id
end
end
return paths
end


struct GradResult
tape::Tape
gvars::Dict{Int, Any} # gradient vars: argid -> gradient var
end


function GradResult(tape::Tape)
tape.fieldpaths = field_paths(tape)
gvars = Dict{Int,Any}()
# struct fields
for (argid, dct) in tape.sfields
for (argid, dct) in tape.fieldpaths
gvars[argid] = Dict(field_path => tape.derivs[var_id]
for (field_path, var_id) in dct
if haskey(tape.derivs, var_id)) # not all fields may have derivatives
end
# other arguments
struct_arg_ids = Set(keys(tape.sfields))
struct_arg_ids = Set(keys(tape.fieldpaths))
for op in tape
if op isa Input && !in(op.argid, struct_arg_ids)
gvars[op.argid] = tape.derivs[op.var.id]
if op isa Input && !in(op.id, struct_arg_ids)
gvars[op.id] = tape.derivs[op.id]
end
end
return GradResult(tape, gvars)
Expand All @@ -33,9 +99,9 @@ function Base.getindex(g::GradResult, argid::Int)
tape = g.tape
gvar = g.gvars[argid]
if isa(gvar, Dict)
return Dict(f => tape[id].var.val for (f, id) in gvar)
return Dict(f => tape[id].val for (f, id) in gvar)
else
return tape[gvar].var.val
return tape[gvar].val
end
end

Expand Down Expand Up @@ -71,9 +137,6 @@ end


function deriv_broadcast!(tape::Tape, op::AbstractOp, i::Int, dy::AbstractOp)
# 1. take basic elements (see in Yota, presumably just first())
# 2. find rule for basic elements
# 3. record_expr_broadcast!() - record `broadcast` and execute immediately
ex = to_unbroadcast_expr(tape, op)
dep_eltypes = [eltype(tape[arg].typ) for arg in op.args[2:end]]
dex = deriv_expr(ex, dep_eltypes, i-1)
Expand Down Expand Up @@ -104,6 +167,9 @@ function step_back!(tape::Tape, op::Union{Call}, i::Int)
end


"""
Backpropagate through the tape, record derivatives as new operations
"""
function back!(tape::Tape)
# z - final variable (usually a loss)
# y - resulting variable of current op
Expand All @@ -116,8 +182,7 @@ function back!(tape::Tape)
# set initial derivative value
tape.derivs[z.id] = dy.id
for op in reverse(tape.ops[1:end-1])
println(op)
if op isa Call # || op isa Bcast
if op isa Call && op.fn != Base.getproperty
for i=1:length(op.args)
# backpropagate only non-constant vars
# note that it also prevents backprop on 1st param of broadcast
Expand Down
28 changes: 17 additions & 11 deletions src/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,21 @@ end
const MaybeFunction = Union{Function, Nothing}

mutable struct Tape
ops::Vector{<:AbstractOp} # linearized execution graph
resultid::Int # id of result variable
derivs::Dict{Int,Int} # derivs[var_id] == grad_id
# sfields::Dict{Int, Dict} # mapping of argid -> Dict(struct field paths -> var id)
compiled::MaybeFunction # compiled tape or nothing
device::AbstractDevice # function to use for moving intermediate results to device
# linearized execution graph
ops::Vector{<:AbstractOp}
# id of result variable
resultid::Int
# derivs[var_id] == grad_id
derivs::Dict{Int,Int}
# mapping of argid -> Dict(struct field paths -> var id)
fieldpaths::Dict{Int, Dict}
# compiled tape or nothing
compiled::MaybeFunction
# function to use for moving intermediate results to device
device::AbstractDevice
end

Tape(device::AbstractDevice) = Tape(AbstractOp[], -1, Dict(), nothing, device)
Tape(device::AbstractDevice) = Tape(AbstractOp[], -1, Dict(), Dict(), nothing, device)
Tape() = Tape(CPU())


Expand Down Expand Up @@ -160,7 +166,7 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false)
end


function record_expr!(tape::Tape, x::Symbol; st=Dict(), bcast=false)
function record_expr!(tape::Tape, x::Symbol; st=Dict(), bcast=false)
id = st[x]
return record!(tape, Assign, id, tape[id].val)
end
Expand Down Expand Up @@ -225,11 +231,11 @@ end


function squash_assigned(tape::Tape)
new_tape = copy_with(tape; ops=AbstractOp[])
new_tape = copy_with(tape; ops=AbstractOp[])
st = Dict()
for (id, op) in enumerate(tape)
if op isa Assign
st[id] = op.src_id
if op isa Assign
st[id] = op.src_id
else
# record any other operations as is
new_id = length(new_tape) + 1
Expand Down

0 comments on commit 0c22089

Please sign in to comment.