Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] Revert C-API Device capability improvements #6772

Merged
merged 1 commit into from
Nov 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 7 additions & 183 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,179 +679,6 @@ class Event {
}
};

class CapabilityLevelConfigBuilder;
class CapabilityLevelConfig {
public:
std::vector<TiCapabilityLevelInfo> cap_level_infos;

CapabilityLevelConfig() : cap_level_infos() {
}
CapabilityLevelConfig(std::vector<TiCapabilityLevelInfo> &&capabilities)
: cap_level_infos(std::move(capabilities)) {
}

static CapabilityLevelConfigBuilder builder();

uint32_t get(TiCapability capability) const {
for (size_t i = 0; i < cap_level_infos.size(); ++i) {
const TiCapabilityLevelInfo &cap_level_info = cap_level_infos.at(i);
if (cap_level_info.capability == capability) {
return cap_level_info.level;
}
}
return 0;
}

void set(TiCapability capability, uint32_t level) {
std::vector<TiCapabilityLevelInfo>::iterator it = cap_level_infos.begin();
for (; it != cap_level_infos.end(); ++it) {
if (it->capability == capability) {
it->level = level;
return;
}
}
TiCapabilityLevelInfo cap_level_info{};
cap_level_info.capability = capability;
cap_level_info.level = level;
cap_level_infos.emplace_back(std::move(cap_level_info));
}
};

