Skip to content

Commit

Permalink
Implementation of relay_to_tir target hook
Browse files Browse the repository at this point in the history
This the first new hook proposed in the Additional Target Hooks RFC, longer
term the compilation should move to using `Target` proper but this unblocks our current work whilst illustrating the eventual interface via `Target` in `target_kind.cc`

I've encapsulated the hook lookup into a method on `TargetKind` (`GetRegisteredHook`), which will eventually mean that the logic can be compacted to:
```
func->target->kind.GetRegisteredHook()
```
  • Loading branch information
Mousius committed Sep 13, 2021
1 parent 16a1e9b commit 69300e7
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 11 deletions.
7 changes: 7 additions & 0 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ class TargetKind : public ObjectRef {
TVM_DLL static Optional<TargetKind> Get(const String& target_kind_name);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TargetKind, ObjectRef, TargetKindNode);

/*!
* \brief Look up for TargetKind registered hooks
* \param hook_name Name of the registered hook
* \return The associated PackedFunc for the hook
*/
TVM_DLL const PackedFunc* GetRegisteredHook(const String& hook_name) const;

private:
/*! \brief Mutable access to the container class */
TargetKindNode* operator->() { return static_cast<TargetKindNode*>(data_.get()); }
Expand Down
47 changes: 36 additions & 11 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,22 @@ using namespace tvm::relay::transform;

TVM_REGISTER_OBJECT_TYPE(TECompilerNode);

/*!
* \brief Get target hook from function after checking TargetKind registry
*
* \param func - Function to get hook from
* \param hook_name - Name of hook to acquire
* \return Pointer to the packed function in the registry or nullptr if not found
*/
const PackedFunc* GetTargetHookFromFunction(const Function& func, const String& hook_name) {
auto code_gen_name = func->GetAttr<String>(attr::kCompiler).value();
auto target_kind = tvm::TargetKind::Get(code_gen_name);
if (target_kind) {
return target_kind.value().GetRegisteredHook(hook_name);
}
return nullptr;
}

class TECompilerImpl : public TECompilerNode {
public:
// Lower the function.
Expand Down Expand Up @@ -135,10 +151,12 @@ class TECompilerImpl : public TECompilerNode {
auto src_func = it.first->source_func;
ICHECK(src_func.defined());
if (src_func->GetAttr<String>(attr::kCompiler).defined()) {
auto code_gen = src_func->GetAttr<String>(attr::kCompiler);
std::string code_gen_name = code_gen.value();
// Skip this function if it was actually lowered to TIR instead of a Runtime Module
if (GetTargetHookFromFunction(src_func, "relay_to_tir") != nullptr) {
continue;
}
cached_ext_funcs.push_back(it.first);

auto code_gen_name = src_func->GetAttr<String>(attr::kCompiler).value();
auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
Expand Down Expand Up @@ -208,17 +226,28 @@ class TECompilerImpl : public TECompilerNode {
}
cur_ccache_key_ = key;

// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
auto ir_module = IRModule();
const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
auto func_name = GetUniqueName(name_node.value(), &name_map_);
auto target = Target("ext_dev");
auto global_var = GlobalVar(func_name);
global_var->checked_type_ = key->source_func->checked_type();

auto ir_module = IRModule();
ir_module->Add(global_var, key->source_func);

// Lower to TIR if we have a registered lowering hook
auto custom_lowering_to_tir = GetTargetHookFromFunction(key->source_func, "relay_to_tir");
if (custom_lowering_to_tir != nullptr) {
IRModule lowered_module = (*custom_lowering_to_tir)(ir_module, key->source_func);
value->cached_func =
CachedFunc(key->target, global_var, {}, {}, te::Schedule(), {}, lowered_module);
return value;
}

// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
auto target = Target("ext_dev");
value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
return value;
}
Expand Down Expand Up @@ -597,13 +626,9 @@ class LowerTensorExprMutator : public ExprMutator {
};

Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
TECompiler compiler, std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function func, IRModule module, PassContext ctx) {
LowerTensorExprMutator lower_te(module, targets, device_context_map, process_fn,
module_name, compiler);
return Downcast<Function>(lower_te.Mutate(func));
};
return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
}
Expand Down
12 changes: 12 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ Optional<TargetKind> TargetKind::Get(const String& target_kind_name) {
return reg->kind_;
}

