From 85a98ce0688bf1fd581ed20374e26068420f4cdf Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 18 Mar 2021 07:04:48 -0700 Subject: [PATCH] [Relay][Training][Pass] Factor out first-order AD to a module pass (#7677) --- python/tvm/relay/transform/transform.py | 26 +- src/relay/transforms/first_order_gradient.cc | 309 ++++++++++++++++++ src/relay/transforms/gradient.h | 54 +++ .../{gradient.cc => higher_order_gradient.cc} | 274 +--------------- 4 files changed, 391 insertions(+), 272 deletions(-) create mode 100644 src/relay/transforms/first_order_gradient.cc create mode 100644 src/relay/transforms/gradient.h rename src/relay/transforms/{gradient.cc => higher_order_gradient.cc} (64%) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index b61f209505d8..5b0e480f5f28 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -800,12 +800,36 @@ def gradient(expr, mod=None, mode="higher_order"): The transformed expression. """ if mode == "first_order": - return _ffi_api.first_order_gradient(expr, mod) + warnings.warn( + "using transform.gradient for first-order AD is deprecated, please use the" + "FirstOrderGradient module pass", + DeprecationWarning, + ) + if mod is not None: + raise RuntimeError( + "to run first-order AD on a module, please use the FirstOrderGradient module pass." + ) + return FirstOrderGradient()(tvm.IRModule.from_expr(expr))["main"] if mode == "higher_order": return _ffi_api.gradient(expr, mod) raise Exception("unknown mode") +def FirstOrderGradient(): + """ + Transforms all global functions in the module to return the original result, paired with the + gradients of the inputs. This pass transforms each global function independently and does not + support interprocedural AD. Additionally, this pass does not support any control-flow or + references, and should only be used on pure data-flow graphs. + + Returns + ------- + ret : tvm.transform.Pass + The registered FirstOrderGradient pass. + """ + return _ffi_api.FirstOrderGradient() + + def Defunctionalization(func, mod): """ Performs defunctionalization on func, diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc new file mode 100644 index 000000000000..55714592ded7 --- /dev/null +++ b/src/relay/transforms/first_order_gradient.cc @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file first_order_gradient.cc + * \brief First-order Automatic Differentiation in Relay for pure dataflow graphs. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "gradient.h" +#include "let_list.h" +#include "pass_utils.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +template +Expr MultiFactory(const Type& t, F factory, DiagnosticContext diag_ctx) { + if (auto* tt = t.as()) { + return factory(tt->shape, tt->dtype); + } else if (auto* tt = t.as()) { + std::vector res; + for (size_t i = 0; i < tt->fields.size(); i++) { + res.push_back(MultiFactory(tt->fields[i], factory, diag_ctx)); + } + return Tuple(res); + } else { + diag_ctx.EmitFatal(Diagnostic::Error(t->span) + << "could not build tensors using factory for type " << PrettyPrint(t)); + throw; + } +} + +template +Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like, + DiagnosticContext diag_ctx) { + if (t.as()) { + return factory_like(e); + } else if (auto* tt = t.as()) { + return MultiFactory(t, factory, diag_ctx); + } else { + diag_ctx.EmitFatal(Diagnostic::Error(t->span) + << "could not build tensors using factory for type " << PrettyPrint(t)); + throw; + } +} + +/*! \brief A fragment of the program being built by the automatic differentation + * pass. + */ +struct ADValueNode { + virtual ~ADValueNode() {} + template + T& get() { + auto ret = dynamic_cast(this); + ICHECK(ret) << "cannot downcast"; + return *ret; + } +}; + +using ADValue = std::shared_ptr; + +/*! \brief AD over a program which generates a tensor output. */ +struct ADTensor : ADValueNode { + Expr forward; + mutable Expr reverse; // must be a variable to avoid duplication + ADTensor(LetList* ll, const Expr& forward, DiagnosticContext diag_ctx) + : forward(ll->Push(forward)), + reverse(ll->Push( + MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike, diag_ctx))) { + this->forward->checked_type_ = forward->checked_type(); + } +}; + +/*! \brief A staged representation of the program, we reflect + * Relay functions into a function over fragments of AD. We + * can compute away this function to obtain a reverse mode program. + */ +struct ADFunction : ADValueNode { + // (ad_args, orig) -> ad_ret + using ADFunctionType = ADValue(const std::vector&, const Call&); + std::function func; + explicit ADFunction(const std::function& func) : func(func) {} +}; + +struct FirstOrderReverseAD : ExprFunctor { + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); + std::vector> backprop_actions; + // we assume no closure so no need for lexical scoping + std::unordered_map env; + LetList* ll; + DiagnosticContext diag_ctx; + + FirstOrderReverseAD(LetList* ll, DiagnosticContext diag_ctx) : ll(ll), diag_ctx(diag_ctx) {} + + ADValue VisitExpr(const Expr& n) final { + if (env.count(n)) { + return env.at(n); + } + auto ret = ExprFunctor::VisitExpr(n); + env[n] = ret; + return ret; + } + + static Expr LiftedAdd(const Type& t, const Expr& x, const Expr& y, LetList* ll) { + if (t.as()) { + return ll->Push(Add(x, y)); + } else if (auto* tt = t.as()) { + Array fields; + for (size_t i = 0; i < tt->fields.size(); ++i) { + fields.push_back( + LiftedAdd(tt->fields[i], ll->Push(GetField(x, i)), ll->Push(GetField(y, i)), ll)); + } + return ll->Push(Tuple(fields)); + } else { + LOG(FATAL) << "cannot lift addition for type " << PrettyPrint(t); + throw; + } + } + + ADValue VisitExpr_(const OpNode* op) final { + Op op_ref = GetRef(op); + if (!rev_map.count(op_ref)) { + diag_ctx.EmitFatal(Diagnostic::Error(op->span) + << "the operator " << op->name << " does not have a registered gradient."); + } + return std::make_shared([this, op_ref](const std::vector& ad_args, + const Call& orig) { + std::vector orig_args; + for (const ADValue& adval : ad_args) { + orig_args.push_back(adval->get().forward); + } + auto orig_new = Call(op_ref, orig_args, orig->attrs, orig->type_args); + orig_new->checked_type_ = orig->checked_type(); + auto ret = std::make_shared(ll, orig_new, diag_ctx); + backprop_actions.push_back([this, ad_args, orig_new, ret, op_ref](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig_new, ret->reverse); + if (ad_args.size() != rev.size()) { + diag_ctx.EmitFatal(Diagnostic::Error(op_ref->span) + << "arity mismatch for operator " << op_ref->name + << " and its registered gradient: expected " << ad_args.size() + << " but got " << rev.size() << " gradients."); + } + for (size_t i = 0; i < ad_args.size(); ++i) { + auto& ad_arg = ad_args[i]->get(); + ad_arg.reverse = LiftedAdd(ad_arg.forward->checked_type(), ad_arg.reverse, rev[i], ll); + } + }); + return ret; + }); + } + + ADValue VisitExpr_(const TupleGetItemNode* op) final { + Expr e = GetRef(op); + ADValue tup = VisitExpr(op->tuple); + auto tt = op->tuple->checked_type().as(); + size_t idx = op->index; + auto ret = std::make_shared(ll, e, diag_ctx); + backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) { + auto& ad_tup = tup->get(); + std::vector updated_grads; + for (size_t i = 0; i < tt->fields.size(); ++i) { + Expr grad_pre = GetField(ad_tup.reverse, i); + updated_grads.push_back(i != idx ? grad_pre + : LiftedAdd(tt->fields[i], grad_pre, ret->reverse, ll)); + } + ad_tup.reverse = ll->Push(Tuple(updated_grads)); + }); + return ret; + } + + ADValue VisitExpr_(const TupleNode* op) final { + Expr e = GetRef(op); + std::vector fields; + for (const auto& f : op->fields) { + fields.push_back(VisitExpr(f)); + } + auto tt = op->checked_type().as(); + auto ret = std::make_shared(ll, e, diag_ctx); + backprop_actions.push_back([fields, tt, ret](LetList* ll) { + for (size_t i = 0; i < fields.size(); ++i) { + auto& ad_field = fields[i]->get(); + ad_field.reverse = + LiftedAdd(tt->fields[i], ad_field.reverse, GetField(ret->reverse, i), ll); + } + }); + return ret; + } + + ADValue VisitExpr_(const ConstantNode* op) final { + Expr e = GetRef(op); + return std::make_shared(ll, e, diag_ctx); + } + + ADValue VisitExpr_(const CallNode* op) final { + ADValue f = VisitExpr(op->op); + std::vector args; + for (const auto& arg : op->args) { + args.push_back(VisitExpr(arg)); + } + return f->get().func(args, GetRef(op)); + } + + ADValue VisitExpr_(const FunctionNode* op) final { + Function f = GetRef(op); + // todo: assert no closure + return std::make_shared( + [this, f](const std::vector& ad_args, const Call& orig) { + ICHECK_EQ(f->params.size(), ad_args.size()); + for (size_t i = 0; i < f->params.size(); ++i) { + env[f->params[i]] = ad_args[i]; + } + return VisitExpr(f->body); + }); + } + + // Var will always be in env, handled in VisitExpr (without _), so we don't need + // to implement its VisitExpr_. +}; + +namespace transform { + +Pass FirstOrderGradient() { + runtime::TypedPackedFunc f = [](IRModule mod, PassContext ctx) { + CheckFeature( + mod, FeatureSet({fVar, fConstant, fTuple, fTupleGetItem, fFunction, fOp, fCall, fGraph})); + IRModule ad_mod = GetRef(mod.CopyOnWrite()); + DiagnosticContext diag_ctx = DiagnosticContext::Default(ad_mod); + + if (mod->functions.size() > 1) { + LOG(WARNING) << "IRModule contains multiple global functions: first-order AD will transform " + "them indepedently!"; + } + + for (const auto& pr : mod->functions) { + const FunctionNode* func = pr.second.as(); + if (!func) { + diag_ctx.Emit(Diagnostic::Warning(pr.second->span) + << "AD can only be performed on Relay functions, skipping " + << PrettyPrint(pr.first)); + } + if (func->type_params.size() > 0) { + diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span) + << "first-order AD does not support polymorphism yet."); + } + Expr body = LetList::With([&](LetList* ll) { + FirstOrderReverseAD reverse_ad(ll, diag_ctx); + ADValue rev = reverse_ad(pr.second); + std::vector args; + for (const auto& p : func->params) { + args.push_back(std::make_shared(ll, p, diag_ctx)); + } + Call placeholder = Call(GetRef(func), {}); + placeholder->checked_type_ = func->checked_type().as()->ret_type; + auto grad_call = rev->get().func(args, placeholder); + auto& res = grad_call->get(); + Expr grad_tuple = LetList::With([&](LetList* ll) { + res.reverse = + MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike, diag_ctx); + for (auto it = reverse_ad.backprop_actions.rbegin(); + it != reverse_ad.backprop_actions.rend(); ++it) { + (*it)(ll); + } + std::vector grads; + for (const auto& a : args) { + grads.push_back(a->get().reverse); + } + return Tuple(grads); + }); + return Pair(res.forward, grad_tuple); + }); + ad_mod->Update(pr.first, + Function(func->params, body, GradRetType(GetRef(func)), {})); + } + + return ad_mod; + }; + return CreateModulePass(f, 0, "FirstOrderGradient", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FirstOrderGradient").set_body_typed(FirstOrderGradient); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/gradient.h b/src/relay/transforms/gradient.h new file mode 100644 index 000000000000..2e6ffbcc7c9e --- /dev/null +++ b/src/relay/transforms/gradient.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file gradient.h + * \brief Utility functions for Automatic Differentiation in Relay. + */ +#ifndef TVM_RELAY_TRANSFORMS_GRADIENT_H_ +#define TVM_RELAY_TRANSFORMS_GRADIENT_H_ + +#include +#include + +#include + +namespace tvm { +namespace relay { + +inline Type GradRetType(const Function& f) { + // if type annotations are provided, we will construct a ret type; + // otherwise, leave it to be inferred + if (!f->ret_type.defined()) { + return Type(); + } + std::vector vt; + for (const auto& p : f->params) { + if (!p->type_annotation.defined()) { + return Type(); + } + vt.push_back(p->type_annotation); + } + + return TupleType({f->ret_type, TupleType(vt)}); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TRANSFORMS_GRADIENT_H_ diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/higher_order_gradient.cc similarity index 64% rename from src/relay/transforms/gradient.cc rename to src/relay/transforms/higher_order_gradient.cc index cd3a99655341..202275626d5d 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -18,8 +18,8 @@ */ /*! - * \file gradient.cc - * \brief API for Automatic Differentiation for the Relay IR. + * \file higher_order_gradient.cc + * \brief Higher-order Automatic Differentiation in Relay IR, for non-graph programs. */ #include #include @@ -28,6 +28,7 @@ #include #include +#include "gradient.h" #include "let_list.h" #include "pass_utils.h" #include "pattern_utils.h" @@ -64,13 +65,6 @@ using namespace tvm::runtime; * output. There are multiple implementation of AD in relay, with different characteristic. However, * they all transform the input expr according to WithGradientType. */ -Type WithGradientType(const Type&); - -/*! return an expression that represent differentiation of e (according to WithGradientType). - * This version only work on first order code without control flow. - */ -Expr FirstOrderGradient(const Expr& e, const Optional& mod); - Type WithGradientType(const Type& t) { // TODO(@M.K.): stricter checking auto ty = t.as(); @@ -94,268 +88,6 @@ Expr DeGlobal(const Optional& mod, const Expr& e) { } } -/*! \brief A fragment of the program being built by the automatic differentation - * pass. - */ -struct ADValueNode { - virtual ~ADValueNode() {} - template - T& get() { - auto ret = dynamic_cast(this); - ICHECK(ret) << "cannot downcast"; - return *ret; - } -}; - -template -Expr MultiFactory(const Type& t, F factory) { - if (auto* tt = t.as()) { - return factory(tt->shape, tt->dtype); - } else if (auto* tt = t.as()) { - std::vector res; - for (size_t i = 0; i < tt->fields.size(); i++) { - res.push_back(MultiFactory(tt->fields[i], factory)); - } - return Tuple(res); - } else { - LOG(FATAL) << "unsupported type to create tensors of: " << tt; - throw; - } -} - -template -Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like) { - if (t.as()) { - return factory_like(e); - } else if (auto* tt = t.as()) { - return MultiFactory(t, factory); - } else { - LOG(FATAL) << "unsupported type to tensors of: " << tt; - throw; - } -} - -using ADValue = std::shared_ptr; - -/*! \brief AD over a program which generates a tensor output. */ -struct ADTensor : ADValueNode { - Expr forward; - mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& forward) - : forward(ll->Push(forward)), - reverse( - ll->Push(MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike))) { - this->forward->checked_type_ = forward->checked_type(); - } -}; - -/*! \brief A staged representation of the program, we reflect - * Relay functions into a function over fragments of AD. We - * can compute away this function to obtain a reverse mode program. - */ -struct ADFunction : ADValueNode { - std::function&, const Attrs&, - const tvm::Array&)> - func; - explicit ADFunction(const std::function&, - const Attrs&, const tvm::Array&)>& func) - : func(func) {} -}; - -struct FirstOrderReverseAD : ExprFunctor { - using TBase = ExprFunctor; - const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); - std::vector> backprop_actions; - // we assume no closure so no need for lexical scoping - std::unordered_map env; - LetList* ll; - - FirstOrderReverseAD(LetList* ll) : ll(ll) {} - - ADValue VisitExpr(const Expr& n) final { - if (env.count(n)) { - return env.at(n); - } - auto ret = TBase::VisitExpr(n); - env[n] = ret; - return ret; - } - - Expr UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { - if (t.as()) { - return ll->Push(Add(arg, grad)); - } else if (auto* tt = t.as()) { - Array updates; - for (size_t i = 0; i < tt->fields.size(); ++i) { - updates.push_back(this->UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), - ll->Push(GetField(grad, i)), ll)); - } - return ll->Push(Tuple(updates)); - } else { - LOG(FATAL) << "unsupported arg type of operator: " << t; - throw; - } - } - - ADValue VisitExpr_(const OpNode* op) final { - Op op_ref = GetRef(op); - ICHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; - return std::make_shared( - [this, op_ref](const Type& orig_type, const std::vector& args, const Attrs& attrs, - const tvm::Array& type_args) { - std::vector call_args; - for (const ADValue& adval : args) { - call_args.push_back(adval->get().forward); - } - auto orig = Call(op_ref, call_args, attrs, type_args); - orig->checked_type_ = orig_type; - auto ret = std::make_shared(ll, orig); - backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, ret->reverse); - ICHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - auto ad_arg = args[i]->get(); - auto ad_arg_type = ad_arg.forward->checked_type(); - args[i]->get().reverse = - this->UpdateGrad(ad_arg_type, ad_arg.reverse, rev[i], ll); - } - }); - return ret; - }); - } - - ADValue VisitExpr_(const TupleGetItemNode* op) final { - Expr e = GetRef(op); - ADValue tup = VisitExpr(op->tuple); - auto tt = op->tuple->checked_type().as(); - size_t size = tt->fields.size(); - size_t idx = op->index; - auto ret = std::make_shared(ll, e); - backprop_actions.push_back([tup, idx, size, ret](LetList* ll) { - auto rev = tup->get().reverse; - // special-case Tuple, to avoid long chains of GetItem/Tuple, - // but we might have functions using tuples, so we don't know - // that the reverse node is always a tuple - std::vector grfields; - if (auto tup_node = rev.as()) { - for (size_t i = 0; i < size; ++i) { - grfields.push_back(i != idx ? tup_node->fields[i] - : Add(tup_node->fields[i], ret->reverse)); - } - } else { - for (size_t i = 0; i < size; ++i) { - grfields.push_back(i != idx ? TupleGetItem(rev, i) - : Add(TupleGetItem(rev, i), ret->reverse)); - } - } - tup->get().reverse = ll->Push(Tuple(grfields)); - }); - return ret; - } - - ADValue VisitExpr_(const TupleNode* op) final { - Expr e = GetRef(op); - std::vector fields; - for (const auto& f : op->fields) { - fields.push_back(VisitExpr(f)); - } - auto ret = std::make_shared(ll, e); - backprop_actions.push_back([fields, ret](LetList* ll) { - for (size_t i = 0; i < fields.size(); ++i) { - fields[i]->get().reverse = - ll->Push(Add(fields[i]->get().reverse, TupleGetItem(ret->reverse, i))); - } - }); - return ret; - } - - ADValue VisitExpr_(const ConstantNode* op) final { - Expr e = GetRef(op); - return std::make_shared(ll, e); - } - - ADValue VisitExpr_(const CallNode* op) final { - ADValue f = VisitExpr(op->op); - std::vector args; - for (const auto& arg : op->args) { - args.push_back(VisitExpr(arg)); - } - return f->get().func(op->checked_type(), args, op->attrs, op->type_args); - } - - ADValue VisitExpr_(const FunctionNode* op) final { - Function f = GetRef(op); - // todo: assert no closure - return std::make_shared( - [this, f](const Type& orig_type, const std::vector& args, const Attrs& attrs, - const tvm::Array& type_args) { - ICHECK_EQ(f->params.size(), args.size()); - for (size_t i = 0; i < f->params.size(); ++i) { - env[f->params[i]] = args[i]; - } - return VisitExpr(f->body); - }); - } - - // Var will always be in env, handled in VisitExpr (without _), so we don't need - // to implement its VisitExpr_. -}; - -Type GradRetType(const Function& f) { - // if type annotations are provided, we will construct a ret type; - // otherwise, leave it to be inferred - if (!f->ret_type.defined()) { - return Type(); - } - std::vector vt; - for (const auto& p : f->params) { - if (!p->type_annotation.defined()) { - return Type(); - } - vt.push_back(p->type_annotation); - } - - return TupleType({f->ret_type, TupleType(vt)}); -} - -Expr FirstOrderGradient(const Expr& re, const Optional& mod) { - // Currently we first remove any global functions for the first - // order case. - auto e = DeGlobal(mod, re); - auto f = e.as(); - ICHECK(f) << "FOWithGradient expects its argument to be a function: " << f; - ICHECK(f->type_params.size() == 0) << "no polymorphism supported for now"; - - // We will then build a sequence of lets which implement reverse mode. - Expr body = LetList::With([&](LetList* ll) { - FirstOrderReverseAD reverse_ad(ll); - ADValue rev = reverse_ad(e); - std::vector args; - for (const auto& p : f->params) { - args.push_back(std::make_shared(ll, p)); - } - auto c = rev->get().func(f->checked_type(), args, Attrs(), {}); - const auto& res = c->get(); - Expr grad = LetList::With([&](LetList* ll) { - res.reverse = MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike); - for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); - ++it) { - (*it)(ll); - } - std::vector grad_res; - for (const auto& a : args) { - grad_res.push_back(a->get().reverse); - } - return Tuple(grad_res); - }); - return Pair(res.forward, grad); - }); - - return Function(f->params, body, GradRetType(GetRef(f)), {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); - static Type bpt = RelayRefType(FuncType({}, TupleType(Array()), {}, {})); struct ReverseADType : TypeMutator {