Skip to content

Commit

Permalink
Merge pull request #132 from dfdx/faster-grad
Browse files Browse the repository at this point in the history
Faster grad
  • Loading branch information
dfdx authored Oct 30, 2022
2 parents ad94b3e + 24b679b commit b9ae446
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 21 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.8.1"
version = "0.8.2"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Expand All @@ -20,5 +20,5 @@ ChainRules = "1.43"
ChainRulesCore = "1.15"
FiniteDifferences = "0.12"
NNlib = "0.8"
Umlaut = "0.4.7"
Umlaut = "0.4.8"
julia = "1.6"
41 changes: 22 additions & 19 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ struct ChainRulesCtx end


function has_rrule(f, args...)
@nospecialize
F = Core.Typeof(f)
Args = Core.Typeof.(args)
Core.Compiler.return_type(rrule, Tuple{YotaRuleConfig, F, Args...}) !== Nothing && return true
Expand All @@ -18,7 +19,7 @@ function has_rrule(f, args...)
return false
end

Umlaut.isprimitive(::ChainRulesCtx, f, args...) = has_rrule(f, args...)
Umlaut.isprimitive(::ChainRulesCtx, f, args...) = has_rrule(Base.inferencebarrier(f), Base.inferencebarrier(args)...)


struct GradCtx
Expand Down Expand Up @@ -94,31 +95,32 @@ function set_or_add_deriv!(tape::Tape, x::Variable, dx::Variable)
end


function todo_list!(tape::Tape{GradCtx}, y_id::Int, result::Set{Int})
push!(result, y_id)
y = V(tape, y_id)
# since `y = getfield(rr, 2)`, we use arguments of the original rrule instead
y_fargs = is_kwfunc(y._op.fn) ? tape[y].args[3:end] : tape[y].args
for x in y_fargs
if x isa V && !in(x.id, result) && tape[x] isa Call
todo_list!(tape, x.id, result)
end
end
end

"""
Collect variables that we need to step through during the reverse pass.
The returned vector is already deduplicated and reverse-sorted
"""
function todo_list(tape::Tape{GradCtx}, y=tape.result)
y_orig = y
@assert(tape[y] isa Call, "The tape's result is expected to be a Call, " *
function todo_list(tape::Tape{GradCtx})
@assert(tape[tape.result] isa Call, "The tape's result is expected to be a Call, " *
"but instead $(typeof(tape[tape.result])) was encountered")
y_fargs = [tape[y].fn; tape[y].args...]
is_rrule_based = haskey(tape.c.pullbacks, y)
if is_rrule_based
# use rrule instead
y = tape[y].args[1]
y_fargs = is_kwfunc(y._op.fn) ? tape[y].args[3:end] : tape[y].args
end
y_todo = [x for x in y_fargs if x isa V && tape[x] isa Call]
x_todos = [todo_list(tape, x) for x in y_todo]
# include y itself (original), its parents and their parents recursively
todo = [[y_orig]; y_todo; vcat(x_todos...)]
# deduplicate
todo = collect(Set([bound(tape, v) for v in todo]))
todo = sort(todo, by=v->v.id, rev=true)
return todo
result = Set{Int}()
todo_list!(tape, tape.result.id, result)
ids = sort(collect(result), rev=true)
return [V(tape, id) for id in ids]
end


call_values(op::Call) = Umlaut.var_values([op.fn, op.args...])

"""
Expand Down Expand Up @@ -212,6 +214,7 @@ function back!(tape::Tape; seed=1)
tape.c.derivs[z] = dy
# queue of variables to calculate derivatives for
deriv_todo = todo_list(tape)
deriv_todo = sort(deriv_todo, by=v->v.id, rev=true)
for y in deriv_todo
try
step_back!(tape, y)
Expand Down
9 changes: 9 additions & 0 deletions src/rulesets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,15 @@ function rrule(::YotaRuleConfig, nt::Type{NamedTuple{names}}, t::Tuple) where {n
end


function rrule(::YotaRuleConfig, nt::Type{Tuple}, t::Tuple)
val = Tuple(t)
function tuple_pullback(dy)
return NoTangent(), dy
end
return val, tuple_pullback
end


function rrule(::YotaRuleConfig, ::typeof(getindex), s::NamedTuple, f::Symbol)
y = getindex(s, f)
function nt_getindex_pullback(dy)
Expand Down
1 change: 1 addition & 0 deletions test/test_rulesets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ config = YotaRuleConfig()
@testset "tuple" begin
test_rrule(config, tuple, 1.0, 2.0, 3.0; check_inferred=false)
test_rrule(config, NamedTuple{(:dims,)}, (1,))
test_rrule(config, Tuple, (1,))
test_rrule(config, getindex, (a=42.0, b=54.0), :a; check_inferred=false)
end

Expand Down

0 comments on commit b9ae446

Please sign in to comment.