diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc index 9441f8af5d27..cd3a99655341 100644 --- a/src/relay/transforms/gradient.cc +++ b/src/relay/transforms/gradient.cc @@ -181,6 +181,22 @@ struct FirstOrderReverseAD : ExprFunctor { return ret; } + Expr UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) { + if (t.as()) { + return ll->Push(Add(arg, grad)); + } else if (auto* tt = t.as()) { + Array updates; + for (size_t i = 0; i < tt->fields.size(); ++i) { + updates.push_back(this->UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), + ll->Push(GetField(grad, i)), ll)); + } + return ll->Push(Tuple(updates)); + } else { + LOG(FATAL) << "unsupported arg type of operator: " << t; + throw; + } + } + ADValue VisitExpr_(const OpNode* op) final { Op op_ref = GetRef(op); ICHECK(rev_map.count(op_ref)) << op->name << " does not have reverse mode defined"; @@ -198,8 +214,10 @@ struct FirstOrderReverseAD : ExprFunctor { tvm::Array rev = rev_map[op_ref](orig, ret->reverse); ICHECK(args.size() == rev.size()); for (size_t i = 0; i < args.size(); ++i) { + auto ad_arg = args[i]->get(); + auto ad_arg_type = ad_arg.forward->checked_type(); args[i]->get().reverse = - ll->Push(Add(args[i]->get().reverse, rev[i])); + this->UpdateGrad(ad_arg_type, ad_arg.reverse, rev[i], ll); } }); return ret; diff --git a/tests/python/relay/test_pass_gradient.py b/tests/python/relay/test_pass_gradient.py index 93bad3a19c53..0604ed51272c 100644 --- a/tests/python/relay/test_pass_gradient.py +++ b/tests/python/relay/test_pass_gradient.py @@ -255,6 +255,29 @@ def _test_tuple(mode): tvm.testing.assert_allclose(grad_z.asnumpy(), -1 * np.ones_like(grad_z.asnumpy())) +def _test_tuple_argument(mode): + shape = (2, 3) + dtype = "float32" + tensor_type = relay.TensorType(shape, dtype) + fields = 3 + tuple_type = relay.TupleType([tensor_type] * fields) + tup = relay.var("tup", type_annotation=tuple_type) + body = relay.TupleGetItem(tup, 0) + for i in range(1, fields): + body = relay.add(body, relay.TupleGetItem(tup, i)) + func = relay.Function([tup], body) + func = run_infer_type(func) + back_func = run_infer_type(gradient(func, mode=mode)) + xs = [rand(dtype, *shape) for _ in range(fields)] + xs_np = np.array([x.asnumpy() for x in xs]) + expected_forward = np.sum(xs_np, axis=0) + ex = create_executor() + forward, grad = ex.evaluate(back_func)(tuple(xs)) + tvm.testing.assert_allclose(forward.asnumpy(), expected_forward) + for field in grad[0]: + tvm.testing.assert_allclose(field.asnumpy(), np.ones_like(field.asnumpy())) + + def test_tuple(): _test_tuple("higher_order") @@ -263,6 +286,16 @@ def test_tuple_first_order(): _test_tuple("first_order") +@pytest.mark.xfail(raises=tvm.error.TVMError) +def test_tuple_argument(): + # fails until we add support for top-level tuple arguments in higher-order AD + _test_tuple_argument("higher_order") + + +def test_tuple_argument_first_order(): + _test_tuple_argument("first_order") + + def test_pow(): mod = tvm.IRModule() p = Prelude(mod)