Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Training][Pass] Factor out first-order AD to a module pass #7677

Merged
merged 6 commits into from
Mar 18, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,12 +800,50 @@ def gradient(expr, mod=None, mode="higher_order"):
The transformed expression.
"""
if mode == "first_order":
return _ffi_api.first_order_gradient(expr, mod)
warnings.warn(
"using transform.gradient for first-order AD is deprecated, please use the"
"FirstOrderGradient module pass",
DeprecationWarning,
)
if mod is not None:
raise RuntimeError(
"to run first-order AD on a module, please use the FirstOrderGradient module pass."
)
return FirstOrderGradient()(tvm.IRModule.from_expr(expr))["main"]
if mode == "higher_order":
return _ffi_api.gradient(expr, mod)
raise Exception("unknown mode")


def FirstOrderGradient():
"""
Transforms all global functions in the module to return the original result, paired with the
gradients of the inputs. This pass transforms each global function independently and does not
support interprocedural AD. Additionally, this pass does not support any control-flow or
references, and should only be used on pure data-flow graphs.

Returns
-------
ret : tvm.transform.Pass
The registered FirstOrderGradient pass.
"""
return _ffi_api.FirstOrderGradient()


def ConcretizeLike():
"""
Transforms `op_like` functions to their explicit-shape equivalent (e.g. `zeros_like(x, y)`
to `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary
dependencies and can enable more opportunities for operator fusion.

Returns
-------
ret : tvm.transform.Pass
The registered ConcretizeLike pass.
"""
return _ffi_api.ConcretizeLike()


def Defunctionalization(func, mod):
"""
Performs defunctionalization on func,
Expand Down
160 changes: 160 additions & 0 deletions src/relay/transforms/concretize_like.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file concretize_like.cc
* \brief Converts `*_like` operators to their explicit shape equivalent (e.g. `zeros_like(x, y)` to
* `zeros(x, y.shape)`), when the target shape is concrete. This removes unnecessary dependencies
* and can enable more opportunities for operator fusion.
*/
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/relay/transform.h>

#include "pattern_utils.h"

namespace tvm {
namespace relay {

class ConcretizeLikeRewrite {
public:
ConcretizeLikeRewrite() {
concrete_map_[Op::Get("reshape_like")] = [](Expr data, Array<Integer> shape, DataType dtype) {
return MakeReshape(data, shape);
};
concrete_map_[Op::Get("zeros_like")] = [](Expr data, Array<Integer> shape, DataType dtype) {
return MakeZeros(shape, dtype);
};
concrete_map_[Op::Get("ones_like")] = [](Expr data, Array<Integer> shape, DataType dtype) {
return MakeOnes(shape, dtype);
};
concrete_map_[Op::Get("collapse_sum_like")] = [](Expr data, Array<Integer> shape,
DataType dtype) {
ICHECK_LE(shape.size(), std::numeric_limits<int64_t>::max());
static const Op& op = Op::Get("collapse_sum_to");
auto attrs = make_object<InitOpAttrs>();
auto cshape =
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
attrs->shape = shape;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto attrs = make_object<InitOpAttrs>();
auto cshape =
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);
attrs->shape = shape;
auto attrs = make_object<InitOpAttrs>();
attrs->shape = shape;
auto cshape =
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape);

return Call(op, {data, cshape}, Attrs(attrs));
};
concrete_map_[Op::Get("broadcast_to_like")] = [](Expr data, Array<Integer> shape,
DataType dtype) {
return MakeBroadCastTo(data, shape);
};

for (const auto& pr : concrete_map_) {
if (!op_pat_.defined()) {
op_pat_ = IsExpr(pr.first);
} else {
op_pat_ = op_pat_ || IsExpr(pr.first);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Could you ellaborate what this loop does? It seems to me that it unions all the patterns that match ops in concrete_map_, but I didn't find op_pat_ being used else where.
  2. The construction of concrete_map_ could be static, so we should be able to move it out of the constructor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. thanks for the catch, this was left over from before
  2. true, I'm not sure what the code style is for defining static variables like this but I'll try something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to follow up on the idea of making these static, I'm mainly concerned about how the static lifetime interacts with the Operator registry, do you know how that works?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it would be an issue. If you search static const Op& op in the code base, you'll find lots of use cases.


data_pat_ = IsWildcard();
like_pat_ = IsWildcard();
unary_like_pat_ = (IsOp("zeros_like") || IsOp("ones_like"))({like_pat_});
binary_like_pat_ = (IsOp("reshape_like") || IsOp("collapse_sum_like") ||
IsOp("broadcast_to_like"))({data_pat_, like_pat_});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. These patterns can also be defined statically.
  2. Now we have two places to specify the supported *like ops. I feel we should define a list of "unary like" and "binary like" ops once and used them both in concrete_map_ and here. As a result, we could have the logic similar to L61-67 to construct the pattern.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good points, thanks

}

Expr Callback(const Expr& pre, const Expr& post,
const Map<DFPattern, Array<Expr>>& node_map) const {
// we will rewrite iff the like argument has fully concrete shape
const CallNode* call_node = post.as<CallNode>();
ICHECK(call_node);
const OpNode* op_node = call_node->op.as<OpNode>();
ICHECK(op_node);
const Op op_ref = GetRef<Op>(op_node);
ICHECK(concrete_map_.count(op_ref) > 0);

Expr like = node_map[like_pat_][0];

if (!like->checked_type_.defined()) {
// TODO(@altanh): maybe because of the input being rewritten?
return post;
}

// skip trying to support this for now (ironic, as I was the one who added the feature)
if (const auto* attrs = call_node->attrs.as<ReshapeLikeAttrs>()) {
if (attrs->lhs_begin != 0 || attrs->rhs_begin != 0 || attrs->lhs_end.defined() ||
attrs->rhs_end.defined()) {
return post;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too ad-hoc. It means we may not concretize *like ops in certain situations. Instead of hacking the unified callabck function, we should maintain this logic in the op specific function. I can think of two ways to achieve this:

  1. Put the logic in the beginning of concrete_map_ functions and return the same op if not applicable.
  2. Construct another checker map that include checker functions of each op, and invoke the corresponding checker function here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I agree, thanks


CHECK(like->checked_type_.defined())
<< "ConcretizeLike requires checked types to be populated, please run type inference";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have L88-90 so this check seems useless.

const TensorTypeNode* like_ty = like->checked_type().as<TensorTypeNode>();
ICHECK(like_ty) << "got non-Tensor argument type " << PrettyPrint(like->checked_type());

Array<Integer> cshape;
for (const auto& dim : like_ty->shape) {
if (const auto* imm = dim.as<IntImmNode>()) {
cshape.push_back(Integer(GetRef<IntImm>(imm)));
continue;
}
return post;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (const auto* imm = dim.as<IntImmNode>()) {
cshape.push_back(Integer(GetRef<IntImm>(imm)));
continue;
}
return post;
if (const auto* imm = dim.as<IntImmNode>()) {
cshape.push_back(Integer(GetRef<IntImm>(imm)));
} else {
return post;
}

}

if (call_node->args.size() == 2) {
return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype);
}
return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why empty Expr()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah maybe this is too much of a hack, I'm just using it as a placeholder since the unary matches won't have a corresponding data Expr node. I'll rework this tmrw

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be better to refer to SimplifyExpr pass to separate unary and binary ops. Then maybe we could have a base struct to put the sharable logic.

}

DFPattern UnaryPattern() const { return unary_like_pat_; }

DFPattern BinaryPattern() const { return binary_like_pat_; }

private:
using FMake = std::function<Expr(Expr, Array<Integer>, DataType)>;
std::unordered_map<Op, FMake, ObjectPtrHash, ObjectPtrEqual> concrete_map_;
DFPattern op_pat_;
DFPattern data_pat_;
DFPattern like_pat_;
DFPattern unary_like_pat_;
DFPattern binary_like_pat_;
};

namespace transform {

Pass ConcretizeLike() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[](Function f, IRModule m, PassContext pc) {
ConcretizeLikeRewrite rw;
auto callback_func = PackedFunc([&rw](TVMArgs args, TVMRetValue* rv) {
Expr pre = args[0];
Expr post = args[1];
Map<DFPattern, Array<Expr>> node_map = args[2];
*rv = rw.Callback(pre, post, node_map);
});
Array<DFPatternCallback> callbacks = {
DFPatternCallback(rw.UnaryPattern(), callback_func, true),
DFPatternCallback(rw.BinaryPattern(), callback_func, true)};
return Downcast<Function>(RewritePatterns(callbacks, f, m));
};
return CreateFunctionPass(pass_func, 0, "ConcretizeLike", {});
}

TVM_REGISTER_GLOBAL("relay._transform.ConcretizeLike").set_body_typed(ConcretizeLike);

} // namespace transform

} // namespace relay
} // namespace tvm
Loading