Skip to content

Commit

Permalink
[Relax] Add module pass tvm.relax.transform.CheckForSpecialCase
Browse files Browse the repository at this point in the history
  • Loading branch information
csullivan committed Apr 3, 2024
1 parent 8b94f7c commit 03055a2
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/relax/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
BundleModelParams,
CallTIRRewrite,
CanonicalizeBindings,
CheckForSpecialCase,
CombineParallelMatmul,
ComputePrimValue,
ConvertLayout,
Expand Down
29 changes: 29 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,35 @@ def BindSymbolicVars(
return _ffi_api.BindSymbolicVars(binding_map, func_name) # type: ignore


def CheckForSpecialCase(
special_case: Mapping[Union[str, tvm.tir.Var, Var], Union[tvm.tir.PrimExpr, Expr]],
func_name: Optional[str] = None,
) -> tvm.ir.transform.Pass:
"""Bind params of function of the module to constant tensors to produce a
special case
Parameters
----------
special_case : Mapping[Union[str, tvm.tir.Var, Var], Union[tvm.tir.PrimExpr, Expr]],
The map from symbolic varname to integer.
func_name : Optional[str]
The function name to be special cased. If None (default), all
functions within the module will be updated.
Returns
-------
ret: tvm.ir.transform.Pass
"""
# Relax uses int64 for symbolic variables, but the FFI
# converts python integers into int32.
special_case = {
key: tvm.tir.const(value, "int64") if isinstance(value, int) else value
for key, value in special_case.items()
}
return _ffi_api.CheckForSpecialCase(special_case, func_name) # type: ignore


def RunCodegen(
target_options: Optional[dict] = None,
entry_functions: Optional[List[str]] = None,
Expand Down
47 changes: 47 additions & 0 deletions src/relax/transform/check_for_special_case.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,5 +224,52 @@ Function FunctionCheckForSpecialCase(
TVM_REGISTER_GLOBAL("relax.FunctionCheckForSpecialCase")
.set_body_typed(FunctionCheckForSpecialCase);

namespace {
IRModule ModuleCheckForSpecialCase(
IRModule mod,
Map<Variant<tir::Var, relax::Var, String>, Variant<Expr, PrimExpr>> arg_special_case) {
IRModule updates;
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<Function>()) {
auto func = opt.value();
auto new_func = FunctionCheckForSpecialCase(func, arg_special_case);
if (!func.same_as(new_func)) {
updates->Add(gvar, new_func);
}
}
}

if (updates->functions.size()) {
mod.CopyOnWrite()->Update(updates);
}
return mod;
}
} // namespace

namespace transform {

Pass CheckForSpecialCase(
Map<Variant<tir::Var, relax::Var, String>, Variant<Expr, PrimExpr>> arg_special_case,
Optional<String> func_name) {
auto pass_func = [=](IRModule mod, PassContext context) -> IRModule {
if (func_name) {
auto gvar = mod->GetGlobalVar(func_name.value());
auto func = Downcast<Function>(mod->Lookup(gvar));
auto new_func = FunctionCheckForSpecialCase(func, arg_special_case);
if (!func.same_as(new_func)) {
mod.CopyOnWrite()->Update(gvar, new_func);
}
} else {
mod = ModuleCheckForSpecialCase(mod, arg_special_case);
}
return mod;
};

return tvm::transform::CreateModulePass(pass_func, 1, "relax.CheckForSpecialCase", {});
}

TVM_REGISTER_GLOBAL("relax.transform.CheckForSpecialCase").set_body_typed(CheckForSpecialCase);

} // namespace transform
} // namespace relax
} // namespace tvm

0 comments on commit 03055a2

Please sign in to comment.