Skip to content

Commit

Permalink
[Relax] Implement Function.check_for_special_case
Browse files Browse the repository at this point in the history
If a dynamic model is frequently called with specific
arguments or shapes of arguments, performance may be improved by
generating to specialized versions of the model.  Previously,
specialized versions of a relax function `func` could be generated
using `func.bind_params` and `func.bind_symbolic_vars`.  However, use
of these specialized versions requires the calling scope to explicitly
check the preconditions of each kernel and call the appropriate one.

This commit implements a new utility, `check_for_special_case`, which
handles both the generating of the special case, and checking whether
the special case applies.  The function's user-facing signature is
unmodified, while internally it delegates to either the original
function or the specialized version depending on the result of the
check.  This allows optimized kernels for specific static shapes to be
introduced solely by changing the optimization pipeline, with no
changes required in the calling scope.
  • Loading branch information
Lunderberg committed Apr 2, 2024
1 parent c20cdaf commit 8b94f7c
Show file tree
Hide file tree
Showing 7 changed files with 476 additions and 26 deletions.
59 changes: 40 additions & 19 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,18 @@ def __call__(self, *args):
"""
return Call(self, args, None, None)

@staticmethod
def _normalize_value(value):
"""Conversions that must occur prior to the FFI conversions"""
if isinstance(value, int):
# Relax uses int64 for symbolic variables, but the FFI
# converts python integers into int32.
return tvm.tir.const(value, "int64")
elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)):
return tvm.relax.const(value)
else:
return value

def bind_symbolic_vars(
self, binding_map: Mapping[Union[str, tvm.tir.Var], PrimExpr]
) -> "Function":
Expand All @@ -1042,15 +1054,36 @@ def bind_symbolic_vars(
The updated function
"""

# Relax uses int64 for symbolic variables, but the FFI
# converts python integers into int32.
binding_map = {
key: tvm.tir.const(value, "int64") if isinstance(value, int) else value
for key, value in binding_map.items()
}
binding_map = {key: self._normalize_value(value) for key, value in binding_map.items()}

return _ffi_api.FunctionBindSymbolicVars(self, binding_map) # type: ignore

def check_for_special_case(
self, special_case: Mapping[Union[str, tvm.tir.Var, Var], Union[PrimExpr, Expr]]
) -> "Function":
"""Return a new function with updated symbolic variable
Parameters
----------
binding_map: Mapping[Union[str, tvm.tir.Var], Union[PrimExpr,Expr]]
The mapping of values to be replaced. Keys may be either
a `relax.Var, a `tir.Var` or a string providing the name
of the variable. If the variables are referred to by
name, the name must uniquely identify the `tir.Var` or
`relax.Var` in the function signature.
Returns
-------
func: Function
The updated function
"""

special_case = {key: self._normalize_value(value) for key, value in special_case.items()}

return _ffi_api.FunctionCheckForSpecialCase(self, special_case) # type: ignore

