Skip to content

Commit

Permalink
[FIX] Infer input shape in sparse_dense_padded's alter_op if one does…
Browse files Browse the repository at this point in the history
… not exist (apache#7308)

* [FIX] Infer input shape in sparse_dense_padded's alter_op if one does not exist

If there are multiple alter_ops in a model, the first alteration does
not run type inference for the subsequent ones. In this case, we don't
have the shape information, so we run the inferencer manually.

* add todo
  • Loading branch information
tkonolige authored and alexwong committed Feb 11, 2021
1 parent 4e9b543 commit 2f13475
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
9 changes: 8 additions & 1 deletion python/tvm/topi/cuda/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,14 @@ def is_valid_for_sparse_dense_padded(data, weight_data):
"""
# pylint:disable=invalid-name
warp_size = int(tvm.target.Target.current(allow_none=False).thread_warp_size)
m = get_const_tuple(data.checked_type.shape)[1]
# If there are multiple alter_ops in a model, the first alteration does not
# run type inference for the subsequent ones. In this case, we don't have
# the shape information, so we run the inferencer manually.
try:
m = get_const_tuple(data.checked_type.shape)[1]
except ValueError:
data_infered = relay.transform.InferType()(tvm.IRModule.from_expr(data))["main"]
m = get_const_tuple(data_infered.ret_type.shape)[1]
if len(weight_data.shape) == 1:
bs_m = 1
else:
Expand Down
1 change: 1 addition & 0 deletions src/relay/transforms/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class AlterTransformMemorizer : public TransformMemorizer {
* 2. Do not support nested tuple arguments.
*/
Expr AlterOpLayout(const Expr& expr) {
// TODO(@icemelon9): need to rerun type inference after applying an alter op.
AlterTransformMemorizer alterMemorizer(make_object<AlterTransformMemorizerNode>());
auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; };

Expand Down

0 comments on commit 2f13475

Please sign in to comment.