Skip to content

Commit

Permalink
Update collecting of symbolic variables in InferSymbolicVarMap
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Jan 2, 2024
1 parent 5b3c262 commit ec5ffc6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
18 changes: 18 additions & 0 deletions src/relax/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
}
};

auto bind_from_prim_value = [&bind_from_prim_expr](const StructInfo& var,
const StructInfo& expr) {
auto var_sinfo = var.as<PrimStructInfoNode>();
if (!var_sinfo) return;

auto expr_sinfo = expr.as<PrimStructInfoNode>();
CHECK(expr_sinfo) << "Cannot bind expression with struct type " << expr
<< " to variable with struct type " << var;
CHECK_EQ(var_sinfo->dtype, expr_sinfo->dtype)
<< "Cannot bind expression with struct type " << expr << " to variable with struct type "
<< var << ", due to conflicting PrimExpr DataType";

if (!var_sinfo->value.defined() || !expr_sinfo->value.defined()) return;

bind_from_prim_expr(var_sinfo->value.value(), expr_sinfo->value.value());
};

auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const StructInfo& expr) {
auto var_shape = var.as<ShapeStructInfoNode>();
if (!var_shape) return;
Expand Down Expand Up @@ -195,6 +212,7 @@ tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(

bind_from_tensor(var_sinfo, expr_sinfo);
bind_from_shape(var_sinfo, expr_sinfo);
bind_from_prim_value(var_sinfo, expr_sinfo);
}

return tir_var_remap;
Expand Down
5 changes: 4 additions & 1 deletion tests/python/relax/test_bind_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,14 @@ def expected() -> R.Shape([16]):


def test_bind_prim_value(prim_value_dtype):
if prim_value_dtype != "int64":
pytest.xfail(reason="Currently, only support int64 as known symbolic value")

N = tir.Var("N", prim_value_dtype)
value = tir.const(16, prim_value_dtype)

@R.function
def before(A: R.Prim(value=N)):
def before(A: R.Prim(value=N)) -> R.Prim(value=N):
R.func_attr({"global_symbol": "main"})
B: R.Prim(value=N) = A
return B
Expand Down

0 comments on commit ec5ffc6

Please sign in to comment.