def bind_params(
self,
binding_map: Mapping[
Expand Down Expand Up @@ -1085,19 +1118,7 @@ def bind_params(
The updated function
"""

def _normalize_value(value):
# Conversions that must occur prior to the FFI
# conversions.
if isinstance(value, int):
# Relax uses int64 for symbolic variables, but the FFI
# converts python integers into int32.
return tvm.tir.const(value, "int64")
elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)):
return tvm.relax.const(value)
else:
return value

binding_map = {key: _normalize_value(value) for key, value in binding_map.items()}
binding_map = {key: self._normalize_value(value) for key, value in binding_map.items()}

return _ffi_api.FunctionBindParams(self, binding_map) # type: ignore

Expand Down
93 changes: 87 additions & 6 deletions src/relax/ir/block_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <tvm/tir/function.h>

#include <memory>
#include <optional>
#include <unordered_map>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -201,7 +202,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
}
}
}
scope_stack_.emplace_back(ScopeFrame({std::move(shape_var_map)}));
scope_stack_.emplace_back(ScopeFrame({params, std::move(shape_var_map)}));
}

void EndScope() final { scope_stack_.pop_back(); }
Expand Down Expand Up @@ -314,6 +315,15 @@ class BlockBuilderImpl : public BlockBuilderNode {
// Consider impl alternative: merge with block frame if we have more frame kinds.
//
// TODO(relax-team) tracks the var defined also through match-cast.

/*! \brief The parameters used to define this scope
*
* Can be used to copy the parent scope, for cases that should
* inherit definitions from their parent, but not expose new
* definitions into the parent.
*/
Optional<Array<Var>> params;

/*! \brief set of defined symbolic vars, value as themself. */
Map<tir::Var, PrimExpr> shape_var_map;
};
Expand Down Expand Up @@ -705,8 +715,11 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&

Expr VisitExpr_(const IfNode* op) final {
Expr new_cond = this->NormalizeArgument(op->cond);
Expr new_true = this->VisitWithNewScope(op->true_branch);
Expr new_false = this->VisitWithNewScope(op->false_branch);

Optional<Array<Var>> scope_params = scope_stack_.size() ? scope_stack_.back().params : NullOpt;

Expr new_true = this->VisitWithNewScope(op->true_branch, scope_params);
Expr new_false = this->VisitWithNewScope(op->false_branch, scope_params);

If if_node;
if (new_cond.same_as(op->cond) && new_true.same_as(op->true_branch) &&
Expand All @@ -716,9 +729,77 @@ class Normalizer : public BlockBuilderImpl, private ExprFunctor<Expr(const Expr&
if_node = If(new_cond, new_true, new_false, op->span);
}
if (!if_node->struct_info_.defined()) {
auto true_info = EraseToWellDefinedInScope(GetStructInfo(new_true));
auto false_info = EraseToWellDefinedInScope(GetStructInfo(new_false));
UpdateStructInfo(if_node, StructInfoLCA(true_info, false_info));
StructInfo lca = [&]() -> StructInfo {
StructuralEqual struct_equal;

auto true_info = GetStructInfo(new_true);
auto false_info = GetStructInfo(new_false);

if (struct_equal(true_info, false_info)) {
return true_info;
}

auto prim_cond = [&]() -> Optional<PrimExpr> {
Expr cond_expr = new_cond;
while (true) {
if (auto var = cond_expr.as<Var>()) {
if (auto value = LookupBinding(var.value())) {
cond_expr = value.value();
}
}
break;
}

if (auto prim_value = cond_expr.as<PrimValueNode>()) {
return prim_value->value;
}

if (auto prim_sinfo = cond_expr->struct_info_.as<PrimStructInfoNode>()) {
if (prim_sinfo->value.defined()) {
return prim_sinfo->value.value();
}
}
return NullOpt;
}();

arith::Analyzer* analyzer = GetAnalyzer();

if (!prim_cond.defined()) {
return StructInfoLCA(true_info, false_info, analyzer);
}

// auto true_info = EraseToWellDefinedInScope(GetStructInfo(new_true));
// auto false_info = EraseToWellDefinedInScope(GetStructInfo(new_false));

{
// The struct info returned in the "then" branch is a special
// case, and the "else" branch returns the general case.
std::optional<With<arith::ConstraintContext>> context;
if (prim_cond.defined()) {
context.emplace(analyzer, prim_cond.value());
}
auto then_lca = StructInfoLCA(true_info, false_info, GetAnalyzer());
if (struct_equal(true_info, then_lca)) {
return false_info;
}
}
{
// The struct info returned in the "else" branch is a special
// case, and the "then" branch returns the general case.
std::optional<With<arith::ConstraintContext>> context;
if (prim_cond.defined()) {
context.emplace(analyzer, !prim_cond.value());
}
auto else_lca = StructInfoLCA(true_info, false_info, GetAnalyzer());
if (struct_equal(false_info, else_lca)) {
return true_info;
}
}

return StructInfoLCA(true_info, false_info, analyzer);
}();

UpdateStructInfo(if_node, lca);
}
return if_node;
}
Expand Down
Loading

0 comments on commit 8b94f7c

Please sign in to comment.