From f389f2ba584ec9443ae3c138fefbd66ea3d99d9a Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 19 Dec 2020 16:56:19 +0100 Subject: [PATCH] friendlier error message for gradient failure --- src/grad/reverse.jl | 2 +- src/grad/tracker.jl | 2 +- src/grad/zygote.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/grad/reverse.jl b/src/grad/reverse.jl index da8b335..9bddd93 100644 --- a/src/grad/reverse.jl +++ b/src/grad/reverse.jl @@ -8,7 +8,7 @@ using .ReverseDiff ReverseDiff.@grad function (ev::Eval)(args...) Z = ev.fwd(ReverseDiff.value.(args)...) Z, Δ -> begin - isnothing(ev.rev) && error("no gradient definition here!") + ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason") ev.rev(ReverseDiff.value(Δ), Z, ReverseDiff.value.(args)...) end end diff --git a/src/grad/tracker.jl b/src/grad/tracker.jl index 53e1618..723914f 100644 --- a/src/grad/tracker.jl +++ b/src/grad/tracker.jl @@ -8,7 +8,7 @@ using .Tracker Tracker.@grad function (ev::Eval)(args...) Z = ev.fwd(Tracker.data.(args)...) Z, Δ -> begin - isnothing(ev.rev) && error("no gradient definition here!") + ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason") tuple(ev.rev(Tracker.data(Δ), Z, Tracker.data.(args)...)...) end end diff --git a/src/grad/zygote.jl b/src/grad/zygote.jl index 82808f9..f719c79 100644 --- a/src/grad/zygote.jl +++ b/src/grad/zygote.jl @@ -4,7 +4,7 @@ using .Zygote Zygote.@adjoint function (ev::Eval)(args...) Z = ev.fwd(args...) Z, Δ -> begin - isnothing(ev.rev) && error("no gradient definition here!") + ev.rev===nothing && throw("No gradient definition found! Running `@tullio` with keyword `verbose=true` may print the reason") tuple(nothing, ev.rev(Δ, Z, args...)...) end end