-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Training][Pass] Factor out first-order AD to a module pass #7677
Conversation
auto attrs = make_object<InitOpAttrs>(); | ||
auto cshape = | ||
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape); | ||
attrs->shape = shape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto attrs = make_object<InitOpAttrs>(); | |
auto cshape = | |
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape); | |
attrs->shape = shape; | |
auto attrs = make_object<InitOpAttrs>(); | |
attrs->shape = shape; | |
auto cshape = | |
MakeConstantTensor(DataType::Int(32), {static_cast<int64_t>(shape.size())}, shape); |
if (const auto* imm = dim.as<IntImmNode>()) { | ||
cshape.push_back(Integer(GetRef<IntImm>(imm))); | ||
continue; | ||
} | ||
return post; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (const auto* imm = dim.as<IntImmNode>()) { | |
cshape.push_back(Integer(GetRef<IntImm>(imm))); | |
continue; | |
} | |
return post; | |
if (const auto* imm = dim.as<IntImmNode>()) { | |
cshape.push_back(Integer(GetRef<IntImm>(imm))); | |
} else { | |
return post; | |
} |
if (call_node->args.size() == 2) { | ||
return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype); | ||
} | ||
return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why empty Expr()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah maybe this is too much of a hack, I'm just using it as a placeholder since the unary matches won't have a corresponding data Expr node. I'll rework this tmrw
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to refer to SimplifyExpr pass to separate unary and binary ops. Then maybe we could have a base struct to put the sharable logic.
cc @comaniac |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also cc @yzhliu
for (const auto& pr : concrete_map_) { | ||
if (!op_pat_.defined()) { | ||
op_pat_ = IsExpr(pr.first); | ||
} else { | ||
op_pat_ = op_pat_ || IsExpr(pr.first); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Could you ellaborate what this loop does? It seems to me that it unions all the patterns that match ops in
concrete_map_
, but I didn't findop_pat_
being used else where. - The construction of
concrete_map_
could be static, so we should be able to move it out of the constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- thanks for the catch, this was left over from before
- true, I'm not sure what the code style is for defining static variables like this but I'll try something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just to follow up on the idea of making these static, I'm mainly concerned about how the static lifetime interacts with the Operator registry, do you know how that works?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it would be an issue. If you search static const Op& op
in the code base, you'll find lots of use cases.
data_pat_ = IsWildcard(); | ||
like_pat_ = IsWildcard(); | ||
unary_like_pat_ = (IsOp("zeros_like") || IsOp("ones_like"))({like_pat_}); | ||
binary_like_pat_ = (IsOp("reshape_like") || IsOp("collapse_sum_like") || | ||
IsOp("broadcast_to_like"))({data_pat_, like_pat_}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- These patterns can also be defined statically.
- Now we have two places to specify the supported *like ops. I feel we should define a list of "unary like" and "binary like" ops once and used them both in
concrete_map_
and here. As a result, we could have the logic similar to L61-67 to construct the pattern.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good points, thanks
// skip trying to support this for now (ironic, as I was the one who added the feature) | ||
if (const auto* attrs = call_node->attrs.as<ReshapeLikeAttrs>()) { | ||
if (attrs->lhs_begin != 0 || attrs->rhs_begin != 0 || attrs->lhs_end.defined() || | ||
attrs->rhs_end.defined()) { | ||
return post; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is too ad-hoc. It means we may not concretize *like ops in certain situations. Instead of hacking the unified callabck function, we should maintain this logic in the op specific function. I can think of two ways to achieve this:
- Put the logic in the beginning of
concrete_map_
functions and return the same op if not applicable. - Construct another checker map that include checker functions of each op, and invoke the corresponding checker function here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I agree, thanks
CHECK(like->checked_type_.defined()) | ||
<< "ConcretizeLike requires checked types to be populated, please run type inference"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have L88-90 so this check seems useless.
if (call_node->args.size() == 2) { | ||
return concrete_map_.at(op_ref)(node_map[data_pat_][0], cshape, like_ty->dtype); | ||
} | ||
return concrete_map_.at(op_ref)(Expr(), cshape, like_ty->dtype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to refer to SimplifyExpr pass to separate unary and binary ops. Then maybe we could have a base struct to put the sharable logic.
data = relay.var("data", shape=(2, 3, 4), dtype="float32") | ||
shape_like = relay.var("shape_like", shape=(6, 2, 2), dtype="float32") | ||
f = relay.Function([data, shape_like], relay.reshape_like(data, shape_like)) | ||
f_expected = relay.Function([data, shape_like], relay.reshape(data, (6, 2, 2))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This exposes a question of the current implementation: after the ConcretizeLike pass we expect unused arguments. If we don't need shape_like tensor anymore, we should remove it from the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a good point that I was wondering about while I wrote the tests. However this might be complicated, as removing an parameter means we'll need to remove the argument at each callsite (and then maybe re-apply the optimization until fixed point). I imagine we won't want to remove the parameter from a global function as then the user will have to ensure they don't pass the removed argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be fine to have another internal pass being called after concretizing like ops to check and mutate the function signatures.
|
||
mod = tvm.IRModule.from_expr(f) | ||
mod_concrete = relay.transform.ConcretizeLike()(mod) | ||
assert tvm.ir.structural_equal(mod_concrete["main"], f_expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add if __name__ == "__main__":
@comaniac thanks for the helpful feedback, it seems there's some design decisions that would be worth combining with existing approaches. I'll spend some time this week to incorporate the feedback, but in the meantime, I think I'll remove the ConcretizeLike pass from this PR and send a follow up (since the AD refactor can stand alone). How does that sound? |
Sounds good to me and I'm good with the first-order AD changes in general. |
@MarisaKirisame please take another look and https://tvm.apache.org/docs/contribute/code_review.html#approve-and-request-changes-explicitly |
Thanks @altanh @comaniac @MarisaKirisame |
Since there is negligible code sharing between the first-order and higher-order AD in Relay, I've factored out the first-order AD pass to a separate file. Additionally, I've made it a proper IRModule pass, and fixed a few spots where adding tuples wasn't being lifted correctly. I tried my best to add some diagnostics.
To complement AD, I've also added a pass called
ConcretizeLike
that transforms the*_like
operators generated during AD to the explicit-shape equivalent (e.g.zeros_like(x, y) -> zeros(x, y.shape)
), which removes some forward dependencies and could enable more opportunity for operator fusion. Note that this is the same thing asLikeZapp
from @t-vi's blogpost with perhaps a bit more error checking.cc @t-vi @MarisaKirisame @tqchen @jroesch