From ebf056b54a1d9ab26e724ac0c993ce90b942b61f Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 8 Mar 2021 20:14:47 +0000 Subject: [PATCH 1/6] first order AD as a separate pass --- python/tvm/relay/build_module.py | 12 +- python/tvm/relay/op/_tensor_grad.py | 14 ++ python/tvm/relay/op/tensor.py | 7 +- python/tvm/relay/transform/transform.py | 4 + src/relay/transforms/first_order_ad.cc | 312 ++++++++++++++++++++++++ 5 files changed, 344 insertions(+), 5 deletions(-) create mode 100644 src/relay/transforms/first_order_ad.cc diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 8e69d288df12..38fbcf221393 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -380,11 +380,12 @@ class GraphExecutor(_interpreter.Executor): The target option to build the function. """ - def __init__(self, mod, ctx, target): + def __init__(self, mod, ctx, target, debug=False): assert mod is not None self.mod = mod self.ctx = ctx self.target = target + self.debug = debug def _make_executor(self, expr=None): if expr: @@ -394,7 +395,12 @@ def _make_executor(self, expr=None): if _ty.is_dynamic(ret_type): raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type) mod = build(self.mod, target=self.target) - gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) + if self.debug: + from ..contrib.debugger import debug_runtime + + gmodule = debug_runtime.create(mod.get_json(), mod.lib, self.ctx) + else: + gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) def _unflatten(flat_iter, cur_type): if isinstance(cur_type, _ty.TensorType): @@ -473,6 +479,8 @@ def create_executor(kind="debug", mod=None, ctx=None, target="llvm"): return _interpreter.Interpreter(mod, ctx, target) if kind == "graph": return GraphExecutor(mod, ctx, target) + if kind == "graph_debug": + return GraphExecutor(mod, ctx, target, debug=True) if kind == "vm": return VMExecutor(mod, ctx, target) raise RuntimeError("unknown execution strategy: {0}".format(kind)) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 5836aebce393..0d1a1273c32a 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -596,6 +596,12 @@ def cast_grad(orig, grad): return [cast_like(grad, x)] +@register_gradient("cast_like") +def cast_like_grad(orig, grad): + data, dtype_like = orig.args + return [cast_like(grad, data), zeros_like(dtype_like)] + + @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" @@ -866,3 +872,11 @@ def less_equal_grad(orig, grad): Returns the gradient of less_equal. """ return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] + + +@register_gradient("split") +def split_grad(orig, grad): + """ + Returns the gradient of split, which is the concatenation of the downstream gradients. + """ + return [concatenate(grad, orig.attrs.axis)] diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 5b011043f588..0715540f057e 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -22,7 +22,7 @@ from . import _make from .dyn import _make as _dyn_make -from ..expr import Tuple, Expr, Constant +from ..expr import Tuple, Expr from . import op as reg @@ -1096,12 +1096,13 @@ def concatenate(data, axis): result: relay.Expr The concatenated tensor. """ - data = list(data) + if not isinstance(data, Expr): + data = Tuple(list(data)) if not data: raise ValueError("relay.concatenate requires data to be non-empty.") if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") - return _make.concatenate(Tuple(data), axis) + return _make.concatenate(data, axis) def stack(data, axis): diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index b61f209505d8..7092f029c01f 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -806,6 +806,10 @@ def gradient(expr, mod=None, mode="higher_order"): raise Exception("unknown mode") +def FirstOrderAD(): + return _ffi_api.FirstOrderAD() + + def Defunctionalization(func, mod): """ Performs defunctionalization on func, diff --git a/src/relay/transforms/first_order_ad.cc b/src/relay/transforms/first_order_ad.cc new file mode 100644 index 000000000000..6ff0475b4be4 --- /dev/null +++ b/src/relay/transforms/first_order_ad.cc @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file first_order_ad.cc + * \brief First-order AD for Relay. + */ +#include +#include +#include +#include +#include +#include + +#include "let_list.h" +#include "pass_utils.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +template +Expr MultiFactory(const Type& t, F factory) { + if (auto* tt = t.as()) { + return factory(tt->shape, tt->dtype); + } else if (auto* tt = t.as()) { + std::vector res; + for (size_t i = 0; i < tt->fields.size(); i++) { + res.push_back(MultiFactory(tt->fields[i], factory)); + } + return Tuple(res); + } else { + LOG(FATAL) << "unsupported type to create tensors of: " << tt; + throw; + } +} + +template +Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like) { + if (t.as()) { + return factory_like(e); + } else if (auto* tt = t.as()) { + return MultiFactory(t, factory); + } else { + LOG(FATAL) << "unsupported type to tensors of: " << tt; + throw; + } +} + +/*! \brief A fragment of the program being built by the automatic differentation + * pass. + */ +struct ADValueNode { + virtual ~ADValueNode() {} + template + T& get() { + auto ret = dynamic_cast(this); + ICHECK(ret) << "cannot downcast"; + return *ret; + } +}; + +using ADValue = std::shared_ptr; + +/*! \brief AD over a program which generates a tensor output. */ +struct ADTensor : ADValueNode { + Expr forward; + mutable Expr reverse; // must be a variable to avoid duplication + ADTensor(LetList* ll, const Expr& forward) + : forward(ll->Push(forward)), + reverse( + ll->Push(MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike))) { + this->forward->checked_type_ = forward->checked_type(); + } +}; + +/*! \brief A staged representation of the program, we reflect + * Relay functions into a function over fragments of AD. We + * can compute away this function to obtain a reverse mode program. + */ +struct ADFunction : ADValueNode { + // (ad_args, orig) -> ad_ret + using ADFunctionType = ADValue(const std::vector&, const Call&); + std::function func; + explicit ADFunction(const std::function& func) : func(func) {} +}; + +struct FirstOrderReverseAD : ExprFunctor { + const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); + std::vector> backprop_actions; + // we assume no closure so no need for lexical scoping + std::unordered_map env; + LetList* ll; + DiagnosticContext diag_ctx; + + FirstOrderReverseAD(LetList* ll, DiagnosticContext diag_ctx) : ll(ll), diag_ctx(diag_ctx) {} + + ADValue VisitExpr(const Expr& n) final { + if (env.count(n)) { + return env.at(n); + } + auto ret = ExprFunctor::VisitExpr(n); + env[n] = ret; + return ret; + } + + static Expr LiftedAdd(const Type& t, const Expr& x, const Expr& y, LetList* ll) { + if (t.as()) { + return ll->Push(Add(x, y)); + } else if (auto* tt = t.as()) { + Array fields; + for (size_t i = 0; i < tt->fields.size(); ++i) { + fields.push_back( + LiftedAdd(tt->fields[i], ll->Push(GetField(x, i)), ll->Push(GetField(y, i)), ll)); + } + return ll->Push(Tuple(fields)); + } else { + LOG(FATAL) << "cannot lift addition for type " << PrettyPrint(t); + throw; + } + } + + ADValue VisitExpr_(const OpNode* op) final { + Op op_ref = GetRef(op); + if (!rev_map.count(op_ref)) { + diag_ctx.EmitFatal(Diagnostic::Error(op->span) + << "the operator " << op->name << " does not have a registered gradient."); + } + return std::make_shared([this, op_ref](const std::vector& ad_args, + const Call& orig) { + std::vector orig_args; + for (const ADValue& adval : ad_args) { + orig_args.push_back(adval->get().forward); + } + auto orig_new = Call(op_ref, orig_args, orig->attrs, orig->type_args); + orig_new->checked_type_ = orig->checked_type(); + auto ret = std::make_shared(ll, orig_new); + backprop_actions.push_back([this, ad_args, orig_new, ret, op_ref](LetList* ll) { + tvm::Array rev = rev_map[op_ref](orig_new, ret->reverse); + if (ad_args.size() != rev.size()) { + diag_ctx.EmitFatal(Diagnostic::Error(op_ref->span) + << "arity mismatch for operator " << op_ref->name + << " and its registered gradient: expected " << ad_args.size() + << " but got " << rev.size() << " gradients."); + } + for (size_t i = 0; i < ad_args.size(); ++i) { + auto& ad_arg = ad_args[i]->get(); + ad_arg.reverse = LiftedAdd(ad_arg.forward->checked_type(), ad_arg.reverse, rev[i], ll); + } + }); + return ret; + }); + } + + ADValue VisitExpr_(const TupleGetItemNode* op) final { + Expr e = GetRef(op); + ADValue tup = VisitExpr(op->tuple); + auto tt = op->tuple->checked_type().as(); + size_t idx = op->index; + auto ret = std::make_shared(ll, e); + backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) { + auto& ad_tup = tup->get(); + std::vector updated_grads; + for (size_t i = 0; i < tt->fields.size(); ++i) { + Expr grad_pre = GetField(ad_tup.reverse, i); + updated_grads.push_back(i != idx ? grad_pre + : LiftedAdd(tt->fields[i], grad_pre, ret->reverse, ll)); + } + ad_tup.reverse = ll->Push(Tuple(updated_grads)); + }); + return ret; + } + + ADValue VisitExpr_(const TupleNode* op) final { + Expr e = GetRef(op); + std::vector fields; + for (const auto& f : op->fields) { + fields.push_back(VisitExpr(f)); + } + auto tt = op->checked_type().as(); + auto ret = std::make_shared(ll, e); + backprop_actions.push_back([fields, tt, ret](LetList* ll) { + for (size_t i = 0; i < fields.size(); ++i) { + auto& ad_field = fields[i]->get(); + ad_field.reverse = + LiftedAdd(tt->fields[i], ad_field.reverse, GetField(ret->reverse, i), ll); + } + }); + return ret; + } + + ADValue VisitExpr_(const ConstantNode* op) final { + Expr e = GetRef(op); + return std::make_shared(ll, e); + } + + ADValue VisitExpr_(const CallNode* op) final { + ADValue f = VisitExpr(op->op); + std::vector args; + for (const auto& arg : op->args) { + args.push_back(VisitExpr(arg)); + } + return f->get().func(args, GetRef(op)); + } + + ADValue VisitExpr_(const FunctionNode* op) final { + Function f = GetRef(op); + // todo: assert no closure + return std::make_shared( + [this, f](const std::vector& ad_args, const Call& orig) { + ICHECK_EQ(f->params.size(), ad_args.size()); + for (size_t i = 0; i < f->params.size(); ++i) { + env[f->params[i]] = ad_args[i]; + } + return VisitExpr(f->body); + }); + } + + // Var will always be in env, handled in VisitExpr (without _), so we don't need + // to implement its VisitExpr_. +}; + +Type GradientReturnType(const Function& f) { + // if type annotations are provided, we will construct a ret type; + // otherwise, leave it to be inferred + if (!f->ret_type.defined()) { + return Type(); + } + std::vector vt; + for (const auto& p : f->params) { + if (!p->type_annotation.defined()) { + return Type(); + } + vt.push_back(p->type_annotation); + } + + return TupleType({f->ret_type, TupleType(vt)}); +} + +namespace transform { + +Pass FirstOrderAD() { + runtime::TypedPackedFunc f = [](IRModule mod, PassContext ctx) { + IRModule ad_mod = GetRef(mod.CopyOnWrite()); + DiagnosticContext diag_ctx = DiagnosticContext::Default(ad_mod); + + for (const auto& pr : mod->functions) { + const FunctionNode* func = pr.second.as(); + if (!func) { + diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span) + << "AD can only be performed on Relay functions."); + } + if (func->type_params.size() > 0) { + diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span) + << "first-order AD does not support polymorphism yet."); + } + Expr body = LetList::With([&](LetList* ll) { + FirstOrderReverseAD reverse_ad(ll, diag_ctx); + ADValue rev = reverse_ad(pr.second); + std::vector args; + for (const auto& p : func->params) { + args.push_back(std::make_shared(ll, p)); + } + Call placeholder = Call(GetRef(func), {}); + placeholder->checked_type_ = func->checked_type().as()->ret_type; + auto grad_call = rev->get().func(args, placeholder); + auto& res = grad_call->get(); + Expr grad_tuple = LetList::With([&](LetList* ll) { + res.reverse = MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike); + for (auto it = reverse_ad.backprop_actions.rbegin(); + it != reverse_ad.backprop_actions.rend(); ++it) { + (*it)(ll); + } + std::vector grads; + for (const auto& a : args) { + grads.push_back(a->get().reverse); + } + return Tuple(grads); + }); + return Pair(res.forward, grad_tuple); + }); + ad_mod->Update(pr.first, + Function(func->params, body, GradientReturnType(GetRef(func)), {})); + } + + return ad_mod; + }; + return CreateModulePass(f, 0, "FirstOrderAD", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FirstOrderAD").set_body_typed(FirstOrderAD); + +} // namespace transform + +} // namespace relay +} // namespace tvm \ No newline at end of file From 4b74bd615d19a3eb459da7e31405f527e39359cd Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 15 Mar 2021 23:04:57 -0700 Subject: [PATCH 2/6] ConcretizeLike pass --- python/tvm/relay/transform/transform.py | 4 + src/relay/transforms/first_order_ad.cc | 166 +++++++++++++++++++++--- 2 files changed, 153 insertions(+), 17 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7092f029c01f..20c09c4e0320 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -810,6 +810,10 @@ def FirstOrderAD(): return _ffi_api.FirstOrderAD() +def ConcretizeLike(): + return _ffi_api.ConcretizeLike() + + def Defunctionalization(func, mod): """ Performs defunctionalization on func, diff --git a/src/relay/transforms/first_order_ad.cc b/src/relay/transforms/first_order_ad.cc index 6ff0475b4be4..2cde23a7e0f0 100644 --- a/src/relay/transforms/first_order_ad.cc +++ b/src/relay/transforms/first_order_ad.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -36,29 +37,32 @@ namespace tvm { namespace relay { template -Expr MultiFactory(const Type& t, F factory) { +Expr MultiFactory(const Type& t, F factory, DiagnosticContext diag_ctx) { if (auto* tt = t.as()) { return factory(tt->shape, tt->dtype); } else if (auto* tt = t.as()) { std::vector res; for (size_t i = 0; i < tt->fields.size(); i++) { - res.push_back(MultiFactory(tt->fields[i], factory)); + res.push_back(MultiFactory(tt->fields[i], factory, diag_ctx)); } return Tuple(res); } else { - LOG(FATAL) << "unsupported type to create tensors of: " << tt; + diag_ctx.EmitFatal(Diagnostic::Error(t->span) + << "could not build tensors using factory for type " << PrettyPrint(t)); throw; } } template -Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like) { +Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like, + DiagnosticContext diag_ctx) { if (t.as()) { return factory_like(e); } else if (auto* tt = t.as()) { - return MultiFactory(t, factory); + return MultiFactory(t, factory, diag_ctx); } else { - LOG(FATAL) << "unsupported type to tensors of: " << tt; + diag_ctx.EmitFatal(Diagnostic::Error(t->span) + << "could not build tensors using factory for type " << PrettyPrint(t)); throw; } } @@ -82,10 +86,10 @@ using ADValue = std::shared_ptr; struct ADTensor : ADValueNode { Expr forward; mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& forward) + ADTensor(LetList* ll, const Expr& forward, DiagnosticContext diag_ctx) : forward(ll->Push(forward)), - reverse( - ll->Push(MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike))) { + reverse(ll->Push( + MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike, diag_ctx))) { this->forward->checked_type_ = forward->checked_type(); } }; @@ -150,7 +154,7 @@ struct FirstOrderReverseAD : ExprFunctor { } auto orig_new = Call(op_ref, orig_args, orig->attrs, orig->type_args); orig_new->checked_type_ = orig->checked_type(); - auto ret = std::make_shared(ll, orig_new); + auto ret = std::make_shared(ll, orig_new, diag_ctx); backprop_actions.push_back([this, ad_args, orig_new, ret, op_ref](LetList* ll) { tvm::Array rev = rev_map[op_ref](orig_new, ret->reverse); if (ad_args.size() != rev.size()) { @@ -173,7 +177,7 @@ struct FirstOrderReverseAD : ExprFunctor { ADValue tup = VisitExpr(op->tuple); auto tt = op->tuple->checked_type().as(); size_t idx = op->index; - auto ret = std::make_shared(ll, e); + auto ret = std::make_shared(ll, e, diag_ctx); backprop_actions.push_back([tup, tt, idx, ret](LetList* ll) { auto& ad_tup = tup->get(); std::vector updated_grads; @@ -194,7 +198,7 @@ struct FirstOrderReverseAD : ExprFunctor { fields.push_back(VisitExpr(f)); } auto tt = op->checked_type().as(); - auto ret = std::make_shared(ll, e); + auto ret = std::make_shared(ll, e, diag_ctx); backprop_actions.push_back([fields, tt, ret](LetList* ll) { for (size_t i = 0; i < fields.size(); ++i) { auto& ad_field = fields[i]->get(); @@ -207,7 +211,7 @@ struct FirstOrderReverseAD : ExprFunctor { ADValue VisitExpr_(const ConstantNode* op) final { Expr e = GetRef(op); - return std::make_shared(ll, e); + return std::make_shared(ll, e, diag_ctx); } ADValue VisitExpr_(const CallNode* op) final { @@ -253,6 +257,107 @@ Type GradientReturnType(const Function& f) { return TupleType({f->ret_type, TupleType(vt)}); } +class ConcretizeLikeRewrite { + public: + ConcretizeLikeRewrite() { + concrete_map_[Op::Get("reshape_like")] = [](Expr data, Array shape, DataType dtype) { + return MakeReshape(data, shape); + }; + concrete_map_[Op::Get("zeros_like")] = [](Expr data, Array shape, DataType dtype) { + return MakeZeros(shape, dtype); + }; + concrete_map_[Op::Get("ones_like")] = [](Expr data, Array shape, DataType dtype) { + return MakeOnes(shape, dtype); + }; + concrete_map_[Op::Get("collapse_sum_like")] = [](Expr data, Array shape, + DataType dtype) { + static const Op& op = Op::Get("collapse_sum_to"); + auto attrs = make_object(); + auto cshape = + MakeConstantTensor(DataType::Int(64), {static_cast(shape.size())}, shape); + attrs->shape = shape; + ICHECK_LE(shape.size(), std::numeric_limits::max()); + return Call(op, {data, cshape}, Attrs(attrs)); + }; + concrete_map_[Op::Get("broadcast_to_like")] = [](Expr data, Array shape, + DataType dtype) { + return MakeBroadCastTo(data, shape); + }; + + for (const auto& pr : concrete_map_) { + if (!op_pat_.defined()) { + op_pat_ = IsExpr(pr.first); + } else { + op_pat_ = op_pat_ || IsExpr(pr.first); + } + } + + data_pat_ = IsWildcard(); + like_pat_ = IsWildcard(); + unary_like_pat_ = (IsOp("zeros_like") || IsOp("ones_like"))({like_pat_}); + binary_like_pat_ = (IsOp("reshape_like") || IsOp("collapse_sum_like") || + IsOp("broadcast_to_like"))({data_pat_, like_pat_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const { + // we will rewrite iff the like argument has fully concrete shape + const CallNode* call_node = post.as(); + ICHECK(call_node); + const OpNode* op_node = call_node->op.as(); + ICHECK(op_node); + const Op op_ref = GetRef(op_node); + ICHECK(concrete_map_.count(op_ref) > 0); + + Expr like = node_map[like_pat_][0]; + + if (!like->checked_type_.defined()) { + // TODO(@altanh): why is this? + return post; + } + + // skip trying to support this for now (ironic, as I was the one who added the feature) + if (const auto* attrs = call_node->attrs.as()) { + if (attrs->lhs_begin != 0 || attrs->rhs_begin != 0 || attrs->lhs_end.defined() || + attrs->rhs_end.defined()) { + return post; + } + } + + CHECK(like->checked_type_.defined()) + << "ConcretizeLike requires checked types to be populated, please run type inference"; + const TensorTypeNode* like_ty = like->checked_type().as(); + ICHECK(like_ty) << "got non-Tensor argument type " << PrettyPrint(like->checked_type()); + + Array cshape; + for (const auto& dim : like_ty->shape) { + if (const auto* imm = dim.as()) { + cshape.push_back(Integer(GetRef(imm))); + continue; + } + return post; + } + + if (call_node->args.size() == 2) { + return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype); + } + return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype); + } + + DFPattern UnaryPattern() const { return unary_like_pat_; } + + DFPattern BinaryPattern() const { return binary_like_pat_; } + + private: + using FMake = std::function, DataType)>; + std::unordered_map concrete_map_; + DFPattern op_pat_; + DFPattern data_pat_; + DFPattern like_pat_; + DFPattern unary_like_pat_; + DFPattern binary_like_pat_; +}; + namespace transform { Pass FirstOrderAD() { @@ -260,11 +365,17 @@ Pass FirstOrderAD() { IRModule ad_mod = GetRef(mod.CopyOnWrite()); DiagnosticContext diag_ctx = DiagnosticContext::Default(ad_mod); + if (mod->functions.size() > 1) { + LOG(WARNING) << "IRModule contains multiple global functions: first-order AD will transform " + "them indepedently!"; + } + for (const auto& pr : mod->functions) { const FunctionNode* func = pr.second.as(); if (!func) { - diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span) - << "AD can only be performed on Relay functions."); + diag_ctx.Emit(Diagnostic::Warning(pr.second->span) + << "AD can only be performed on Relay functions, skipping " + << PrettyPrint(pr.first)); } if (func->type_params.size() > 0) { diag_ctx.EmitFatal(Diagnostic::Error(pr.second->span) @@ -275,14 +386,15 @@ Pass FirstOrderAD() { ADValue rev = reverse_ad(pr.second); std::vector args; for (const auto& p : func->params) { - args.push_back(std::make_shared(ll, p)); + args.push_back(std::make_shared(ll, p, diag_ctx)); } Call placeholder = Call(GetRef(func), {}); placeholder->checked_type_ = func->checked_type().as()->ret_type; auto grad_call = rev->get().func(args, placeholder); auto& res = grad_call->get(); Expr grad_tuple = LetList::With([&](LetList* ll) { - res.reverse = MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike); + res.reverse = + MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike, diag_ctx); for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); ++it) { (*it)(ll); @@ -306,6 +418,26 @@ Pass FirstOrderAD() { TVM_REGISTER_GLOBAL("relay._transform.FirstOrderAD").set_body_typed(FirstOrderAD); +Pass ConcretizeLike() { + runtime::TypedPackedFunc pass_func = + [](Function f, IRModule m, PassContext pc) { + ConcretizeLikeRewrite rw; + auto callback_func = PackedFunc([&rw](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = rw.Callback(pre, post, node_map); + }); + Array callbacks = { + DFPatternCallback(rw.UnaryPattern(), callback_func, true), + DFPatternCallback(rw.BinaryPattern(), callback_func, true)}; + return Downcast(RewritePatterns(callbacks, f, m)); + }; + return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(ConcretizeLike); + } // namespace transform } // namespace relay From bb4ec4c875709a4a58b23aac79565dee7418f3f0 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 16 Mar 2021 20:13:29 -0700 Subject: [PATCH 3/6] some refactoring, add unit tests for ConcretizeLike --- python/tvm/relay/op/tensor.py | 2 +- python/tvm/relay/transform/transform.py | 36 ++- src/relay/transforms/concretize_like.cc | 160 ++++++++++ ...st_order_ad.cc => first_order_gradient.cc} | 153 +--------- src/relay/transforms/gradient.h | 52 ++++ .../{gradient.cc => higher_order_gradient.cc} | 274 +----------------- .../python/relay/test_pass_concretize_like.py | 108 +++++++ 7 files changed, 366 insertions(+), 419 deletions(-) create mode 100644 src/relay/transforms/concretize_like.cc rename src/relay/transforms/{first_order_ad.cc => first_order_gradient.cc} (67%) create mode 100644 src/relay/transforms/gradient.h rename src/relay/transforms/{gradient.cc => higher_order_gradient.cc} (64%) create mode 100644 tests/python/relay/test_pass_concretize_like.py diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 0715540f057e..31252ec69306 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -22,7 +22,7 @@ from . import _make from .dyn import _make as _dyn_make -from ..expr import Tuple, Expr +from ..expr import Tuple, Expr, Constant from . import op as reg diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 20c09c4e0320..cfe71979e3f6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -800,17 +800,47 @@ def gradient(expr, mod=None, mode="higher_order"): The transformed expression. """ if mode == "first_order": - return _ffi_api.first_order_gradient(expr, mod) + warnings.warn( + "using transform.gradient for first-order AD is deprecated, please use the" + "FirstOrderGradient module pass", + DeprecationWarning, + ) + if mod is not None: + raise RuntimeError( + "to run first-order AD on a module, please use the FirstOrderGradient module pass." + ) + return FirstOrderGradient()(tvm.IRModule.from_expr(expr))["main"] if mode == "higher_order": return _ffi_api.gradient(expr, mod) raise Exception("unknown mode") -def FirstOrderAD(): - return _ffi_api.FirstOrderAD() +def FirstOrderGradient(): + """ + Transforms all global functions in the module to return the original result, paired with the + gradients of the inputs. This pass transforms each global function independently and does not + support interprocedural AD. Additionally, this pass does not support any control-flow or + references, and should only be used on pure data-flow graphs. + + Returns + ------- + ret : tvm.transform.Pass + The registered FirstOrderGradient pass. + """ + return _ffi_api.FirstOrderGradient() def ConcretizeLike(): + """ + Transforms `op_like` functions to their explicit-shape equivalent (e.g. `zeros_like(x, y)` + to `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary + dependencies and can enable more opportunities for operator fusion. + + Returns + ------- + ret : tvm.transform.Pass + The registered ConcretizeLike pass. + """ return _ffi_api.ConcretizeLike() diff --git a/src/relay/transforms/concretize_like.cc b/src/relay/transforms/concretize_like.cc new file mode 100644 index 000000000000..fc5bb519e241 --- /dev/null +++ b/src/relay/transforms/concretize_like.cc @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file concretize_like.cc + * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to + * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies + * and can enable more opportunities for operator fusion. + */ +#include +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +class ConcretizeLikeRewrite { + public: + ConcretizeLikeRewrite() { + concrete_map_[Op::Get("reshape_like")] = [](Expr data, Array shape, DataType dtype) { + return MakeReshape(data, shape); + }; + concrete_map_[Op::Get("zeros_like")] = [](Expr data, Array shape, DataType dtype) { + return MakeZeros(shape, dtype); + }; + concrete_map_[Op::Get("ones_like")] = [](Expr data, Array shape, DataType dtype) { + return MakeOnes(shape, dtype); + }; + concrete_map_[Op::Get("collapse_sum_like")] = [](Expr data, Array shape, + DataType dtype) { + ICHECK_LE(shape.size(), std::numeric_limits::max()); + static const Op& op = Op::Get("collapse_sum_to"); + auto attrs = make_object(); + auto cshape = + MakeConstantTensor(DataType::Int(32), {static_cast(shape.size())}, shape); + attrs->shape = shape; + return Call(op, {data, cshape}, Attrs(attrs)); + }; + concrete_map_[Op::Get("broadcast_to_like")] = [](Expr data, Array shape, + DataType dtype) { + return MakeBroadCastTo(data, shape); + }; + + for (const auto& pr : concrete_map_) { + if (!op_pat_.defined()) { + op_pat_ = IsExpr(pr.first); + } else { + op_pat_ = op_pat_ || IsExpr(pr.first); + } + } + + data_pat_ = IsWildcard(); + like_pat_ = IsWildcard(); + unary_like_pat_ = (IsOp("zeros_like") || IsOp("ones_like"))({like_pat_}); + binary_like_pat_ = (IsOp("reshape_like") || IsOp("collapse_sum_like") || + IsOp("broadcast_to_like"))({data_pat_, like_pat_}); + } + + Expr Callback(const Expr& pre, const Expr& post, + const Map>& node_map) const { + // we will rewrite iff the like argument has fully concrete shape + const CallNode* call_node = post.as(); + ICHECK(call_node); + const OpNode* op_node = call_node->op.as(); + ICHECK(op_node); + const Op op_ref = GetRef(op_node); + ICHECK(concrete_map_.count(op_ref) > 0); + + Expr like = node_map[like_pat_][0]; + + if (!like->checked_type_.defined()) { + // TODO(@altanh): maybe because of the input being rewritten? + return post; + } + + // skip trying to support this for now (ironic, as I was the one who added the feature) + if (const auto* attrs = call_node->attrs.as()) { + if (attrs->lhs_begin != 0 || attrs->rhs_begin != 0 || attrs->lhs_end.defined() || + attrs->rhs_end.defined()) { + return post; + } + } + + CHECK(like->checked_type_.defined()) + << "ConcretizeLike requires checked types to be populated, please run type inference"; + const TensorTypeNode* like_ty = like->checked_type().as(); + ICHECK(like_ty) << "got non-Tensor argument type " << PrettyPrint(like->checked_type()); + + Array cshape; + for (const auto& dim : like_ty->shape) { + if (const auto* imm = dim.as()) { + cshape.push_back(Integer(GetRef(imm))); + continue; + } + return post; + } + + if (call_node->args.size() == 2) { + return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype); + } + return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype); + } + + DFPattern UnaryPattern() const { return unary_like_pat_; } + + DFPattern BinaryPattern() const { return binary_like_pat_; } + + private: + using FMake = std::function, DataType)>; + std::unordered_map concrete_map_; + DFPattern op_pat_; + DFPattern data_pat_; + DFPattern like_pat_; + DFPattern unary_like_pat_; + DFPattern binary_like_pat_; +}; + +namespace transform { + +Pass ConcretizeLike() { + runtime::TypedPackedFunc pass_func = + [](Function f, IRModule m, PassContext pc) { + ConcretizeLikeRewrite rw; + auto callback_func = PackedFunc([&rw](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = rw.Callback(pre, post, node_map); + }); + Array callbacks = { + DFPatternCallback(rw.UnaryPattern(), callback_func, true), + DFPatternCallback(rw.BinaryPattern(), callback_func, true)}; + return Downcast(RewritePatterns(callbacks, f, m)); + }; + return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(ConcretizeLike); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/first_order_ad.cc b/src/relay/transforms/first_order_gradient.cc similarity index 67% rename from src/relay/transforms/first_order_ad.cc rename to src/relay/transforms/first_order_gradient.cc index 2cde23a7e0f0..4b7a82b80940 100644 --- a/src/relay/transforms/first_order_ad.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -18,8 +18,8 @@ */ /*! - * \file first_order_ad.cc - * \brief First-order AD for Relay. + * \file first_order_gradient.cc + * \brief First-order Automatic Differentiation in Relay for pure dataflow graphs. */ #include #include @@ -29,6 +29,7 @@ #include #include +#include "gradient.h" #include "let_list.h" #include "pass_utils.h" #include "pattern_utils.h" @@ -240,128 +241,12 @@ struct FirstOrderReverseAD : ExprFunctor { // to implement its VisitExpr_. }; -Type GradientReturnType(const Function& f) { - // if type annotations are provided, we will construct a ret type; - // otherwise, leave it to be inferred - if (!f->ret_type.defined()) { - return Type(); - } - std::vector vt; - for (const auto& p : f->params) { - if (!p->type_annotation.defined()) { - return Type(); - } - vt.push_back(p->type_annotation); - } - - return TupleType({f->ret_type, TupleType(vt)}); -} - -class ConcretizeLikeRewrite { - public: - ConcretizeLikeRewrite() { - concrete_map_[Op::Get("reshape_like")] = [](Expr data, Array shape, DataType dtype) { - return MakeReshape(data, shape); - }; - concrete_map_[Op::Get("zeros_like")] = [](Expr data, Array shape, DataType dtype) { - return MakeZeros(shape, dtype); - }; - concrete_map_[Op::Get("ones_like")] = [](Expr data, Array shape, DataType dtype) { - return MakeOnes(shape, dtype); - }; - concrete_map_[Op::Get("collapse_sum_like")] = [](Expr data, Array shape, - DataType dtype) { - static const Op& op = Op::Get("collapse_sum_to"); - auto attrs = make_object(); - auto cshape = - MakeConstantTensor(DataType::Int(64), {static_cast(shape.size())}, shape); - attrs->shape = shape; - ICHECK_LE(shape.size(), std::numeric_limits::max()); - return Call(op, {data, cshape}, Attrs(attrs)); - }; - concrete_map_[Op::Get("broadcast_to_like")] = [](Expr data, Array shape, - DataType dtype) { - return MakeBroadCastTo(data, shape); - }; - - for (const auto& pr : concrete_map_) { - if (!op_pat_.defined()) { - op_pat_ = IsExpr(pr.first); - } else { - op_pat_ = op_pat_ || IsExpr(pr.first); - } - } - - data_pat_ = IsWildcard(); - like_pat_ = IsWildcard(); - unary_like_pat_ = (IsOp("zeros_like") || IsOp("ones_like"))({like_pat_}); - binary_like_pat_ = (IsOp("reshape_like") || IsOp("collapse_sum_like") || - IsOp("broadcast_to_like"))({data_pat_, like_pat_}); - } - - Expr Callback(const Expr& pre, const Expr& post, - const Map>& node_map) const { - // we will rewrite iff the like argument has fully concrete shape - const CallNode* call_node = post.as(); - ICHECK(call_node); - const OpNode* op_node = call_node->op.as(); - ICHECK(op_node); - const Op op_ref = GetRef(op_node); - ICHECK(concrete_map_.count(op_ref) > 0); - - Expr like = node_map[like_pat_][0]; - - if (!like->checked_type_.defined()) { - // TODO(@altanh): why is this? - return post; - } - - // skip trying to support this for now (ironic, as I was the one who added the feature) - if (const auto* attrs = call_node->attrs.as()) { - if (attrs->lhs_begin != 0 || attrs->rhs_begin != 0 || attrs->lhs_end.defined() || - attrs->rhs_end.defined()) { - return post; - } - } - - CHECK(like->checked_type_.defined()) - << "ConcretizeLike requires checked types to be populated, please run type inference"; - const TensorTypeNode* like_ty = like->checked_type().as(); - ICHECK(like_ty) << "got non-Tensor argument type " << PrettyPrint(like->checked_type()); - - Array cshape; - for (const auto& dim : like_ty->shape) { - if (const auto* imm = dim.as()) { - cshape.push_back(Integer(GetRef(imm))); - continue; - } - return post; - } - - if (call_node->args.size() == 2) { - return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype); - } - return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype); - } - - DFPattern UnaryPattern() const { return unary_like_pat_; } - - DFPattern BinaryPattern() const { return binary_like_pat_; } - - private: - using FMake = std::function, DataType)>; - std::unordered_map concrete_map_; - DFPattern op_pat_; - DFPattern data_pat_; - DFPattern like_pat_; - DFPattern unary_like_pat_; - DFPattern binary_like_pat_; -}; - namespace transform { -Pass FirstOrderAD() { +Pass FirstOrderGradient() { runtime::TypedPackedFunc f = [](IRModule mod, PassContext ctx) { + CheckFeature( + mod, FeatureSet({fVar, fConstant, fTuple, fTupleGetItem, fFunction, fOp, fCall, fGraph})); IRModule ad_mod = GetRef(mod.CopyOnWrite()); DiagnosticContext diag_ctx = DiagnosticContext::Default(ad_mod); @@ -408,35 +293,15 @@ Pass FirstOrderAD() { return Pair(res.forward, grad_tuple); }); ad_mod->Update(pr.first, - Function(func->params, body, GradientReturnType(GetRef(func)), {})); + Function(func->params, body, GradRetType(GetRef(func)), {})); } return ad_mod; }; - return CreateModulePass(f, 0, "FirstOrderAD", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.FirstOrderAD").set_body_typed(FirstOrderAD); - -Pass ConcretizeLike() { - runtime::TypedPackedFunc pass_func = - [](Function f, IRModule m, PassContext pc) { - ConcretizeLikeRewrite rw; - auto callback_func = PackedFunc([&rw](TVMArgs args, TVMRetValue* rv) { - Expr pre = args[0]; - Expr post = args[1]; - Map> node_map = args[2]; - *rv = rw.Callback(pre, post, node_map); - }); - Array callbacks = { - DFPatternCallback(rw.UnaryPattern(), callback_func, true), - DFPatternCallback(rw.BinaryPattern(), callback_func, true)}; - return Downcast(RewritePatterns(callbacks, f, m)); - }; - return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {}); + return CreateModulePass(f, 0, "FirstOrderGradient", {}); } -TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(ConcretizeLike); +TVM_REGISTER_GLOBAL("relay._transform.FirstOrderGradient").set_body_typed(FirstOrderGradient); } // namespace transform diff --git a/src/relay/transforms/gradient.h b/src/relay/transforms/gradient.h new file mode 100644 index 000000000000..455d86613b8e --- /dev/null +++ b/src/relay/transforms/gradient.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file gradient.h + * \brief Utility functions for Automatic Differentiation in Relay. + */ +#ifndef TVM_RELAY_TRANSFORMS_GRADIENT_H_ +#define TVM_RELAY_TRANSFORMS_GRADIENT_H_ + +#include +#include + +namespace tvm { +namespace relay { + +inline Type GradRetType(const Function& f) { + // if type annotations are provided, we will construct a ret type; + // otherwise, leave it to be inferred + if (!f->ret_type.defined()) { + return Type(); + } + std::vector vt; + for (const auto& p : f->params) { + if (!p->type_annotation.defined()) { + return Type(); + } + vt.push_back(p->type_annotation); + } + + return TupleType({f->ret_type, TupleType(vt)}); +} + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_TRANSFORMS_GRADIENT_H_ diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/higher_order_gradient.cc similarity index 64% rename from src/relay/transforms/gradient.cc rename to src/relay/transforms/higher_order_gradient.cc index cd3a99655341..202275626d5d 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/higher_order_gradient.cc @@ -18,8 +18,8 @@ */ /*! - * \file gradient.cc - * \brief API for Automatic Differentiation for the Relay IR. + * \file higher_order_gradient.cc + * \brief Higher-order Automatic Differentiation in Relay IR, for non-graph programs. */ #include #include @@ -28,6 +28,7 @@ #include #include +#include "gradient.h" #include "let_list.h" #include "pass_utils.h" #include "pattern_utils.h" @@ -64,13 +65,6 @@ using namespace tvm::runtime; * output. There are multiple implementation of AD in relay, with different characteristic. However, * they all transform the input expr according to WithGradientType. */ -Type WithGradientType(const Type&); - -/*! return an expression that represent differentiation of e (according to WithGradientType). - * This version only work on first order code without control flow. - */ -Expr FirstOrderGradient(const Expr& e, const Optional& mod); - Type WithGradientType(const Type& t) { // TODO(@M.K.): stricter checking auto ty = t.as(); @@ -94,268 +88,6 @@ Expr DeGlobal(const Optional& mod, const Expr& e) { } } -/*! \brief A fragment of the program being built by the automatic differentation - * pass. - */ -struct ADValueNode { - virtual ~ADValueNode() {} - template - T& get() { - auto ret = dynamic_cast(this); - ICHECK(ret) << "cannot downcast"; - return *ret; - } -}; - -template -Expr MultiFactory(const Type& t, F factory) { - if (auto* tt = t.as()) { - return factory(tt->shape, tt->dtype); - } else if (auto* tt = t.as()) { - std::vector res; - for (size_t i = 0; i < tt->fields.size(); i++) { - res.push_back(MultiFactory(tt->fields[i], factory)); - } - return Tuple(res); - } else { - LOG(FATAL) << "unsupported type to create tensors of: " << tt; - throw; - } -} - -template -Expr MultiFactoryLike(const Expr& e, const Type& t, F factory, F2 factory_like) { - if (t.as()) { - return factory_like(e); - } else if (auto* tt = t.as()) { - return MultiFactory(t, factory); - } else { - LOG(FATAL) << "unsupported type to tensors of: " << tt; - throw; - } -} - -using ADValue = std::shared_ptr; - -/*! \brief AD over a program which generates a tensor output. */ -struct ADTensor : ADValueNode { - Expr forward; - mutable Expr reverse; // must be a variable to avoid duplication - ADTensor(LetList* ll, const Expr& forward) - : forward(ll->Push(forward)), - reverse( - ll->Push(MultiFactoryLike(this->forward, forward->checked_type(), Zeros, ZerosLike))) { - this->forward->checked_type_ = forward->checked_type(); - } -}; - -/*! \brief A staged representation of the program, we reflect - * Relay functions into a function over fragments of AD. We - * can compute away this function to obtain a reverse mode program. - */ -struct ADFunction : ADValueNode { - std::function&, const Attrs&, - const tvm::Array&)> - func; - explicit ADFunction(const std::function&, - const Attrs&, const tvm::Array&)>& func) - : func(func) {} -}; - -struct FirstOrderReverseAD : ExprFunctor { - using TBase = ExprFunctor; - const OpAttrMap rev_map = Op::GetAttrMap("FPrimalGradient"); - std::vector> backprop_actions; - // we assume no closure so no need for lexical scoping - std::unordered_map env; - LetList* ll; - - FirstOrderReverseAD(LetList* ll) : ll(ll) {} - - ADValue VisitExpr(const Expr& n) final { - if (env.count(n)) { - return env.at(n); - } - auto ret = TBase::VisitExpr(n); - env[n] = ret; - return ret; - } - - Expr UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { - if (t.as()) { - return ll->Push(Add(arg, grad)); - } else if (auto* tt = t.as()) { - Array updates; - for (size_t i = 0; i < tt->fields.size(); ++i) { - updates.push_back(this->UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), - ll->Push(GetField(grad, i)), ll)); - } - return ll->Push(Tuple(updates)); - } else { - LOG(FATAL) << "unsupported arg type of operator: " << t; - throw; - } - } - - ADValue VisitExpr_(const OpNode* op) final { - Op op_ref = GetRef(op); - ICHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; - return std::make_shared( - [this, op_ref](const Type& orig_type, const std::vector& args, const Attrs& attrs, - const tvm::Array& type_args) { - std::vector call_args; - for (const ADValue& adval : args) { - call_args.push_back(adval->get().forward); - } - auto orig = Call(op_ref, call_args, attrs, type_args); - orig->checked_type_ = orig_type; - auto ret = std::make_shared(ll, orig); - backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) { - tvm::Array rev = rev_map[op_ref](orig, ret->reverse); - ICHECK(args.size() == rev.size()); - for (size_t i = 0; i < args.size(); ++i) { - auto ad_arg = args[i]->get(); - auto ad_arg_type = ad_arg.forward->checked_type(); - args[i]->get().reverse = - this->UpdateGrad(ad_arg_type, ad_arg.reverse, rev[i], ll); - } - }); - return ret; - }); - } - - ADValue VisitExpr_(const TupleGetItemNode* op) final { - Expr e = GetRef(op); - ADValue tup = VisitExpr(op->tuple); - auto tt = op->tuple->checked_type().as(); - size_t size = tt->fields.size(); - size_t idx = op->index; - auto ret = std::make_shared(ll, e); - backprop_actions.push_back([tup, idx, size, ret](LetList* ll) { - auto rev = tup->get().reverse; - // special-case Tuple, to avoid long chains of GetItem/Tuple, - // but we might have functions using tuples, so we don't know - // that the reverse node is always a tuple - std::vector grfields; - if (auto tup_node = rev.as()) { - for (size_t i = 0; i < size; ++i) { - grfields.push_back(i != idx ? tup_node->fields[i] - : Add(tup_node->fields[i], ret->reverse)); - } - } else { - for (size_t i = 0; i < size; ++i) { - grfields.push_back(i != idx ? TupleGetItem(rev, i) - : Add(TupleGetItem(rev, i), ret->reverse)); - } - } - tup->get().reverse = ll->Push(Tuple(grfields)); - }); - return ret; - } - - ADValue VisitExpr_(const TupleNode* op) final { - Expr e = GetRef(op); - std::vector fields; - for (const auto& f : op->fields) { - fields.push_back(VisitExpr(f)); - } - auto ret = std::make_shared(ll, e); - backprop_actions.push_back([fields, ret](LetList* ll) { - for (size_t i = 0; i < fields.size(); ++i) { - fields[i]->get().reverse = - ll->Push(Add(fields[i]->get().reverse, TupleGetItem(ret->reverse, i))); - } - }); - return ret; - } - - ADValue VisitExpr_(const ConstantNode* op) final { - Expr e = GetRef(op); - return std::make_shared(ll, e); - } - - ADValue VisitExpr_(const CallNode* op) final { - ADValue f = VisitExpr(op->op); - std::vector args; - for (const auto& arg : op->args) { - args.push_back(VisitExpr(arg)); - } - return f->get().func(op->checked_type(), args, op->attrs, op->type_args); - } - - ADValue VisitExpr_(const FunctionNode* op) final { - Function f = GetRef(op); - // todo: assert no closure - return std::make_shared( - [this, f](const Type& orig_type, const std::vector& args, const Attrs& attrs, - const tvm::Array& type_args) { - ICHECK_EQ(f->params.size(), args.size()); - for (size_t i = 0; i < f->params.size(); ++i) { - env[f->params[i]] = args[i]; - } - return VisitExpr(f->body); - }); - } - - // Var will always be in env, handled in VisitExpr (without _), so we don't need - // to implement its VisitExpr_. -}; - -Type GradRetType(const Function& f) { - // if type annotations are provided, we will construct a ret type; - // otherwise, leave it to be inferred - if (!f->ret_type.defined()) { - return Type(); - } - std::vector vt; - for (const auto& p : f->params) { - if (!p->type_annotation.defined()) { - return Type(); - } - vt.push_back(p->type_annotation); - } - - return TupleType({f->ret_type, TupleType(vt)}); -} - -Expr FirstOrderGradient(const Expr& re, const Optional& mod) { - // Currently we first remove any global functions for the first - // order case. - auto e = DeGlobal(mod, re); - auto f = e.as(); - ICHECK(f) << "FOWithGradient expects its argument to be a function: " << f; - ICHECK(f->type_params.size() == 0) << "no polymorphism supported for now"; - - // We will then build a sequence of lets which implement reverse mode. - Expr body = LetList::With([&](LetList* ll) { - FirstOrderReverseAD reverse_ad(ll); - ADValue rev = reverse_ad(e); - std::vector args; - for (const auto& p : f->params) { - args.push_back(std::make_shared(ll, p)); - } - auto c = rev->get().func(f->checked_type(), args, Attrs(), {}); - const auto& res = c->get(); - Expr grad = LetList::With([&](LetList* ll) { - res.reverse = MultiFactoryLike(res.forward, res.forward->checked_type(), Ones, OnesLike); - for (auto it = reverse_ad.backprop_actions.rbegin(); it != reverse_ad.backprop_actions.rend(); - ++it) { - (*it)(ll); - } - std::vector grad_res; - for (const auto& a : args) { - grad_res.push_back(a->get().reverse); - } - return Tuple(grad_res); - }); - return Pair(res.forward, grad); - }); - - return Function(f->params, body, GradRetType(GetRef(f)), {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient); - static Type bpt = RelayRefType(FuncType({}, TupleType(Array()), {}, {})); struct ReverseADType : TypeMutator { diff --git a/tests/python/relay/test_pass_concretize_like.py b/tests/python/relay/test_pass_concretize_like.py new file mode 100644 index 000000000000..ea4ec6038494 --- /dev/null +++ b/tests/python/relay/test_pass_concretize_like.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for the ConcretizeLike pass.""" +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.relay.testing import run_infer_type + + +def test_reshape_like(): + data = relay.var("data", shape=(2, 3, 4), dtype="float32") + shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") + f = relay.Function([data, shape_like], relay.reshape_like(data, shape_like)) + f_expected = relay.Function([data, shape_like], relay.reshape(data, (6, 2, 2))) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_zeros_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + f = relay.Function([shape_like], relay.zeros_like(shape_like)) + f_expected = relay.Function([shape_like], relay.zeros((3, 4, 5), dtype)) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_ones_like(): + dtype = "int32" + shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) + f = relay.Function([shape_like], relay.ones_like(shape_like)) + f_expected = relay.Function([shape_like], relay.ones((3, 4, 5), dtype)) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_collapse_sum_like(): + data = relay.var("data", shape=(3, 3, 3), dtype="float32") + shape_like = relay.var("shape_like", shape=(3,), dtype="float32") + f = relay.Function([data, shape_like], relay.collapse_sum_like(data, shape_like)) + f_expected = relay.Function([data, shape_like], relay.collapse_sum_to(data, (3,))) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_broadcast_to_like(): + data = relay.var("data", shape=(3,), dtype="float32") + shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32") + f = relay.Function([data, shape_like], relay.broadcast_to_like(data, shape_like)) + f_expected = relay.Function([data, shape_like], relay.broadcast_to(data, (3, 3, 3))) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) + + +def test_multiple(): + x = relay.var("x", shape=(2, 3), dtype="float32") + y = relay.var("x", shape=(3,), dtype="float32") + l = x + y + + dl = relay.ones_like(l) + dx = relay.zeros_like(x) + dy = relay.zeros_like(y) + dx = dx + relay.collapse_sum_like(dl, dx) + dy = dy + relay.collapse_sum_like(dl, dy) + ret = relay.Tuple([dx, dy]) + f = relay.Function([x, y], ret) + + dl_c = relay.ones((2, 3), "float32") + dx_c = relay.zeros((2, 3), "float32") + dy_c = relay.zeros((3,), "float32") + dx_c = dx_c + relay.collapse_sum_to(dl_c, (2, 3)) + dy_c = dy_c + relay.collapse_sum_to(dl_c, (3,)) + ret_c = relay.Tuple([dx_c, dy_c]) + f_expected = relay.Function([x, y], ret_c) + f_expected = run_infer_type(f_expected) + + mod = tvm.IRModule.from_expr(f) + mod_concrete = relay.transform.ConcretizeLike()(mod) + assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) From ee72271c4126bcf4ceaf94a9644a6048d51e44fb Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 16 Mar 2021 20:19:16 -0700 Subject: [PATCH 4/6] revert changes unrelated to this PR --- python/tvm/relay/build_module.py | 12 ++---------- python/tvm/relay/op/_tensor_grad.py | 14 -------------- python/tvm/relay/op/tensor.py | 6 +++--- src/relay/transforms/first_order_gradient.cc | 2 +- 4 files changed, 6 insertions(+), 28 deletions(-) diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 38fbcf221393..8e69d288df12 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -380,12 +380,11 @@ class GraphExecutor(_interpreter.Executor): The target option to build the function. """ - def __init__(self, mod, ctx, target, debug=False): + def __init__(self, mod, ctx, target): assert mod is not None self.mod = mod self.ctx = ctx self.target = target - self.debug = debug def _make_executor(self, expr=None): if expr: @@ -395,12 +394,7 @@ def _make_executor(self, expr=None): if _ty.is_dynamic(ret_type): raise ValueError("Graph Runtime only supports static graphs, got output type", ret_type) mod = build(self.mod, target=self.target) - if self.debug: - from ..contrib.debugger import debug_runtime - - gmodule = debug_runtime.create(mod.get_json(), mod.lib, self.ctx) - else: - gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) + gmodule = _graph_rt.GraphModule(mod["default"](self.ctx)) def _unflatten(flat_iter, cur_type): if isinstance(cur_type, _ty.TensorType): @@ -479,8 +473,6 @@ def create_executor(kind="debug", mod=None, ctx=None, target="llvm"): return _interpreter.Interpreter(mod, ctx, target) if kind == "graph": return GraphExecutor(mod, ctx, target) - if kind == "graph_debug": - return GraphExecutor(mod, ctx, target, debug=True) if kind == "vm": return VMExecutor(mod, ctx, target) raise RuntimeError("unknown execution strategy: {0}".format(kind)) diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 0d1a1273c32a..5836aebce393 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -596,12 +596,6 @@ def cast_grad(orig, grad): return [cast_like(grad, x)] -@register_gradient("cast_like") -def cast_like_grad(orig, grad): - data, dtype_like = orig.args - return [cast_like(grad, data), zeros_like(dtype_like)] - - @register_gradient("nn.batch_flatten") def batch_flatten_grad(orig, grad): """Returns grad reshaped to data dims""" @@ -872,11 +866,3 @@ def less_equal_grad(orig, grad): Returns the gradient of less_equal. """ return [zeros_like(orig.args[0]), zeros_like(orig.args[1])] - - -@register_gradient("split") -def split_grad(orig, grad): - """ - Returns the gradient of split, which is the concatenation of the downstream gradients. - """ - return [concatenate(grad, orig.attrs.axis)] diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index 31252ec69306..c476d3a1c883 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1096,13 +1096,13 @@ def concatenate(data, axis): result: relay.Expr The concatenated tensor. """ - if not isinstance(data, Expr): - data = Tuple(list(data)) + data = list(data) + if not data: raise ValueError("relay.concatenate requires data to be non-empty.") if not isinstance(axis, int): raise ValueError("For now, we only support integer axis") - return _make.concatenate(data, axis) + return _make.concatenate(Tuple(data), axis) def stack(data, axis): diff --git a/src/relay/transforms/first_order_gradient.cc b/src/relay/transforms/first_order_gradient.cc index 4b7a82b80940..55714592ded7 100644 --- a/src/relay/transforms/first_order_gradient.cc +++ b/src/relay/transforms/first_order_gradient.cc @@ -306,4 +306,4 @@ TVM_REGISTER_GLOBAL("relay._transform.FirstOrderGradient").set_body_typed(FirstO } // namespace transform } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm From 4a0185981507bfb8ed928a265f4d34a4779cdb30 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Tue, 16 Mar 2021 20:35:36 -0700 Subject: [PATCH 5/6] missed one --- python/tvm/relay/op/tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index c476d3a1c883..5b011043f588 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1097,7 +1097,6 @@ def concatenate(data, axis): The concatenated tensor. """ data = list(data) - if not data: raise ValueError("relay.concatenate requires data to be non-empty.") if not isinstance(axis, int): From 775086b0e7675e2ee37dbb6bbf2655b9917e3195 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Wed, 17 Mar 2021 11:52:24 -0700 Subject: [PATCH 6/6] delete ConcretizeLike --- python/tvm/relay/transform/transform.py | 14 -- src/relay/transforms/concretize_like.cc | 160 ------------------ src/relay/transforms/gradient.h | 2 + .../python/relay/test_pass_concretize_like.py | 108 ------------ 4 files changed, 2 insertions(+), 282 deletions(-) delete mode 100644 src/relay/transforms/concretize_like.cc delete mode 100644 tests/python/relay/test_pass_concretize_like.py diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cfe71979e3f6..5b0e480f5f28 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -830,20 +830,6 @@ def FirstOrderGradient(): return _ffi_api.FirstOrderGradient() -def ConcretizeLike(): - """ - Transforms `op_like` functions to their explicit-shape equivalent (e.g. `zeros_like(x, y)` - to `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary - dependencies and can enable more opportunities for operator fusion. - - Returns - ------- - ret : tvm.transform.Pass - The registered ConcretizeLike pass. - """ - return _ffi_api.ConcretizeLike() - - def Defunctionalization(func, mod): """ Performs defunctionalization on func, diff --git a/src/relay/transforms/concretize_like.cc b/src/relay/transforms/concretize_like.cc deleted file mode 100644 index fc5bb519e241..000000000000 --- a/src/relay/transforms/concretize_like.cc +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file concretize_like.cc - * \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to - * `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies - * and can enable more opportunities for operator fusion. - */ -#include -#include - -#include "pattern_utils.h" - -namespace tvm { -namespace relay { - -class ConcretizeLikeRewrite { - public: - ConcretizeLikeRewrite() { - concrete_map_[Op::Get("reshape_like")] = [](Expr data, Array shape, DataType dtype) { - return MakeReshape(data, shape); - }; - concrete_map_[Op::Get("zeros_like")] = [](Expr data, Array shape, DataType dtype) { - return MakeZeros(shape, dtype); - }; - concrete_map_[Op::Get("ones_like")] = [](Expr data, Array shape, DataType dtype) { - return MakeOnes(shape, dtype); - }; - concrete_map_[Op::Get("collapse_sum_like")] = [](Expr data, Array shape, - DataType dtype) { - ICHECK_LE(shape.size(), std::numeric_limits::max()); - static const Op& op = Op::Get("collapse_sum_to"); - auto attrs = make_object(); - auto cshape = - MakeConstantTensor(DataType::Int(32), {static_cast(shape.size())}, shape); - attrs->shape = shape; - return Call(op, {data, cshape}, Attrs(attrs)); - }; - concrete_map_[Op::Get("broadcast_to_like")] = [](Expr data, Array shape, - DataType dtype) { - return MakeBroadCastTo(data, shape); - }; - - for (const auto& pr : concrete_map_) { - if (!op_pat_.defined()) { - op_pat_ = IsExpr(pr.first); - } else { - op_pat_ = op_pat_ || IsExpr(pr.first); - } - } - - data_pat_ = IsWildcard(); - like_pat_ = IsWildcard(); - unary_like_pat_ = (IsOp("zeros_like") || IsOp("ones_like"))({like_pat_}); - binary_like_pat_ = (IsOp("reshape_like") || IsOp("collapse_sum_like") || - IsOp("broadcast_to_like"))({data_pat_, like_pat_}); - } - - Expr Callback(const Expr& pre, const Expr& post, - const Map>& node_map) const { - // we will rewrite iff the like argument has fully concrete shape - const CallNode* call_node = post.as(); - ICHECK(call_node); - const OpNode* op_node = call_node->op.as(); - ICHECK(op_node); - const Op op_ref = GetRef(op_node); - ICHECK(concrete_map_.count(op_ref) > 0); - - Expr like = node_map[like_pat_][0]; - - if (!like->checked_type_.defined()) { - // TODO(@altanh): maybe because of the input being rewritten? - return post; - } - - // skip trying to support this for now (ironic, as I was the one who added the feature) - if (const auto* attrs = call_node->attrs.as()) { - if (attrs->lhs_begin != 0 || attrs->rhs_begin != 0 || attrs->lhs_end.defined() || - attrs->rhs_end.defined()) { - return post; - } - } - - CHECK(like->checked_type_.defined()) - << "ConcretizeLike requires checked types to be populated, please run type inference"; - const TensorTypeNode* like_ty = like->checked_type().as(); - ICHECK(like_ty) << "got non-Tensor argument type " << PrettyPrint(like->checked_type()); - - Array cshape; - for (const auto& dim : like_ty->shape) { - if (const auto* imm = dim.as()) { - cshape.push_back(Integer(GetRef(imm))); - continue; - } - return post; - } - - if (call_node->args.size() == 2) { - return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype); - } - return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype); - } - - DFPattern UnaryPattern() const { return unary_like_pat_; } - - DFPattern BinaryPattern() const { return binary_like_pat_; } - - private: - using FMake = std::function, DataType)>; - std::unordered_map concrete_map_; - DFPattern op_pat_; - DFPattern data_pat_; - DFPattern like_pat_; - DFPattern unary_like_pat_; - DFPattern binary_like_pat_; -}; - -namespace transform { - -Pass ConcretizeLike() { - runtime::TypedPackedFunc pass_func = - [](Function f, IRModule m, PassContext pc) { - ConcretizeLikeRewrite rw; - auto callback_func = PackedFunc([&rw](TVMArgs args, TVMRetValue* rv) { - Expr pre = args[0]; - Expr post = args[1]; - Map> node_map = args[2]; - *rv = rw.Callback(pre, post, node_map); - }); - Array callbacks = { - DFPatternCallback(rw.UnaryPattern(), callback_func, true), - DFPatternCallback(rw.BinaryPattern(), callback_func, true)}; - return Downcast(RewritePatterns(callbacks, f, m)); - }; - return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(ConcretizeLike); - -} // namespace transform - -} // namespace relay -} // namespace tvm diff --git a/src/relay/transforms/gradient.h b/src/relay/transforms/gradient.h index 455d86613b8e..2e6ffbcc7c9e 100644 --- a/src/relay/transforms/gradient.h +++ b/src/relay/transforms/gradient.h @@ -27,6 +27,8 @@ #include #include +#include + namespace tvm { namespace relay { diff --git a/tests/python/relay/test_pass_concretize_like.py b/tests/python/relay/test_pass_concretize_like.py deleted file mode 100644 index ea4ec6038494..000000000000 --- a/tests/python/relay/test_pass_concretize_like.py +++ /dev/null @@ -1,108 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Tests for the ConcretizeLike pass.""" -import tvm -import tvm.relay.testing -from tvm import relay -from tvm.relay.testing import run_infer_type - - -def test_reshape_like(): - data = relay.var("data", shape=(2, 3, 4), dtype="float32") - shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") - f = relay.Function([data, shape_like], relay.reshape_like(data, shape_like)) - f_expected = relay.Function([data, shape_like], relay.reshape(data, (6, 2, 2))) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_zeros_like(): - dtype = "int32" - shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) - f = relay.Function([shape_like], relay.zeros_like(shape_like)) - f_expected = relay.Function([shape_like], relay.zeros((3, 4, 5), dtype)) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_ones_like(): - dtype = "int32" - shape_like = relay.var("shape_like", shape=(3, 4, 5), dtype=dtype) - f = relay.Function([shape_like], relay.ones_like(shape_like)) - f_expected = relay.Function([shape_like], relay.ones((3, 4, 5), dtype)) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_collapse_sum_like(): - data = relay.var("data", shape=(3, 3, 3), dtype="float32") - shape_like = relay.var("shape_like", shape=(3,), dtype="float32") - f = relay.Function([data, shape_like], relay.collapse_sum_like(data, shape_like)) - f_expected = relay.Function([data, shape_like], relay.collapse_sum_to(data, (3,))) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_broadcast_to_like(): - data = relay.var("data", shape=(3,), dtype="float32") - shape_like = relay.var("shape_like", shape=(3, 3, 3), dtype="float32") - f = relay.Function([data, shape_like], relay.broadcast_to_like(data, shape_like)) - f_expected = relay.Function([data, shape_like], relay.broadcast_to(data, (3, 3, 3))) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) - - -def test_multiple(): - x = relay.var("x", shape=(2, 3), dtype="float32") - y = relay.var("x", shape=(3,), dtype="float32") - l = x + y - - dl = relay.ones_like(l) - dx = relay.zeros_like(x) - dy = relay.zeros_like(y) - dx = dx + relay.collapse_sum_like(dl, dx) - dy = dy + relay.collapse_sum_like(dl, dy) - ret = relay.Tuple([dx, dy]) - f = relay.Function([x, y], ret) - - dl_c = relay.ones((2, 3), "float32") - dx_c = relay.zeros((2, 3), "float32") - dy_c = relay.zeros((3,), "float32") - dx_c = dx_c + relay.collapse_sum_to(dl_c, (2, 3)) - dy_c = dy_c + relay.collapse_sum_to(dl_c, (3,)) - ret_c = relay.Tuple([dx_c, dy_c]) - f_expected = relay.Function([x, y], ret_c) - f_expected = run_infer_type(f_expected) - - mod = tvm.IRModule.from_expr(f) - mod_concrete = relay.transform.ConcretizeLike()(mod) - assert tvm.ir.structural_equal(mod_concrete["main"], f_expected)