Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metal] Pull out the Runtime MSL code into its own module #3086

Merged
merged 3 commits into from
Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions taichi/backends/metal/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ namespace lang {
namespace metal {

AotModuleBuilderImpl::AotModuleBuilderImpl(
const CompiledRuntimeModule *compiled_runtime_module,
const CompiledStructs *compiled_structs,
const BufferMetaData &buffer_meta_data)
: compiled_structs_(compiled_structs), buffer_meta_data_(buffer_meta_data) {
: compiled_runtime_module_(compiled_runtime_module),
compiled_structs_(compiled_structs),
buffer_meta_data_(buffer_meta_data) {
ti_aot_data_.metadata = buffer_meta_data;
}

Expand Down Expand Up @@ -50,8 +53,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Kernel *kernel) {
auto compiled =
run_codegen(compiled_structs_, kernel, &strtab_, /*offloaded=*/nullptr);
auto compiled = run_codegen(compiled_runtime_module_, compiled_structs_,
kernel, &strtab_, /*offloaded=*/nullptr);
compiled.kernel_name = identifier;
ti_aot_data_.kernels.push_back(std::move(compiled));
}
Expand All @@ -76,8 +79,8 @@ void AotModuleBuilderImpl::add_per_backend_field(const std::string &identifier,
void AotModuleBuilderImpl::add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
auto compiled =
run_codegen(compiled_structs_, kernel, &strtab_, /*offloaded=*/nullptr);
auto compiled = run_codegen(compiled_runtime_module_, compiled_structs_,
kernel, &strtab_, /*offloaded=*/nullptr);
for (auto &k : ti_aot_data_.tmpl_kernels) {
if (k.kernel_bundle_name == identifier) {
k.kernel_tmpl_map.insert(std::make_pair(key, compiled));
Expand Down
9 changes: 7 additions & 2 deletions taichi/backends/metal/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ namespace metal {

class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl(const CompiledStructs *compiled_structs,
const BufferMetaData &buffer_meta_data);
explicit AotModuleBuilderImpl(
const CompiledRuntimeModule *compiled_runtime_module,
const CompiledStructs *compiled_structs,
const BufferMetaData &buffer_meta_data);

void dump(const std::string &output_dir,
const std::string &filename) const override;

Expand All @@ -32,10 +35,12 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
Kernel *kernel) override;

private:
const CompiledRuntimeModule *compiled_runtime_module_;
const CompiledStructs *compiled_structs_;
BufferMetaData buffer_meta_data_;
PrintStringTable strtab_;
TaichiAotData ti_aot_data_;

void metalgen(const stdfs::path &dir,
const std::string &filename,
const CompiledKernelData &k) const;
Expand Down
26 changes: 17 additions & 9 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,14 @@ class KernelCodegenImpl : public IRVisitor {
// TODO(k-ye): Create a Params to hold these ctor params.
KernelCodegenImpl(const std::string &taichi_kernel_name,
Kernel *kernel,
const CompiledRuntimeModule *compiled_runtime_module,
const CompiledStructs *compiled_structs,
PrintStringTable *print_strtab,
const Config &config,
OffloadedStmt *offloaded)
: mtl_kernel_prefix_(taichi_kernel_name),
kernel_(kernel),
compiled_runtime_module_(compiled_runtime_module),
compiled_structs_(compiled_structs),
needs_root_buffer_(compiled_structs_->root_size > 0),
print_strtab_(print_strtab),
Expand Down Expand Up @@ -806,7 +808,8 @@ class KernelCodegenImpl : public IRVisitor {
emit("");
current_appender().append_raw(shaders::kMetalHelpersSourceCode);
emit("");
current_appender().append_raw(compiled_structs_->runtime_utils_source_code);
current_appender().append_raw(
compiled_runtime_module_->runtime_utils_source_code);
emit("");
current_appender().append_raw(compiled_structs_->snode_structs_source_code);
emit("");
Expand Down Expand Up @@ -1495,6 +1498,7 @@ class KernelCodegenImpl : public IRVisitor {

const std::string mtl_kernel_prefix_;
Kernel *const kernel_;
const CompiledRuntimeModule *const compiled_runtime_module_;
const CompiledStructs *const compiled_structs_;
const bool needs_root_buffer_;
PrintStringTable *const print_strtab_;
Expand All @@ -1515,30 +1519,34 @@ class KernelCodegenImpl : public IRVisitor {

} // namespace

CompiledKernelData run_codegen(const CompiledStructs *compiled_structs,
Kernel *kernel,
PrintStringTable *strtab,
OffloadedStmt *offloaded) {
CompiledKernelData run_codegen(
const CompiledRuntimeModule *compiled_runtime_module,
const CompiledStructs *compiled_structs,
Kernel *kernel,
PrintStringTable *strtab,
OffloadedStmt *offloaded) {
const auto id = Program::get_kernel_id();
const auto taichi_kernel_name(
fmt::format("mtl_k{:04d}_{}", id, kernel->name));

KernelCodegenImpl::Config cgen_config;
cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled();

KernelCodegenImpl codegen(taichi_kernel_name, kernel, compiled_structs,
strtab, cgen_config, offloaded);
KernelCodegenImpl codegen(taichi_kernel_name, kernel, compiled_runtime_module,
compiled_structs, strtab, cgen_config, offloaded);

return codegen.run();
}

FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledRuntimeModule *compiled_runtime_module,
const CompiledStructs *compiled_structs,
OffloadedStmt *offloaded) {
const auto compiled_res = run_codegen(
compiled_structs, kernel, kernel_mgr->print_strtable(), offloaded);
const auto compiled_res =
run_codegen(compiled_runtime_module, compiled_structs, kernel,
kernel_mgr->print_strtable(), offloaded);
kernel_mgr->register_taichi_kernel(
compiled_res.kernel_name, compiled_res.source_code,
compiled_res.kernel_attribs, compiled_res.ctx_attribs);
Expand Down
11 changes: 7 additions & 4 deletions taichi/backends/metal/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@ namespace taichi {
namespace lang {
namespace metal {

CompiledKernelData run_codegen(const CompiledStructs *compiled_structs,
Kernel *kernel,
PrintStringTable *print_strtab,
OffloadedStmt *offloaded);
CompiledKernelData run_codegen(
const CompiledRuntimeModule *compiled_runtime_module,
const CompiledStructs *compiled_structs,
Kernel *kernel,
PrintStringTable *print_strtab,
OffloadedStmt *offloaded);

// If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it
// compiles just |offloaded|. These ASTs must have already been lowered at the
// CHI level.
FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledRuntimeModule *compiled_runtime_module,
const CompiledStructs *compiled_structs,
OffloadedStmt *offloaded = nullptr);

Expand Down
9 changes: 6 additions & 3 deletions taichi/backends/metal/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@

#include <string>

#include "taichi/lang_util.h"
#include "taichi/inc/constants.h"

TLANG_NAMESPACE_BEGIN
namespace taichi {
namespace lang {
namespace metal {

inline constexpr int kMaxNumThreadsGridStrideLoop = 64 * 1024;
inline constexpr int kNumRandSeeds = 64 * 1024; // 256 KB is nothing
inline constexpr int kMslVersionNone = 0;
inline constexpr int kMaxNumSNodes = taichi_max_num_snodes;

} // namespace metal
TLANG_NAMESPACE_END
} // namespace lang
} // namespace taichi
29 changes: 18 additions & 11 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <chrono>
#include <cstring>
#include <limits>
#include <optional>
#include <random>
#include <string_view>

Expand Down Expand Up @@ -568,6 +569,7 @@ class KernelManager::Impl {
public:
explicit Impl(Params params)
: config_(params.config),
compiled_runtime_module_(params.compiled_runtime_module),
compiled_structs_(params.compiled_structs),
mem_pool_(params.mem_pool),
host_result_buffer_(params.host_result_buffer),
Expand Down Expand Up @@ -605,24 +607,26 @@ class KernelManager::Impl {
device_.get(), global_tmps_mem_->ptr(), global_tmps_mem_->size());
TI_ASSERT(global_tmps_buffer_ != nullptr);

TI_ASSERT(compiled_structs_.runtime_size > 0);
TI_ASSERT(compiled_runtime_module_.runtime_size > 0);
const size_t mem_pool_bytes =
(config_->device_memory_GB * 1024 * 1024 * 1024ULL);
runtime_mem_ = std::make_unique<BufferMemoryView>(
compiled_structs_.runtime_size + mem_pool_bytes, mem_pool_);
compiled_runtime_module_.runtime_size + mem_pool_bytes, mem_pool_);
runtime_buffer_ = new_mtl_buffer_no_copy(device_.get(), runtime_mem_->ptr(),
runtime_mem_->size());
buffer_meta_data_.runtime_buffer_size = compiled_structs_.runtime_size;
buffer_meta_data_.runtime_buffer_size =
compiled_runtime_module_.runtime_size;
TI_DEBUG(
"Metal runtime buffer size: {} bytes (sizeof(Runtime)={} "
"memory_pool={})",
runtime_mem_->size(), compiled_structs_.runtime_size, mem_pool_bytes);
runtime_mem_->size(), compiled_runtime_module_.runtime_size,
mem_pool_bytes);

ActionRecorder::get_instance().record(
"allocate_runtime_buffer",
{ActionArg("runtime_buffer_size_in_bytes", (int64)runtime_mem_->size()),
ActionArg("runtime_struct_size_in_bytes",
(int64)compiled_structs_.runtime_size),
(int64)compiled_runtime_module_.runtime_size),
ActionArg("memory_pool_size", (int64)mem_pool_bytes)});

TI_ASSERT_INFO(
Expand Down Expand Up @@ -794,7 +798,7 @@ class KernelManager::Impl {
i, snode_type_name(sn_meta.snode->type), rtm_meta->element_stride,
rtm_meta->num_slots, rtm_meta->mem_offset_in_parent);
}
size_t addr_offset = sizeof(SNodeMeta) * max_snodes;
size_t addr_offset = sizeof(SNodeMeta) * kMaxNumSNodes;
addr += addr_offset;
TI_DEBUG("Initialized SNodeMeta, size={} accumulated={}", addr_offset,
(addr - addr_begin));
Expand All @@ -819,7 +823,7 @@ class KernelManager::Impl {
}
TI_DEBUG("");
}
addr_offset = sizeof(SNodeExtractors) * max_snodes;
addr_offset = sizeof(SNodeExtractors) * kMaxNumSNodes;
addr += addr_offset;
TI_DEBUG("Initialized SNodeExtractors, size={} accumulated={}", addr_offset,
(addr - addr_begin));
Expand All @@ -843,7 +847,7 @@ class KernelManager::Impl {
TI_DEBUG("ListManagerData\n id={}\n num_elems_per_chunk={}\n", i,
num_elems_per_chunk);
}
addr_offset = sizeof(ListManagerData) * max_snodes;
addr_offset = sizeof(ListManagerData) * kMaxNumSNodes;
addr += addr_offset;
TI_DEBUG("Initialized ListManagerData, size={} accumulated={}", addr_offset,
(addr - addr_begin));
Expand Down Expand Up @@ -892,7 +896,7 @@ class KernelManager::Impl {
init_node_mgr(sn_desc, nm_data);
snode_id_to_nodemgrs.push_back(std::make_pair(i, nm_data));
}
addr_offset = sizeof(NodeManagerData) * max_snodes;
addr_offset = sizeof(NodeManagerData) * kMaxNumSNodes;
addr += addr_offset;
TI_DEBUG("Initialized NodeManagerData, size={} accumulated={}", addr_offset,
(addr - addr_begin));
Expand All @@ -901,7 +905,7 @@ class KernelManager::Impl {
auto *const ambient_indices_begin =
reinterpret_cast<NodeManagerData::ElemIndex *>(addr);
dev_runtime_mirror_.ambient_indices = ambient_indices_begin;
addr_offset = sizeof(NodeManagerData::ElemIndex) * max_snodes;
addr_offset = sizeof(NodeManagerData::ElemIndex) * kMaxNumSNodes;
addr += addr_offset;
TI_DEBUG(
"Delayed the initialization of SNode ambient elements, size={} "
Expand Down Expand Up @@ -959,7 +963,6 @@ class KernelManager::Impl {
TI_DEBUG("AmbientIndex\n id={}\n mem_alloc->next={}\n", snode_id,
mem_alloc->next);
}

did_modify_range(runtime_buffer_.get(), /*location=*/0,
runtime_mem_->size());
}
Expand Down Expand Up @@ -1129,6 +1132,7 @@ class KernelManager::Impl {
}

CompileConfig *const config_;
const CompiledRuntimeModule compiled_runtime_module_;
const CompiledStructs compiled_structs_;
BufferMetaData buffer_meta_data_;
MemoryPool *const mem_pool_;
Expand Down Expand Up @@ -1212,6 +1216,9 @@ KernelManager::KernelManager(Params params)
KernelManager::~KernelManager() {
}

void KernelManager::add_compiled_snode_tree(const CompiledStructs &snode_tree) {
}

void KernelManager::register_taichi_kernel(
const std::string &taichi_kernel_name,
const std::string &mtl_kernel_source_code,
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/metal/kernel_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace metal {
class KernelManager {
public:
struct Params {
CompiledRuntimeModule compiled_runtime_module;
CompiledStructs compiled_structs;
CompileConfig *config;
MemoryPool *mem_pool;
Expand All @@ -38,6 +39,7 @@ class KernelManager {
// To make Pimpl + std::unique_ptr work
~KernelManager();

void add_compiled_snode_tree(const CompiledStructs &snode_tree);
// Register a Taichi kernel to the Metal runtime.
// * |mtl_kernel_source_code| is the complete source code compiled from a
// Taichi kernel. It may include one or more Metal compute kernels. Each
Expand Down
14 changes: 11 additions & 3 deletions taichi/backends/metal/metal_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ FunctionType MetalProgramImpl::compile(Kernel *kernel,
if (!kernel->lowered()) {
kernel->lower();
}
return metal::compile_to_metal_executable(kernel, metal_kernel_mgr_.get(),
&metal_compiled_structs_.value(),
offloaded);
return metal::compile_to_metal_executable(
kernel, metal_kernel_mgr_.get(), &(compiled_runtime_module_.value()),
&(metal_compiled_structs_.value()), offloaded);
}

std::size_t MetalProgramImpl::get_snode_num_dynamically_allocated(SNode *snode,
Expand All @@ -33,6 +33,7 @@ void MetalProgramImpl::materialize_runtime(MemoryPool *memory_pool,
sizeof(uint64) * taichi_result_buffer_entries, 8);
params_.mem_pool = memory_pool;
params_.profiler = profiler;
compiled_runtime_module_ = metal::compile_runtime_module();
}

void MetalProgramImpl::materialize_snode_tree(
Expand All @@ -48,6 +49,7 @@ void MetalProgramImpl::materialize_snode_tree(
metal_compiled_structs_ = metal::compile_structs(*root);
if (metal_kernel_mgr_ == nullptr) {
params_.compiled_structs = metal_compiled_structs_.value();
params_.compiled_runtime_module = compiled_runtime_module_.value();
params_.config = config;
params_.host_result_buffer = result_buffer;
params_.root_id = root->id;
Expand All @@ -56,5 +58,11 @@ void MetalProgramImpl::materialize_snode_tree(
}
}

std::unique_ptr<AotModuleBuilder> MetalProgramImpl::make_aot_module_builder() {
return std::make_unique<metal::AotModuleBuilderImpl>(
&(compiled_runtime_module_.value()), &(metal_compiled_structs_.value()),
metal_kernel_mgr_->get_buffer_meta_data());
}

} // namespace lang
} // namespace taichi
Loading