Skip to content

Commit

Permalink
[aot] C-API Device capability improvements (#6702)
Browse files Browse the repository at this point in the history
API adjustments, tests and etc.

This PR improves the user experience issue for C-API users to work with
device capability, especially when the `TiRuntime` is imported from an
outer context.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ailing  <ailzhang@users.noreply.github.com>
  • Loading branch information
3 people committed Nov 30, 2022
1 parent 4f7f2f5 commit c27a2e4
Show file tree
Hide file tree
Showing 23 changed files with 374 additions and 104 deletions.
190 changes: 183 additions & 7 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,179 @@ class Event {
}
};

class CapabilityLevelConfigBuilder;
class CapabilityLevelConfig {
public:
std::vector<TiCapabilityLevelInfo> cap_level_infos;

CapabilityLevelConfig() : cap_level_infos() {
}
CapabilityLevelConfig(std::vector<TiCapabilityLevelInfo> &&capabilities)
: cap_level_infos(std::move(capabilities)) {
}

static CapabilityLevelConfigBuilder builder();

uint32_t get(TiCapability capability) const {
for (size_t i = 0; i < cap_level_infos.size(); ++i) {
const TiCapabilityLevelInfo &cap_level_info = cap_level_infos.at(i);
if (cap_level_info.capability == capability) {
return cap_level_info.level;
}
}
return 0;
}

void set(TiCapability capability, uint32_t level) {
std::vector<TiCapabilityLevelInfo>::iterator it = cap_level_infos.begin();
for (; it != cap_level_infos.end(); ++it) {
if (it->capability == capability) {
it->level = level;
return;
}
}
TiCapabilityLevelInfo cap_level_info{};
cap_level_info.capability = capability;
cap_level_info.level = level;
cap_level_infos.emplace_back(std::move(cap_level_info));
}
};

class CapabilityLevelConfigBuilder {
typedef CapabilityLevelConfigBuilder Self;
std::map<TiCapability, uint32_t> cap_level_infos;

public:
CapabilityLevelConfigBuilder() : cap_level_infos() {
}
CapabilityLevelConfigBuilder(const Self &) = delete;
Self &operator=(const Self &) = delete;

Self &spirv_version(uint32_t major, uint32_t minor) {
if (major == 1) {
if (minor == 3) {
cap_level_infos[TI_CAPABILITY_SPIRV_VERSION] = 0x10300;
} else if (minor == 4) {
cap_level_infos[TI_CAPABILITY_SPIRV_VERSION] = 0x10400;
} else if (minor == 5) {
cap_level_infos[TI_CAPABILITY_SPIRV_VERSION] = 0x10500;
} else {
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "minor");
}
} else {
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "major");
}
return *this;
}
Self &spirv_has_int8(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_INT8] = value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_int16(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_INT16] = value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_int64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_INT64] = value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_float16(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_FLOAT16] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_float64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_FLOAT64] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_i64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_I64] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float16(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float16_add(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16_ADD] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float16_minmax(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16_MINMAX] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float64_add(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_variable_ptr(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_VARIABLE_PTR] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_physical_storage_buffer(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_PHYSICAL_STORAGE_BUFFER] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_basic(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_BASIC] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_vote(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_VOTE] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_arithmetic(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_ARITHMETIC] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_ballot(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_BALLOT] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_non_semantic_info(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_NON_SEMANTIC_INFO] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_no_integer_wrap_decoration(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_NO_INTEGER_WRAP_DECORATION] =
value ? TI_TRUE : TI_FALSE;
return *this;
}

CapabilityLevelConfig build() {
std::vector<TiCapabilityLevelInfo> out{};
for (const auto &pair : cap_level_infos) {
TiCapabilityLevelInfo cap_level_info{};
cap_level_info.capability = pair.first;
cap_level_info.level = pair.second;
out.emplace_back(std::move(cap_level_info));
}
return CapabilityLevelConfig{std::move(out)};
}
};

