diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index ba10ea5b1..532644914 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -254,6 +254,15 @@ xaccum(ir) = nothing xaccum(ir, x) = x xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...)) +function passthrough_expr(ex::Expr) + # Metadata we want to preserve + isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true + # ccalls and more that are safe to preserve/required for proper operation: + # - jl_set_task_threadpoolid: added in 1.9 for @spawn + isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true + return false +end + function adjoint(pr::Primal) ir, sigs = adjointcfg(pr) for b in reverse(blocks(pr.ir)) @@ -278,10 +287,9 @@ function adjoint(pr::Primal) end elseif ex isa Core.PiNode grads[ex.val] = grads[v] - elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) - elseif isexpr(ex) + elseif isexpr(ex) && !passthrough_expr(ex) push!(rb, stmt(xcall(Base, :error, """ - Can't differentiate $(ex.head) expression. + Can't differentiate $(ex.head) expression $ex. You might want to check the Zygote limitations documentation. https://fluxml.ai/Zygote.jl/latest/limitations """),