From d8e76a3b3fe0e518e6d60d8ca00f9cc2199c1da7 Mon Sep 17 00:00:00 2001 From: Chris Sidebottom Date: Thu, 9 Jun 2022 19:56:31 +0000 Subject: [PATCH] [Target] Add target_parser to TargetKind This adds the `target_parser` as described in https://github.com/apache/tvm-rfcs/pull/71, which parses an incoming `TargetJSON` and produces a new configuration for generating the final `Target` object from. Marks `set_attrs_preprocessor` as deprecated and errors if both `set_attrs_preprocessor` and `set_target_parser` exist together. --- docs/arch/device_target_interactions.rst | 4 +- include/tvm/target/target_kind.h | 23 +++++++++++ src/target/target.cc | 17 ++++++-- src/target/target_kind.cc | 52 ++++++++++++------------ tests/cpp/target_test.cc | 47 +++++++++++++++++++++ 5 files changed, 112 insertions(+), 31 deletions(-) diff --git a/docs/arch/device_target_interactions.rst b/docs/arch/device_target_interactions.rst index 9c391d31bec0..ec8d52226edd 100644 --- a/docs/arch/device_target_interactions.rst +++ b/docs/arch/device_target_interactions.rst @@ -194,8 +194,8 @@ different code generation targets can run on the same physical device. device type.) All options for a specific target kind are added with the -``add_attr_option`` function, with optional default values. A -preprocessor can be added with ``set_attrs_preprocessor`` to define +``add_attr_option`` function, with optional default values. A `Target` +parser can be added with ``set_target_parser`` to process any parameters that are dynamically based on other parameters or queried from device properties. diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 4879470e7654..e20f8547af49 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -37,6 +37,16 @@ namespace tvm { class Target; +/*! + * \brief TargetParser to apply on instantiation of a given TargetKind + * + * \param target_json Target in JSON format to be transformed during parsing. + * + * \return The transformed Target JSON object. + */ +using TargetJSON = Map; +using FTVMTargetParser = TypedPackedFunc; + /*! * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind * @@ -82,6 +92,8 @@ class TargetKindNode : public Object { Array default_keys; /*! \brief Function used to preprocess on target creation */ PackedFunc preprocessor; + /*! \brief Function used to parse a JSON target during creation */ + FTVMTargetParser target_parser; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); @@ -207,6 +219,11 @@ class TargetKindRegEntry { */ template inline TargetKindRegEntry& set_attrs_preprocessor(FLambda f); + /*! + * \brief Set the parsing function applied upon target creation + * \param parser The Target parsing function + */ + inline TargetKindRegEntry& set_target_parser(FTVMTargetParser parser); /*! * \brief Register a valid configuration option and its ValueType for validation * \param key The configuration key @@ -353,11 +370,17 @@ inline TargetKindRegEntry& TargetKindRegEntry::set_default_keys(std::vector inline TargetKindRegEntry& TargetKindRegEntry::set_attrs_preprocessor(FLambda f) { + LOG(WARNING) << "set_attrs_preprocessor is deprecated please use set_target_parser instead"; using FType = typename tvm::runtime::detail::function_signature::FType; kind_->preprocessor = tvm::runtime::TypedPackedFunc(std::move(f)).packed(); return *this; } +inline TargetKindRegEntry& TargetKindRegEntry::set_target_parser(FTVMTargetParser parser) { + kind_->target_parser = parser; + return *this; +} + template inline TargetKindRegEntry& TargetKindRegEntry::add_attr_option(const String& key) { ICHECK(!kind_->key2vtype_.count(key)) diff --git a/src/target/target.cc b/src/target/target.cc index 07b347f09817..a0e689a0a87d 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -52,7 +52,7 @@ class TargetInternal { static ObjectPtr FromString(const String& tag_or_config_or_target_str); static ObjectPtr FromConfigString(const String& config_str); static ObjectPtr FromRawString(const String& target_str); - static ObjectPtr FromConfig(std::unordered_map config); + static ObjectPtr FromConfig(Map config); static void ConstructorDispatcher(TVMArgs args, TVMRetValue* rv); static Target WithHost(const Target& target, const Target& target_host) { ObjectPtr n = make_object(*target.get()); @@ -716,17 +716,27 @@ ObjectPtr TargetInternal::FromRawString(const String& target_str) { return TargetInternal::FromConfig(config); } -ObjectPtr TargetInternal::FromConfig(std::unordered_map config) { +ObjectPtr TargetInternal::FromConfig(Map config) { const String kKind = "kind"; const String kTag = "tag"; const String kKeys = "keys"; const String kDeviceName = "device"; const String kHost = "host"; ObjectPtr target = make_object(); + // parse 'kind' if (config.count(kKind)) { if (const auto* kind = config[kKind].as()) { target->kind = GetTargetKind(GetRef(kind)); + ICHECK(!(target->kind->preprocessor != nullptr && target->kind->target_parser != nullptr)) + << "Cannot use both set_attrs_preprocessor and set_target_parser"; + + // Run JSON Parser over JSON input + if (target->kind->target_parser != nullptr) { + VLOG(9) << "TargetInternal::FromConfig - Running target_parser"; + config = target->kind->target_parser(config); + } + config.erase(kKind); } else { throw Error(": Expect type of field \"kind\" is String, but get type: " + @@ -828,8 +838,9 @@ ObjectPtr TargetInternal::FromConfig(std::unordered_mapattrs = attrs; } + return target; -} +} // namespace tvm std::unordered_map TargetInternal::QueryDevice(int device_id, const TargetNode* target) { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 1148013706ab..7620c6fc2e53 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -145,15 +145,15 @@ void CheckOrSetAttr(Map* attrs, const String& name, const Str /*! * \brief Update the attributes in the CUDA target. - * \param attrs The original attributes + * \param target The Target to update * \return The updated attributes */ -Map UpdateCUDAAttrs(Map attrs) { +TargetJSON UpdateCUDAAttrs(TargetJSON target) { // Update -arch=sm_xx int archInt; - if (attrs.count("arch")) { + if (target.count("arch")) { // If -arch has been specified, validate the correctness - String archStr = Downcast(attrs.at("arch")); + String archStr = Downcast(target.at("arch")); archInt = ExtractIntWithPrefix(archStr, "sm_"); ICHECK(archInt != -1) << "ValueError: CUDA target gets an invalid CUDA arch: -arch=" << archStr; } else { @@ -165,23 +165,23 @@ Map UpdateCUDAAttrs(Map attrs) { } else { archInt = std::stod(version.operator std::string()) * 10 + 0.1; } - attrs.Set("arch", String("sm_") + std::to_string(archInt)); + target.Set("arch", String("sm_") + std::to_string(archInt)); } - return attrs; + return target; } /*! * \brief Update the attributes in the LLVM NVPTX target. - * \param attrs The original attributes + * \param target The Target to update * \return The updated attributes */ -Map UpdateNVPTXAttrs(Map attrs) { - CheckOrSetAttr(&attrs, "mtriple", "nvptx64-nvidia-cuda"); +TargetJSON UpdateNVPTXAttrs(TargetJSON target) { + CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda"); // Update -mcpu=sm_xx int arch; - if (attrs.count("mcpu")) { + if (target.count("mcpu")) { // If -mcpu has been specified, validate the correctness - String mcpu = Downcast(attrs.at("mcpu")); + String mcpu = Downcast(target.at("mcpu")); arch = ExtractIntWithPrefix(mcpu, "sm_"); ICHECK(arch != -1) << "ValueError: NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; } else { @@ -193,22 +193,22 @@ Map UpdateNVPTXAttrs(Map attrs) { } else { arch = std::stod(version.operator std::string()) * 10 + 0.1; } - attrs.Set("mcpu", String("sm_") + std::to_string(arch)); + target.Set("mcpu", String("sm_") + std::to_string(arch)); } - return attrs; + return target; } /*! * \brief Update the attributes in the LLVM ROCm target. - * \param attrs The original attributes + * \param target The Target to update * \return The updated attributes */ -Map UpdateROCmAttrs(Map attrs) { - CheckOrSetAttr(&attrs, "mtriple", "amdgcn-amd-amdhsa-hcc"); +TargetJSON UpdateROCmAttrs(TargetJSON target) { + CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc"); // Update -mcpu=gfx int arch; - if (attrs.count("mcpu")) { - String mcpu = Downcast(attrs.at("mcpu")); + if (target.count("mcpu")) { + String mcpu = Downcast(target.at("mcpu")); arch = ExtractIntWithPrefix(mcpu, "gfx"); ICHECK(arch != -1) << "ValueError: ROCm target gets an invalid GFX version: -mcpu=" << mcpu; } else { @@ -219,7 +219,7 @@ Map UpdateROCmAttrs(Map attrs) { } else { arch = val.operator int(); } - attrs.Set("mcpu", String("gfx") + std::to_string(arch)); + target.Set("mcpu", String("gfx") + std::to_string(arch)); } // Update -mattr before ROCm 3.5: // Before ROCm 3.5 we needed code object v2, starting @@ -235,13 +235,13 @@ Map UpdateROCmAttrs(Map attrs) { } if (version < 305) { Array mattr; - if (attrs.count("mattr")) { - mattr = Downcast>(attrs.at("mattr")); + if (target.count("mattr")) { + mattr = Downcast>(target.at("mattr")); } mattr.push_back("-code-object-v3"); - attrs.Set("mattr", mattr); + target.Set("mattr", mattr); } - return attrs; + return target; } /********** Register Target kinds and attributes **********/ @@ -295,7 +295,7 @@ TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) .add_attr_option("registers_per_block") .add_attr_option("max_num_threads", Integer(1024)) // TODO(@zxybazh): deprecate it .set_default_keys({"cuda", "gpu"}) - .set_attrs_preprocessor(UpdateCUDAAttrs); + .set_target_parser(UpdateCUDAAttrs); TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("mcpu") @@ -304,7 +304,7 @@ TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) .add_attr_option("max_num_threads", Integer(1024)) .add_attr_option("thread_warp_size", Integer(32)) .set_default_keys({"cuda", "gpu"}) - .set_attrs_preprocessor(UpdateNVPTXAttrs); + .set_target_parser(UpdateNVPTXAttrs); TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option("mcpu") @@ -318,7 +318,7 @@ TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) .add_attr_option("max_shared_memory_per_block", Integer(65536)) .add_attr_option("thread_warp_size", Integer(64)) .set_default_keys({"rocm", "gpu"}) - .set_attrs_preprocessor(UpdateROCmAttrs); + .set_target_parser(UpdateROCmAttrs); TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) .add_attr_option("system-lib") diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index 2c85e47e7fb8..6854fc661d0b 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -34,6 +34,36 @@ TVM_REGISTER_TARGET_KIND("TestTargetKind", kDLCPU) .add_attr_option>("your_names") .add_attr_option>("her_maps"); +TargetJSON TestTargetParser(TargetJSON target) { + String mcpu = Downcast(target.at("mcpu")); + target.Set("mcpu", String("super_") + mcpu); + target.Set("keys", Array({"super"})); + return target; +} + +Map TestAttrsPreProcessor(Map attrs) { + attrs.Set("mattr", String("woof")); + return attrs; +} + +TVM_REGISTER_TARGET_KIND("TestTargetParser", kDLCPU) + .add_attr_option("mattr") + .add_attr_option("mcpu") + .set_default_keys({"cpu"}) + .set_target_parser(TestTargetParser); + +TVM_REGISTER_TARGET_KIND("TestAttrsPreprocessor", kDLCPU) + .add_attr_option("mattr") + .set_default_keys({"cpu"}) + .set_attrs_preprocessor(TestAttrsPreProcessor); + +TVM_REGISTER_TARGET_KIND("TestClashingPreprocessor", kDLCPU) + .add_attr_option("mattr") + .add_attr_option("mcpu") + .set_default_keys({"cpu"}) + .set_attrs_preprocessor(TestAttrsPreProcessor) + .set_target_parser(TestTargetParser); + TEST(TargetKind, GetAttrMap) { auto map = tvm::TargetKind::GetAttrMap("Attr1"); auto target_kind = tvm::TargetKind::Get("TestTargetKind").value(); @@ -136,6 +166,23 @@ TEST(TargetCreationFail, TargetKindNotFound) { ASSERT_EQ(failed, true); } +TEST(TargetCreation, TargetParser) { + Target test_target("TestTargetParser -mcpu=woof"); + ASSERT_EQ(test_target->GetAttr("mcpu").value(), "super_woof"); + ASSERT_EQ(test_target->keys.size(), 2); + ASSERT_EQ(test_target->keys[0], "super"); + ASSERT_EQ(test_target->keys[1], "cpu"); +} + +TEST(TargetCreation, TargetAttrsPreProcessor) { + Target test_target("TestAttrsPreprocessor -mattr=cake"); + ASSERT_EQ(test_target->GetAttr("mattr").value(), "woof"); +} + +TEST(TargetCreation, ClashingTargetProcessing) { + EXPECT_THROW(Target("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError); +} + TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true));