inline CapabilityLevelConfigBuilder CapabilityLevelConfig::builder() {
return {};
}

class Runtime {
TiArch arch_{TI_ARCH_MAX_ENUM};
TiRuntime runtime_{TI_NULL_HANDLE};
Expand Down Expand Up @@ -722,17 +895,20 @@ class Runtime {
return *this;
}

std::map<TiCapability, uint32_t> get_capabilities() const {
void set_capabilities_ext(
const std::vector<TiCapabilityLevelInfo> &capabilities) {
ti_set_runtime_capabilities_ext(runtime_, (uint32_t)capabilities.size(),
capabilities.data());
}
void set_capabilities_ext(const CapabilityLevelConfig &capabilities) {
set_capabilities_ext(capabilities.cap_level_infos);
}
CapabilityLevelConfig get_capabilities() const {
uint32_t n = 0;
ti_get_runtime_capabilities(runtime_, &n, nullptr);
std::vector<TiCapabilityLevelInfo> devcaps(n);
ti_get_runtime_capabilities(runtime_, &n, devcaps.data());

std::map<TiCapability, uint32_t> out{};
for (auto devcap : devcaps) {
out[devcap.capability] = devcap.level;
}
return out;
return CapabilityLevelConfig{std::move(devcaps)};
}

Memory allocate_memory(const TiMemoryAllocateInfo &allocate_info) {
Expand Down
6 changes: 6 additions & 0 deletions c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,12 @@ TI_DLL_EXPORT TiRuntime TI_API_CALL ti_create_runtime(TiArch arch);
// Destroys a Taichi Runtime.
TI_DLL_EXPORT void TI_API_CALL ti_destroy_runtime(TiRuntime runtime);

// Function `ti_set_runtime_capabilities_ext`
TI_DLL_EXPORT void TI_API_CALL
ti_set_runtime_capabilities_ext(TiRuntime runtime,
uint32_t capability_count,
const TiCapabilityLevelInfo *capabilities);

// Function `ti_get_runtime_capabilities`
TI_DLL_EXPORT void TI_API_CALL
ti_get_runtime_capabilities(TiRuntime runtime,
Expand Down
21 changes: 20 additions & 1 deletion c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,25 @@ void ti_destroy_runtime(TiRuntime runtime) {
TI_CAPI_TRY_CATCH_END();
}

void ti_set_runtime_capabilities_ext(
TiRuntime runtime,
uint32_t capability_count,
const TiCapabilityLevelInfo *capabilities) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);

Runtime *runtime2 = (Runtime *)runtime;
taichi::lang::DeviceCapabilityConfig devcaps;
for (uint32_t i = 0; i < capability_count; ++i) {
const auto &cap_level_info = capabilities[i];
devcaps.set((taichi::lang::DeviceCapability)cap_level_info.capability,
cap_level_info.level);
}
runtime2->get().set_caps(std::move(devcaps));

TI_CAPI_TRY_CATCH_END();
}

void ti_get_runtime_capabilities(TiRuntime runtime,
uint32_t *capability_count,
TiCapabilityLevelInfo *capabilities) {
Expand All @@ -273,7 +292,7 @@ void ti_get_runtime_capabilities(TiRuntime runtime,

Runtime *runtime2 = (Runtime *)runtime;
const taichi::lang::DeviceCapabilityConfig &devcaps =
runtime2->get().get_current_caps();
runtime2->get().get_caps();

if (capability_count == nullptr) {
return;
Expand Down
2 changes: 1 addition & 1 deletion c_api/src/taichi_gfx_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Error GfxRuntime::create_aot_module(const taichi::io::VirtualDir *dir,
}

const taichi::lang::DeviceCapabilityConfig &current_devcaps =
params.runtime->get_ti_device()->get_current_caps();
params.runtime->get_ti_device()->get_caps();
const taichi::lang::DeviceCapabilityConfig &required_devcaps =
aot_module->get_required_caps();
for (const auto &pair : required_devcaps.devcaps) {
Expand Down
2 changes: 1 addition & 1 deletion c_api/src/taichi_opengl_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ OpenglRuntime::OpenglRuntime()
caps.set(taichi::lang::DeviceCapability::spirv_has_int64, true);
caps.set(taichi::lang::DeviceCapability::spirv_has_float64, true);
caps.set(taichi::lang::DeviceCapability::spirv_version, 0x10300);
get_gl().set_current_caps(std::move(caps));
get_gl().set_caps(std::move(caps));
}
taichi::lang::Device &OpenglRuntime::get() {
return static_cast<taichi::lang::Device &>(device_);
Expand Down
2 changes: 1 addition & 1 deletion c_api/src/taichi_vulkan_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ VulkanRuntimeImported::Workaround::Workaround(
}
*/

vk_device.set_current_caps(std::move(caps));
vk_device.set_caps(std::move(caps));
vk_device.init_vulkan_structs(
const_cast<taichi::lang::vulkan::VulkanDevice::Params &>(params));
}
Expand Down
19 changes: 19 additions & 0 deletions c_api/taichi.json
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,25 @@
}
]
},
{
"name": "set_runtime_capabilities",
"type": "function",
"is_extension": true,
"parameters": [
{
"type": "handle.runtime"
},
{
"name": "capability_count",
"type": "uint32_t"
},
{
"name": "capabilities",
"type": "structure.capability_level_info",
"count": "capability_count"
}
]
},
{
"name": "get_runtime_capabilities",
"type": "function",
Expand Down
51 changes: 48 additions & 3 deletions c_api/tests/c_api_interface_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,54 @@ TEST_F(CapiTest, DryRunCapabilities) {
{
ti::Runtime runtime(TI_ARCH_VULKAN);
auto devcaps = runtime.get_capabilities();
auto it = devcaps.find(TI_CAPABILITY_SPIRV_VERSION);
assert(it != devcaps.end());
assert(it->second >= 0x10000);
auto level = devcaps.get(TI_CAPABILITY_SPIRV_VERSION);
assert(level >= 0x10000);
}
}
}

TEST_F(CapiTest, SetCapabilities) {
if (capi::utils::is_vulkan_available()) {
// Vulkan Runtime
{
ti::Runtime runtime(TI_ARCH_VULKAN);

{
auto devcaps = ti::CapabilityLevelConfig::builder()
.spirv_version(1, 3)
.spirv_has_atomic_float64_add()
.build();
runtime.set_capabilities_ext(devcaps);
auto devcaps2 = runtime.get_capabilities();
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_VERSION) == 0x10300);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD) ==
1);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64) == 0);
}
{
auto devcaps =
ti::CapabilityLevelConfig::builder().spirv_version(1, 4).build();
runtime.set_capabilities_ext(devcaps);
auto devcaps2 = runtime.get_capabilities();
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_VERSION) == 0x10400);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD) ==
0);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64) == 0);
}
{
auto devcaps = ti::CapabilityLevelConfig::builder()
.spirv_version(1, 5)
.spirv_has_atomic_float64()
.spirv_has_atomic_float64(false)
.spirv_has_atomic_float64(true)
.build();
runtime.set_capabilities_ext(devcaps);
auto devcaps2 = runtime.get_capabilities();
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_VERSION) == 0x10500);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD) ==
0);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64) == 1);
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/cache/gfx/cache_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ CompiledKernelData CacheManager::load_or_compile(CompileConfig *config,
if (kernel->is_evaluator) {
spirv::lower(kernel);
return gfx::run_codegen(kernel, runtime_->get_ti_device()->arch(),
runtime_->get_ti_device()->get_current_caps(),
runtime_->get_ti_device()->get_caps(),
compiled_structs_);
}
std::string kernel_key = make_kernel_key(config, kernel);
Expand Down
Loading

0 comments on commit c27a2e4

Please sign in to comment.