Skip to content

Commit

Permalink
Make the TVM targets list available in Python (apache#7427)
Browse files Browse the repository at this point in the history
* Make the TVM targets list available in Python

Change-Id: I8602723fe57aaf32cee5392d4387a637115dd363

* Rename the APIs to get target kinds

Change-Id: I2e6e32e025e3614a148a30a31e5a2c52fd3563cc
  • Loading branch information
Nicola Lancellotti authored and Lokiiiiii committed Mar 1, 2021
1 parent c8a71ed commit df0f03a
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ class TargetKindRegEntry {
inline TargetKindRegEntry& add_attr_option(const String& key, ObjectRef default_value);
/*! \brief Set name of the TargetKind to be the same as registry if it is empty */
inline TargetKindRegEntry& set_name();
/*!
* \brief List all the entry names in the registry.
* \return The entry names.
*/
TVM_DLL static Array<String> ListTargetKinds();
/*!
* \brief Register or get a new entry.
* \param target_kind_name The name of the TargetKind.
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def mattr(self):
def libs(self):
return list(self.attrs.get("libs", []))

@staticmethod
def list_kinds():
"""Returns the list of available target names."""
return list(_ffi_api.ListTargetKinds())


# TODO(@tvm-team): Deprecate the helper functions below. Encourage the usage of config dict instead.

Expand Down
9 changes: 9 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/ir/expr.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/target.h>
#include <tvm/target/target_kind.h>

Expand All @@ -44,6 +45,10 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

using TargetKindRegistry = AttrRegistry<TargetKindRegEntry, TargetKind>;

Array<String> TargetKindRegEntry::ListTargetKinds() {
return TargetKindRegistry::Global()->ListAllNames();
}

TargetKindRegEntry& TargetKindRegEntry::RegisterOrGet(const String& target_kind_name) {
return TargetKindRegistry::Global()->RegisterOrGet(target_kind_name);
}
Expand Down Expand Up @@ -307,4 +312,8 @@ TVM_REGISTER_TARGET_KIND("composite", kDLCPU)
.add_attr_option<Target>("target_host")
.add_attr_option<Array<Target>>("devices");

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);

} // namespace tvm
6 changes: 6 additions & 0 deletions tests/cpp/target_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ TEST(TargetCreation, DeduplicateKeys) {
ICHECK_EQ(target->GetAttr<Bool>("link-params"), false);
}

TEST(TargetKindRegistryListTargetKinds, Basic) {
Array<String> names = TargetKindRegEntry::ListTargetKinds();
ICHECK_EQ(names.empty(), false);
ICHECK_EQ(std::count(std::begin(names), std::end(names), "llvm"), 1);
}

int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
Expand Down
8 changes: 8 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,18 @@ def test_target_tag_1():
assert tgt.attrs["registers_per_block"] == 32768


def test_list_kinds():
targets = tvm.target.Target.list_kinds()
assert len(targets) != 0
assert "llvm" in targets
assert all(isinstance(target_name, str) for target_name in targets)


if __name__ == "__main__":
test_target_dispatch()
test_target_string_parse()
test_target_create()
test_target_config()
test_config_map()
test_composite_target()
test_list_kinds()

0 comments on commit df0f03a

Please sign in to comment.