const PackedFunc* TargetKind::GetRegisteredHook(const String& hook_name) const {
auto map = tvm::TargetKind::GetAttrMap<String>(hook_name);
if (map.count(*this)) {
std::string hook_function = map[*this];
return tvm::runtime::Registry::Get(hook_function);
}
return nullptr;
}

/********** Utility functions **********/

/*!
Expand Down Expand Up @@ -356,6 +365,9 @@ TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break

TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option<Array<Target>>("devices");

TVM_REGISTER_TARGET_KIND("test", kDLCPU)
.set_attr<String>("relay_to_tir", "target.test.tir_lowering");

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
Expand Down
19 changes: 19 additions & 0 deletions tests/cpp/target_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@

using namespace tvm;

TVM_REGISTER_GLOBAL("target.test_kind.test_registered_function")
.set_body_typed([](IRModule mod, Target target) { return mod; });

TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU)
.set_attr<std::string>("Attr1", "Value1")
.set_attr<String>("known_hook", "target.test_kind.test_registered_function")
.set_attr<String>("unknown_hook", "target.test_kind.test_not_registered_function")
.add_attr_option<Bool>("my_bool")
.add_attr_option<Array<String>>("your_names")
.add_attr_option<Map<String, Integer>>("her_maps");
Expand Down Expand Up @@ -157,3 +162,17 @@ TEST(TargetKindRegistryListTargetKinds, Basic) {
ICHECK_EQ(names.empty(), false);
ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1);
}

TEST(TargetHookCheck, HookRegisteredNonNull) {
auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
const PackedFunc* target_hook =
tvm::runtime::Registry::Get("target.test_kind.test_registered_function");
ICHECK_NE(target_hook, (const PackedFunc*)nullptr);
ICHECK_EQ(target_kind.GetRegisteredHook("known_hook"), target_hook);
}

TEST(TargetHookCheck, HookRegisteredNull) {
auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
const PackedFunc* unknown_func = nullptr;
ICHECK_EQ(target_kind.GetRegisteredHook("unknown_hook"), unknown_func);
}
88 changes: 88 additions & 0 deletions tests/python/relay/test_target_hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for target hooks."""
import sys
import numpy as np
import pytest

import tvm
import tvm.relay.testing
import tvm.relay.transform

from tvm import relay
from utils.external_codegen import (
set_external_func_attr,
check_aot_executor_result,
check_graph_executor_result,
)


def translate_relay_add_to_tir_subtract(ir_module, relay_func):
"""A transform to test Relay -> TIR with"""
ib = tvm.tir.ir_builder.create()
A = tvm.tir.decl_buffer(
dtype=relay_func.params[0].checked_type.dtype,
name=relay_func.params[0].name_hint,
shape=relay_func.params[0].checked_type.shape,
)
B = tvm.tir.decl_buffer(
dtype=relay_func.params[1].checked_type.dtype,
name=relay_func.params[1].name_hint,
shape=relay_func.params[1].checked_type.shape,
)
C = tvm.tir.decl_buffer(dtype=relay_func.ret_type.dtype, shape=relay_func.ret_type.shape)

Ap = ib.buffer_ptr(A)
Bp = ib.buffer_ptr(B)
Cp = ib.buffer_ptr(C)

with ib.for_range(0, 8, name="i") as i:
with ib.for_range(0, 8, name="j") as j:
row = i * 8
Cp[row + j] = Ap[row + j] - Bp[row + j]

prim_func = tvm.tir.PrimFunc([A, B, C], ib.get())

ir_module = tvm.lower(prim_func, name=relay_func.attrs["global_symbol"])
return ir_module


@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_aot_executor_result])
def test_tir_external_generation(check_result):
tvm.register_func("target.test.tir_lowering", translate_relay_add_to_tir_subtract, True)

shape = (8, 8)
x_data = np.random.randint(255, size=shape).astype("float32")
y_data = np.random.randint(255, size=shape).astype("float32")
inputs = {"x": x_data, "y": y_data}

x0 = relay.var("x0", shape=shape, dtype="float32")
y0 = relay.var("y0", shape=shape, dtype="float32")
z = x0 + y0
f = relay.Function([x0, y0], z)
f = set_external_func_attr(f, "test", "replace_add_with_subtract")

x = relay.var("x", shape=(8, 8), dtype="float32")
y = relay.var("y", shape=(8, 8), dtype="float32")
call = relay.Call(f, [x, y])
func = tvm.IRModule.from_expr(call)

check_result(func, inputs, (8, 8), x_data - y_data)


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 69300e7

Please sign in to comment.