Skip to content

Commit

Permalink
[metal] Add AotModuleLoader (#4423)
Browse files Browse the repository at this point in the history
* [metal] Add AotModuleLoader

* fix
  • Loading branch information
k-ye authored Mar 2, 2022
1 parent 1915a80 commit 01cfd83
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 0 deletions.
78 changes: 78 additions & 0 deletions taichi/backends/metal/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "taichi/backends/metal/aot_module_loader_impl.h"

#include "taichi/backends/metal/aot_utils.h"
#include "taichi/backends/metal/kernel_manager.h"

namespace taichi {
namespace lang {
namespace metal {
namespace {

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(KernelManager *runtime, const std::string &kernel_name)
: runtime_(runtime), kernel_name_(kernel_name) {
}

void launch(RuntimeContext *ctx) override {
runtime_->launch_taichi_kernel(kernel_name_, ctx);
}

private:
KernelManager *const runtime_;
const std::string kernel_name_;
};

class AotModuleLoaderImpl : public aot::ModuleLoader {
public:
explicit AotModuleLoaderImpl(const AotModuleParams &params)
: runtime_(params.runtime) {
const std::string bin_path =
fmt::format("{}/metadata.tcb", params.module_path);
read_from_binary_file(aot_data_, bin_path);
// Do we still need to load each individual kernel?
for (const auto &k : aot_data_.kernels) {
kernels_[k.kernel_name] = &k;
}
}

bool get_field(const std::string &name,
aot::CompiledFieldData &field) override {
TI_ERROR("AOT: get_field for Metal not implemented yet");
return false;
}

size_t get_root_size() const override {
return aot_data_.metadata.root_buffer_size;
}

private:
std::unique_ptr<aot::Kernel> make_new_kernel(
const std::string &name) override {
auto itr = kernels_.find(name);
if (itr == kernels_.end()) {
TI_DEBUG("Failed to load kernel {}", name);
return nullptr;
}
auto *kernel_data = itr->second;
runtime_->register_taichi_kernel(name, kernel_data->source_code,
kernel_data->kernel_attribs,
kernel_data->ctx_attribs);
return std::make_unique<KernelImpl>(runtime_, name);
}

KernelManager *const runtime_;
TaichiAotData aot_data_;
std::unordered_map<std::string, const CompiledKernelData *> kernels_;
};

} // namespace

std::unique_ptr<aot::ModuleLoader> make_aot_module_loader(
const AotModuleParams &params) {
return std::make_unique<AotModuleLoaderImpl>(params);
}

} // namespace metal
} // namespace lang
} // namespace taichi
25 changes: 25 additions & 0 deletions taichi/backends/metal/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#pragma once

#include <string>
#include <vector>
#include <unordered_map>

#include "taichi/aot/module_loader.h"

namespace taichi {
namespace lang {
namespace metal {

class KernelManager;

struct AotModuleParams {
std::string module_path;
KernelManager *runtime{nullptr};
};

std::unique_ptr<aot::ModuleLoader> make_aot_module_loader(
const AotModuleParams &params);

} // namespace metal
} // namespace lang
} // namespace taichi

0 comments on commit 01cfd83

Please sign in to comment.