Skip to content

Commit

Permalink
Merge pull request #1 from dfdx/broadcast-grad
Browse files Browse the repository at this point in the history
_grad for broadcast
  • Loading branch information
dfdx authored Jan 16, 2019
2 parents c3b8471 + 25114db commit 96dcb15
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
9 changes: 6 additions & 3 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ function deriv_broadcast!(tape::Tape, op::AbstractOp, i::Int, dy::AbstractOp)
# 3. record_expr_broadcast!() - record `broadcast` and execute immediately
ex = to_unbroadcast_expr(tape, op)
dep_eltypes = [eltype(tape[arg].typ) for arg in op.args[2:end]]
dex = deriv_expr(ex, dep_eltypes, i)
st = Dict(Symbol("%$i") => i for i in op.args)
dex = deriv_expr(ex, dep_eltypes, i-1)
st = Dict(Symbol("%$id") => id for id in op.args)
st[:ds] = dy.id
ret_id = record_expr!(tape, dex; st=st, bcast=true)
return tape[ret_id]
Expand Down Expand Up @@ -105,7 +105,9 @@ end


function back!(tape::Tape)
# z - final variable, y - resulting variable of current op, x - dependencies of y
# z - final variable (usually a loss)
# y - resulting variable of current op
# x - dependencies of y
# dy - derivative of z w.r.t. y
z = tape[tape.resultid]
# using Float32 for seed since for 64-bit args it will be expanded anyway
Expand All @@ -118,6 +120,7 @@ function back!(tape::Tape)
if op isa Call # || op isa Bcast
for i=1:length(op.args)
# backpropagate only non-constant vars
# note that it also prevents backprop on 1st param of broadcast
arg_op = tape[op.args[i]]
if !isa(arg_op, Constant)
step_back!(tape, op, i)
Expand Down
36 changes: 29 additions & 7 deletions src/tape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false)
new_op_args[i] = st[x]
elseif Meta.isexpr(x, :call)
# recursively record arg expression
arg_id = record_expr!(tape, x; st=st)
arg_id = record_expr!(tape, x; st=st, bcast=bcast)
new_op_args[i] = arg_id
else
# treat as constant
Expand All @@ -160,9 +160,9 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false)
end


function record_expr!(tape::Tape, x::Symbol; st=Dict(), bcast=false)
ds_id = st[:ds]
return record!(tape, Assign, ds_id, tape[ds_id].val)
function record_expr!(tape::Tape, x::Symbol; st=Dict(), bcast=false)
id = st[x]
return record!(tape, Assign, id, tape[id].val)
end


Expand Down Expand Up @@ -193,10 +193,12 @@ end


"""
Replace a sequence of `broadcasted()` => `materizalize()` calls with a single `broadcast()`
Recover broadcast operation from Broadcast.broadcasted and Broadcast.materialize
"""
function squash_broadcast(tape::Tape)
function recover_broadcast(tape::Tape)
new_tape = copy_with(tape; ops=AbstractOp[])
# TODO: seems like we don't need subs table any more
# remove after squash_assigned is implemented
st = Dict()
for (id, op) in enumerate(tape)
if op isa Call && op.fn === Broadcast.broadcasted
Expand All @@ -222,9 +224,29 @@ function squash_broadcast(tape::Tape)
end


function squash_assigned(tape::Tape)
new_tape = copy_with(tape; ops=AbstractOp[])
st = Dict()
for (id, op) in enumerate(tape)
if op isa Assign
st[id] = op.src_id
else
# record any other operations as is
new_id = length(new_tape) + 1
push!(new_tape.ops, copy_with(op; id=new_id))
st[id] = length(new_tape)
end
end
replace_in_args!(new_tape, st)
new_tape.resultid = get(st, tape.resultid, tape.resultid)
return new_tape
end


function simplify(tape::Tape)
return squash_broadcast(tape)
tape = recover_broadcast(tape)
tape = squash_assigned(tape)
return tape
end


Expand Down

0 comments on commit 96dcb15

Please sign in to comment.