class CapabilityLevelConfigBuilder {
typedef CapabilityLevelConfigBuilder Self;
std::map<TiCapability, uint32_t> cap_level_infos;

public:
CapabilityLevelConfigBuilder() : cap_level_infos() {
}
CapabilityLevelConfigBuilder(const Self &) = delete;
Self &operator=(const Self &) = delete;

Self &spirv_version(uint32_t major, uint32_t minor) {
if (major == 1) {
if (minor == 3) {
cap_level_infos[TI_CAPABILITY_SPIRV_VERSION] = 0x10300;
} else if (minor == 4) {
cap_level_infos[TI_CAPABILITY_SPIRV_VERSION] = 0x10400;
} else if (minor == 5) {
cap_level_infos[TI_CAPABILITY_SPIRV_VERSION] = 0x10500;
} else {
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "minor");
}
} else {
ti_set_last_error(TI_ERROR_ARGUMENT_OUT_OF_RANGE, "major");
}
return *this;
}
Self &spirv_has_int8(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_INT8] = value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_int16(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_INT16] = value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_int64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_INT64] = value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_float16(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_FLOAT16] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_float64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_FLOAT64] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_i64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_I64] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float16(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float16_add(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16_ADD] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float16_minmax(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT16_MINMAX] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float64(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_atomic_float64_add(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_variable_ptr(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_VARIABLE_PTR] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_physical_storage_buffer(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_PHYSICAL_STORAGE_BUFFER] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_basic(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_BASIC] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_vote(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_VOTE] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_arithmetic(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_ARITHMETIC] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_subgroup_ballot(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_SUBGROUP_BALLOT] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_non_semantic_info(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_NON_SEMANTIC_INFO] =
value ? TI_TRUE : TI_FALSE;
return *this;
}
Self &spirv_has_no_integer_wrap_decoration(bool value = true) {
cap_level_infos[TI_CAPABILITY_SPIRV_HAS_NO_INTEGER_WRAP_DECORATION] =
value ? TI_TRUE : TI_FALSE;
return *this;
}

CapabilityLevelConfig build() {
std::vector<TiCapabilityLevelInfo> out{};
for (const auto &pair : cap_level_infos) {
TiCapabilityLevelInfo cap_level_info{};
cap_level_info.capability = pair.first;
cap_level_info.level = pair.second;
out.emplace_back(std::move(cap_level_info));
}
return CapabilityLevelConfig{std::move(out)};
}
};

inline CapabilityLevelConfigBuilder CapabilityLevelConfig::builder() {
return {};
}

class Runtime {
TiArch arch_{TI_ARCH_MAX_ENUM};
TiRuntime runtime_{TI_NULL_HANDLE};
Expand Down Expand Up @@ -895,20 +722,17 @@ class Runtime {
return *this;
}

void set_capabilities_ext(
const std::vector<TiCapabilityLevelInfo> &capabilities) {
ti_set_runtime_capabilities_ext(runtime_, (uint32_t)capabilities.size(),
capabilities.data());
}
void set_capabilities_ext(const CapabilityLevelConfig &capabilities) {
set_capabilities_ext(capabilities.cap_level_infos);
}
CapabilityLevelConfig get_capabilities() const {
std::map<TiCapability, uint32_t> get_capabilities() const {
uint32_t n = 0;
ti_get_runtime_capabilities(runtime_, &n, nullptr);
std::vector<TiCapabilityLevelInfo> devcaps(n);
ti_get_runtime_capabilities(runtime_, &n, devcaps.data());
return CapabilityLevelConfig{std::move(devcaps)};

std::map<TiCapability, uint32_t> out{};
for (auto devcap : devcaps) {
out[devcap.capability] = devcap.level;
}
return out;
}

Memory allocate_memory(const TiMemoryAllocateInfo &allocate_info) {
Expand Down
6 changes: 0 additions & 6 deletions c_api/include/taichi/taichi_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -851,12 +851,6 @@ TI_DLL_EXPORT TiRuntime TI_API_CALL ti_create_runtime(TiArch arch);
// Destroys a Taichi Runtime.
TI_DLL_EXPORT void TI_API_CALL ti_destroy_runtime(TiRuntime runtime);

// Function `ti_set_runtime_capabilities_ext`
TI_DLL_EXPORT void TI_API_CALL
ti_set_runtime_capabilities_ext(TiRuntime runtime,
uint32_t capability_count,
const TiCapabilityLevelInfo *capabilities);

// Function `ti_get_runtime_capabilities`
TI_DLL_EXPORT void TI_API_CALL
ti_get_runtime_capabilities(TiRuntime runtime,
Expand Down
21 changes: 1 addition & 20 deletions c_api/src/taichi_core_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,25 +265,6 @@ void ti_destroy_runtime(TiRuntime runtime) {
TI_CAPI_TRY_CATCH_END();
}

void ti_set_runtime_capabilities_ext(
TiRuntime runtime,
uint32_t capability_count,
const TiCapabilityLevelInfo *capabilities) {
TI_CAPI_TRY_CATCH_BEGIN();
TI_CAPI_ARGUMENT_NULL(runtime);

Runtime *runtime2 = (Runtime *)runtime;
taichi::lang::DeviceCapabilityConfig devcaps;
for (uint32_t i = 0; i < capability_count; ++i) {
const auto &cap_level_info = capabilities[i];
devcaps.set((taichi::lang::DeviceCapability)cap_level_info.capability,
cap_level_info.level);
}
runtime2->get().set_caps(std::move(devcaps));

TI_CAPI_TRY_CATCH_END();
}

void ti_get_runtime_capabilities(TiRuntime runtime,
uint32_t *capability_count,
TiCapabilityLevelInfo *capabilities) {
Expand All @@ -292,7 +273,7 @@ void ti_get_runtime_capabilities(TiRuntime runtime,

Runtime *runtime2 = (Runtime *)runtime;
const taichi::lang::DeviceCapabilityConfig &devcaps =
runtime2->get().get_caps();
runtime2->get().get_current_caps();

if (capability_count == nullptr) {
return;
Expand Down
2 changes: 1 addition & 1 deletion c_api/src/taichi_gfx_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Error GfxRuntime::create_aot_module(const taichi::io::VirtualDir *dir,
}

const taichi::lang::DeviceCapabilityConfig &current_devcaps =
params.runtime->get_ti_device()->get_caps();
params.runtime->get_ti_device()->get_current_caps();
const taichi::lang::DeviceCapabilityConfig &required_devcaps =
aot_module->get_required_caps();
for (const auto &pair : required_devcaps.devcaps) {
Expand Down
2 changes: 1 addition & 1 deletion c_api/src/taichi_opengl_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ OpenglRuntime::OpenglRuntime()
caps.set(taichi::lang::DeviceCapability::spirv_has_int64, true);
caps.set(taichi::lang::DeviceCapability::spirv_has_float64, true);
caps.set(taichi::lang::DeviceCapability::spirv_version, 0x10300);
get_gl().set_caps(std::move(caps));
get_gl().set_current_caps(std::move(caps));
}
taichi::lang::Device &OpenglRuntime::get() {
return static_cast<taichi::lang::Device &>(device_);
Expand Down
2 changes: 1 addition & 1 deletion c_api/src/taichi_vulkan_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ VulkanRuntimeImported::Workaround::Workaround(
}
*/

vk_device.set_caps(std::move(caps));
vk_device.set_current_caps(std::move(caps));
vk_device.init_vulkan_structs(
const_cast<taichi::lang::vulkan::VulkanDevice::Params &>(params));
}
Expand Down
19 changes: 0 additions & 19 deletions c_api/taichi.json
Original file line number Diff line number Diff line change
Expand Up @@ -531,25 +531,6 @@
}
]
},
{
"name": "set_runtime_capabilities",
"type": "function",
"is_extension": true,
"parameters": [
{
"type": "handle.runtime"
},
{
"name": "capability_count",
"type": "uint32_t"
},
{
"name": "capabilities",
"type": "structure.capability_level_info",
"count": "capability_count"
}
]
},
{
"name": "get_runtime_capabilities",
"type": "function",
Expand Down
51 changes: 3 additions & 48 deletions c_api/tests/c_api_interface_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,54 +38,9 @@ TEST_F(CapiTest, DryRunCapabilities) {
{
ti::Runtime runtime(TI_ARCH_VULKAN);
auto devcaps = runtime.get_capabilities();
auto level = devcaps.get(TI_CAPABILITY_SPIRV_VERSION);
assert(level >= 0x10000);
}
}
}

TEST_F(CapiTest, SetCapabilities) {
if (capi::utils::is_vulkan_available()) {
// Vulkan Runtime
{
ti::Runtime runtime(TI_ARCH_VULKAN);

{
auto devcaps = ti::CapabilityLevelConfig::builder()
.spirv_version(1, 3)
.spirv_has_atomic_float64_add()
.build();
runtime.set_capabilities_ext(devcaps);
auto devcaps2 = runtime.get_capabilities();
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_VERSION) == 0x10300);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD) ==
1);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64) == 0);
}
{
auto devcaps =
ti::CapabilityLevelConfig::builder().spirv_version(1, 4).build();
runtime.set_capabilities_ext(devcaps);
auto devcaps2 = runtime.get_capabilities();
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_VERSION) == 0x10400);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD) ==
0);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64) == 0);
}
{
auto devcaps = ti::CapabilityLevelConfig::builder()
.spirv_version(1, 5)
.spirv_has_atomic_float64()
.spirv_has_atomic_float64(false)
.spirv_has_atomic_float64(true)
.build();
runtime.set_capabilities_ext(devcaps);
auto devcaps2 = runtime.get_capabilities();
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_VERSION) == 0x10500);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64_ADD) ==
0);
TI_ASSERT(devcaps2.get(TI_CAPABILITY_SPIRV_HAS_ATOMIC_FLOAT64) == 1);
}
auto it = devcaps.find(TI_CAPABILITY_SPIRV_VERSION);
assert(it != devcaps.end());
assert(it->second >= 0x10000);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion taichi/cache/gfx/cache_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ CompiledKernelData CacheManager::load_or_compile(CompileConfig *config,
if (kernel->is_evaluator) {
spirv::lower(kernel);
return gfx::run_codegen(kernel, runtime_->get_ti_device()->arch(),
runtime_->get_ti_device()->get_caps(),
runtime_->get_ti_device()->get_current_caps(),
compiled_structs_);
}
std::string kernel_key = make_kernel_key(config, kernel);
Expand Down
Loading