diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index b4987eb321cf..1c190f313aba 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -122,7 +122,7 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) break; } case kDeviceName: - *rv = prop.device_name; + *rv = String(prop.device_name); break; case kMaxClockRate: diff --git a/src/target/target.cc b/src/target/target.cc index df810185784e..ea897adb77b8 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -29,6 +29,7 @@ #include #include +#include #include #include "../runtime/object_internal.h" @@ -146,17 +147,83 @@ static int FindFirstSubstr(const std::string& str, const std::string& substr) { } static Optional JoinString(const std::vector& array, char separator) { + char escape = '\\'; + char quote = '\''; + if (array.empty()) { return NullOpt; } + std::ostringstream os; - os << array[0]; - for (size_t i = 1; i < array.size(); ++i) { - os << separator << array[i]; + + for (size_t i = 0; i < array.size(); ++i) { + if (i > 0) { + os << separator; + } + + std::string str = array[i]; + + if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) { + os << str; + } else { + os << quote; + for (char c : str) { + if (c == separator || c == quote) { + os << escape; + } + os << c; + } + os << quote; + } } return String(os.str()); } +static std::vector SplitString(const std::string& str, char separator) { + char escape = '\\'; + char quote = '\''; + + std::vector output; + + const char* start = str.data(); + const char* end = start + str.size(); + const char* pos = start; + + std::stringstream current_word; + + auto finish_word = [&]() { + std::string word = current_word.str(); + if (word.size()) { + output.push_back(word); + current_word.str(""); + } + }; + + bool pos_quoted = false; + + while (pos < end) { + if ((*pos == separator) && !pos_quoted) { + finish_word(); + pos++; + } else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) { + current_word << quote; + pos += 2; + } else if (*pos == quote) { + pos_quoted = !pos_quoted; + pos++; + } else { + current_word << *pos; + pos++; + } + } + + ICHECK(!pos_quoted) << "Mismatched quotes '' in string"; + + finish_word(); + + return output; +} + static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key, std::string* value) { int pos; @@ -206,9 +273,9 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { - std::istringstream is(str); if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer + std::istringstream is(str); int v; if (!(is >> v)) { std::string lower(str.size(), '\x0'); @@ -225,19 +292,18 @@ ObjectRef TargetInternal::ParseType(const std::string& str, } return Integer(v); } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing string - std::string v; - if (!(is >> v)) { - throw Error(": Cannot parse into type \"String\" from string: " + str); - } - return String(v); + // Parsing string, strip leading/trailing spaces + auto start = str.find_first_not_of(' '); + auto end = str.find_last_not_of(' '); + return String(str.substr(start, (end - start + 1))); + } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing target return Target(TargetInternal::FromString(str)); } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { // Parsing array std::vector result; - for (std::string substr; std::getline(is, substr, ',');) { + for (const std::string& substr : SplitString(str, ',')) { try { ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); result.push_back(parsed); @@ -549,24 +615,14 @@ ObjectPtr TargetInternal::FromConfigString(const String& config_str) { } ObjectPtr TargetInternal::FromRawString(const String& target_str) { + ICHECK_GT(target_str.length(), 0) << "Cannot parse empty target string"; // Split the string by empty spaces - std::string name; - std::vector options; - std::string str; - for (std::istringstream is(target_str); is >> str;) { - if (name.empty()) { - name = str; - } else { - options.push_back(str); - } - } - if (name.empty()) { - throw Error(": Cannot parse empty target string"); - } + std::vector options = SplitString(std::string(target_str), ' '); + std::string name = options[0]; // Create the target config std::unordered_map config = {{"kind", String(name)}}; TargetKind kind = GetTargetKind(name); - for (size_t iter = 0, end = options.size(); iter < end;) { + for (size_t iter = 1, end = options.size(); iter < end;) { std::string key, value; try { // Parse key-value pair diff --git a/tests/python/unittest/test_target_target.py b/tests/python/unittest/test_target_target.py index bb3aa9e86267..5007ef13e4d8 100644 --- a/tests/python/unittest/test_target_target.py +++ b/tests/python/unittest/test_target_target.py @@ -76,6 +76,14 @@ def test_target_string_parse(): assert tvm.target.arm_cpu().device_name == "arm_cpu" +def test_target_string_with_spaces(): + target = tvm.target.Target( + "vulkan -device_name='Name of GPU with spaces' -device_type=discrete" + ) + assert target.attrs["device_name"] == "Name of GPU with spaces" + assert target.attrs["device_type"] == "discrete" + + def test_target_create(): targets = [cuda(), rocm(), mali(), intel_graphics(), arm_cpu("rk3399"), vta(), bifrost()] for tgt in targets: