diff --git a/src/grad.jl b/src/grad.jl index 7862a18..003ff85 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -76,8 +76,8 @@ function deriv_broadcast!(tape::Tape, op::AbstractOp, i::Int, dy::AbstractOp) # 3. record_expr_broadcast!() - record `broadcast` and execute immediately ex = to_unbroadcast_expr(tape, op) dep_eltypes = [eltype(tape[arg].typ) for arg in op.args[2:end]] - dex = deriv_expr(ex, dep_eltypes, i) - st = Dict(Symbol("%$i") => i for i in op.args) + dex = deriv_expr(ex, dep_eltypes, i-1) + st = Dict(Symbol("%$id") => id for id in op.args) st[:ds] = dy.id ret_id = record_expr!(tape, dex; st=st, bcast=true) return tape[ret_id] @@ -105,7 +105,9 @@ end function back!(tape::Tape) - # z - final variable, y - resulting variable of current op, x - dependencies of y + # z - final variable (usually a loss) + # y - resulting variable of current op + # x - dependencies of y # dy - derivative of z w.r.t. y z = tape[tape.resultid] # using Float32 for seed since for 64-bit args it will be expanded anyway @@ -118,6 +120,7 @@ function back!(tape::Tape) if op isa Call # || op isa Bcast for i=1:length(op.args) # backpropagate only non-constant vars + # note that it also prevents backprop on 1st param of broadcast arg_op = tape[op.args[i]] if !isa(arg_op, Constant) step_back!(tape, op, i) diff --git a/src/tape.jl b/src/tape.jl index a63dd51..2d4c9fa 100644 --- a/src/tape.jl +++ b/src/tape.jl @@ -140,7 +140,7 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false) new_op_args[i] = st[x] elseif Meta.isexpr(x, :call) # recursively record arg expression - arg_id = record_expr!(tape, x; st=st) + arg_id = record_expr!(tape, x; st=st, bcast=bcast) new_op_args[i] = arg_id else # treat as constant @@ -160,9 +160,9 @@ function record_expr!(tape::Tape, ex::Expr; st=Dict(), bcast=false) end -function record_expr!(tape::Tape, x::Symbol; st=Dict(), bcast=false) - ds_id = st[:ds] - return record!(tape, Assign, ds_id, tape[ds_id].val) +function record_expr!(tape::Tape, x::Symbol; st=Dict(), bcast=false) + id = st[x] + return record!(tape, Assign, id, tape[id].val) end @@ -193,10 +193,12 @@ end """ -Replace a sequence of `broadcasted()` => `materizalize()` calls with a single `broadcast()` +Recover broadcast operation from Broadcast.broadcasted and Broadcast.materialize """ -function squash_broadcast(tape::Tape) +function recover_broadcast(tape::Tape) new_tape = copy_with(tape; ops=AbstractOp[]) + # TODO: seems like we don't need subs table any more + # remove after squash_assigned is implemented st = Dict() for (id, op) in enumerate(tape) if op isa Call && op.fn === Broadcast.broadcasted @@ -222,9 +224,29 @@ function squash_broadcast(tape::Tape) end +function squash_assigned(tape::Tape) + new_tape = copy_with(tape; ops=AbstractOp[]) + st = Dict() + for (id, op) in enumerate(tape) + if op isa Assign + st[id] = op.src_id + else + # record any other operations as is + new_id = length(new_tape) + 1 + push!(new_tape.ops, copy_with(op; id=new_id)) + st[id] = length(new_tape) + end + end + replace_in_args!(new_tape, st) + new_tape.resultid = get(st, tape.resultid, tape.resultid) + return new_tape +end + function simplify(tape::Tape) - return squash_broadcast(tape) + tape = recover_broadcast(tape) + tape = squash_assigned(tape) + return tape end