Skip to content

Commit

Permalink
[Target] Add target_parser to TargetKind
Browse files Browse the repository at this point in the history
This adds the `target_parser` as described in apache/tvm-rfcs#71, which parses an incoming `TargetJSON` and produces a new configuration for generating the final `Target` object from.

Marks `set_attrs_preprocessor` as deprecated and errors if both `set_attrs_preprocessor` and `set_target_parser` exist together.
  • Loading branch information
Mousius committed Jul 18, 2022
1 parent b84ed27 commit d8e76a3
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 31 deletions.
4 changes: 2 additions & 2 deletions docs/arch/device_target_interactions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ different code generation targets can run on the same physical device.
device type.)

All options for a specific target kind are added with the
``add_attr_option`` function, with optional default values. A
preprocessor can be added with ``set_attrs_preprocessor`` to define
``add_attr_option`` function, with optional default values. A `Target`
parser can be added with ``set_target_parser`` to process
any parameters that are dynamically based on other parameters or
queried from device properties.

Expand Down
23 changes: 23 additions & 0 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ namespace tvm {

class Target;

/*!
* \brief TargetParser to apply on instantiation of a given TargetKind
*
* \param target_json Target in JSON format to be transformed during parsing.
*
* \return The transformed Target JSON object.
*/
using TargetJSON = Map<String, ObjectRef>;
using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>;

/*!
* \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
*
Expand Down Expand Up @@ -82,6 +92,8 @@ class TargetKindNode : public Object {
Array<String> default_keys;
/*! \brief Function used to preprocess on target creation */
PackedFunc preprocessor;
/*! \brief Function used to parse a JSON target during creation */
FTVMTargetParser target_parser;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
Expand Down Expand Up @@ -207,6 +219,11 @@ class TargetKindRegEntry {
*/
template <typename FLambda>
inline TargetKindRegEntry& set_attrs_preprocessor(FLambda f);
/*!
* \brief Set the parsing function applied upon target creation
* \param parser The Target parsing function
*/
inline TargetKindRegEntry& set_target_parser(FTVMTargetParser parser);
/*!
* \brief Register a valid configuration option and its ValueType for validation
* \param key The configuration key
Expand Down Expand Up @@ -353,11 +370,17 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector<Stri

template <typename FLambda>
inline TargetKindRegEntry& TargetKindRegEntry::set_attrs_preprocessor(FLambda f) {
LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead";
using FType = typename tvm::runtime::detail::function_signature<FLambda>::FType;
kind_->preprocessor = tvm::runtime::TypedPackedFunc<FType>(std::move(f)).packed();
return *this;
}

inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParser parser) {
kind_->target_parser = parser;
return *this;
}

template <typename ValueType>
inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) {
ICHECK(!kind_->key2vtype_.count(key))
Expand Down
17 changes: 14 additions & 3 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class TargetInternal {
static ObjectPtr<Object> FromString(const String& tag_or_config_or_target_str);
static ObjectPtr<Object> FromConfigString(const String& config_str);
static ObjectPtr<Object> FromRawString(const String& target_str);
static ObjectPtr<Object> FromConfig(std::unordered_map<String, ObjectRef> config);
static ObjectPtr<Object> FromConfig(Map<String, ObjectRef> config);
static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv);
static Target WithHost(const Target& target, const Target& target_host) {
ObjectPtr<TargetNode> n = make_object<TargetNode>(*target.get());
Expand Down Expand Up @@ -716,17 +716,27 @@ ObjectPtr<Object> TargetInternal::FromRawString(const String& target_str) {
return TargetInternal::FromConfig(config);
}

ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRef> config) {
ObjectPtr<Object> TargetInternal::FromConfig(Map<String, ObjectRef> config) {
const String kKind = "kind";
const String kTag = "tag";
const String kKeys = "keys";
const String kDeviceName = "device";
const String kHost = "host";
ObjectPtr<TargetNode> target = make_object<TargetNode>();

// parse 'kind'
if (config.count(kKind)) {
if (const auto* kind = config[kKind].as<StringObj>()) {
target->kind = GetTargetKind(GetRef<String>(kind));
ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr))
<< "Cannot use both set_attrs_preprocessor and set_target_parser";

// Run JSON Parser over JSON input
if (target->kind->target_parser != nullptr) {
VLOG(9) << "TargetInternal::FromConfig - Running target_parser";
config = target->kind->target_parser(config);
}

config.erase(kKind);
} else {
throw Error(": Expect type of field \"kind\" is String, but get type: " +
Expand Down Expand Up @@ -828,8 +838,9 @@ ObjectPtr<Object> TargetInternal::FromConfig(std::unordered_map<String, ObjectRe
} else {
target->attrs = attrs;
}

return target;
}
} // namespace tvm

std::unordered_map<String, ObjectRef> TargetInternal::QueryDevice(int device_id,
const TargetNode* target) {
Expand Down
52 changes: 26 additions & 26 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,15 @@ void CheckOrSetAttr(Map<String, ObjectRef>* attrs, const String& name, const Str

/*!
* \brief Update the attributes in the CUDA target.
* \param attrs The original attributes
* \param target The Target to update
* \return The updated attributes
*/
Map<String, ObjectRef> UpdateCUDAAttrs(Map<String, ObjectRef> attrs) {
TargetJSON UpdateCUDAAttrs(TargetJSON target) {
// Update -arch=sm_xx
int archInt;
if (attrs.count("arch")) {
if (target.count("arch")) {
// If -arch has been specified, validate the correctness
String archStr = Downcast<String>(attrs.at("arch"));
String archStr = Downcast<String>(target.at("arch"));
archInt = ExtractIntWithPrefix(archStr, "sm_");
ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr;
} else {
Expand All @@ -165,23 +165,23 @@ Map<String, ObjectRef> UpdateCUDAAttrs(Map<String, ObjectRef> attrs) {
} else {
archInt = std::stod(version.operator std::string()) * 10 + 0.1;
}
attrs.Set("arch", String("sm_") + std::to_string(archInt));
target.Set("arch", String("sm_") + std::to_string(archInt));
}
return attrs;
return target;
}

/*!
* \brief Update the attributes in the LLVM NVPTX target.
* \param attrs The original attributes
* \param target The Target to update
* \return The updated attributes
*/
Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
CheckOrSetAttr(&attrs, "mtriple", "nvptx64-nvidia-cuda");
TargetJSON UpdateNVPTXAttrs(TargetJSON target) {
CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda");
// Update -mcpu=sm_xx
int arch;
if (attrs.count("mcpu")) {
if (target.count("mcpu")) {
// If -mcpu has been specified, validate the correctness
String mcpu = Downcast<String>(attrs.at("mcpu"));
String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "sm_");
ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu;
} else {
Expand All @@ -193,22 +193,22 @@ Map<String, ObjectRef> UpdateNVPTXAttrs(Map<String, ObjectRef> attrs) {
} else {
arch = std::stod(version.operator std::string()) * 10 + 0.1;
}
attrs.Set("mcpu", String("sm_") + std::to_string(arch));
target.Set("mcpu", String("sm_") + std::to_string(arch));
}
return attrs;
return target;
}

/*!
* \brief Update the attributes in the LLVM ROCm target.
* \param attrs The original attributes
* \param target The Target to update
* \return The updated attributes
*/
Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
CheckOrSetAttr(&attrs, "mtriple", "amdgcn-amd-amdhsa-hcc");
TargetJSON UpdateROCmAttrs(TargetJSON target) {
CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc");
// Update -mcpu=gfx
int arch;
if (attrs.count("mcpu")) {
String mcpu = Downcast<String>(attrs.at("mcpu"));
if (target.count("mcpu")) {
String mcpu = Downcast<String>(target.at("mcpu"));
arch = ExtractIntWithPrefix(mcpu, "gfx");
ICHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu;
} else {
Expand All @@ -219,7 +219,7 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
} else {
arch = val.operator int();
}
attrs.Set("mcpu", String("gfx") + std::to_string(arch));
target.Set("mcpu", String("gfx") + std::to_string(arch));
}
// Update -mattr before ROCm 3.5:
// Before ROCm 3.5 we needed code object v2, starting
Expand All @@ -235,13 +235,13 @@ Map<String, ObjectRef> UpdateROCmAttrs(Map<String, ObjectRef> attrs) {
}
if (version < 305) {
Array<String> mattr;
if (attrs.count("mattr")) {
mattr = Downcast<Array<String>>(attrs.at("mattr"));
if (target.count("mattr")) {
mattr = Downcast<Array<String>>(target.at("mattr"));
}
mattr.push_back("-code-object-v3");
attrs.Set("mattr", mattr);
target.Set("mattr", mattr);
}
return attrs;
return target;
}

/********** Register Target kinds and attributes **********/
Expand Down Expand Up @@ -295,7 +295,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA)
.add_attr_option<Integer>("registers_per_block")
.add_attr_option<Integer>("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it
.set_default_keys({"cuda", "gpu"})
.set_attrs_preprocessor(UpdateCUDAAttrs);
.set_target_parser(UpdateCUDAAttrs);

TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
.add_attr_option<String>("mcpu")
Expand All @@ -304,7 +304,7 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA)
.add_attr_option<Integer>("max_num_threads", Integer(1024))
.add_attr_option<Integer>("thread_warp_size", Integer(32))
.set_default_keys({"cuda", "gpu"})
.set_attrs_preprocessor(UpdateNVPTXAttrs);
.set_target_parser(UpdateNVPTXAttrs);

TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<String>("mcpu")
Expand All @@ -318,7 +318,7 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM)
.add_attr_option<Integer>("max_shared_memory_per_block", Integer(65536))
.add_attr_option<Integer>("thread_warp_size", Integer(64))
.set_default_keys({"rocm", "gpu"})
.set_attrs_preprocessor(UpdateROCmAttrs);
.set_target_parser(UpdateROCmAttrs);

TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
.add_attr_option<Bool>("system-lib")
Expand Down
47 changes: 47 additions & 0 deletions tests/cpp/target_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,36 @@ TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU)
.add_attr_option<Array<String>>("your_names")
.add_attr_option<Map<String, Integer>>("her_maps");

TargetJSON TestTargetParser(TargetJSON target) {
String mcpu = Downcast<String>(target.at("mcpu"));
target.Set("mcpu", String("super_") + mcpu);
target.Set("keys", Array<String>({"super"}));
return target;
}

Map<String, ObjectRef> TestAttrsPreProcessor(Map<String, ObjectRef> attrs) {
attrs.Set("mattr", String("woof"));
return attrs;
}

TVM_REGISTER_TARGET_KIND("TestTargetParser", kDLCPU)
.add_attr_option<String>("mattr")
.add_attr_option<String>("mcpu")
.set_default_keys({"cpu"})
.set_target_parser(TestTargetParser);

TVM_REGISTER_TARGET_KIND("TestAttrsPreprocessor", kDLCPU)
.add_attr_option<String>("mattr")
.set_default_keys({"cpu"})
.set_attrs_preprocessor(TestAttrsPreProcessor);

TVM_REGISTER_TARGET_KIND("TestClashingPreprocessor", kDLCPU)
.add_attr_option<String>("mattr")
.add_attr_option<String>("mcpu")
.set_default_keys({"cpu"})
.set_attrs_preprocessor(TestAttrsPreProcessor)
.set_target_parser(TestTargetParser);

TEST(TargetKind, GetAttrMap) {
auto map = tvm::TargetKind::GetAttrMap<std::string>("Attr1");
auto target_kind = tvm::TargetKind::Get("TestTargetKind").value();
Expand Down Expand Up @@ -136,6 +166,23 @@ TEST(TargetCreationFail, TargetKindNotFound) {
ASSERT_EQ(failed, true);
}

TEST(TargetCreation, TargetParser) {
Target test_target("TestTargetParser -mcpu=woof");
ASSERT_EQ(test_target->GetAttr<String>("mcpu").value(), "super_woof");
ASSERT_EQ(test_target->keys.size(), 2);
ASSERT_EQ(test_target->keys[0], "super");
ASSERT_EQ(test_target->keys[1], "cpu");
}

TEST(TargetCreation, TargetAttrsPreProcessor) {
Target test_target("TestAttrsPreprocessor -mattr=cake");
ASSERT_EQ(test_target->GetAttr<String>("mattr").value(), "woof");
}

TEST(TargetCreation, ClashingTargetProcessing) {
EXPECT_THROW(Target("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError);
}

TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA)
.set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));

Expand Down

0 comments on commit d8e76a3

Please sign in to comment.