Skip to content

Commit

Permalink
Runtime stitching APIs and sanity tests, ttnn runtime submit refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Nov 20, 2024
1 parent ff339a1 commit 4361f74
Show file tree
Hide file tree
Showing 56 changed files with 1,771 additions and 588 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,12 @@ jobs:
source env/activate
pytest -ssv runtime/tools/python/test/test_perf.py
- name: ttrt api tests
shell: bash
run: |
source env/activate
pytest -ssv runtime/tools/python/test/test_runtime_api.py
build-and-test-explorer:
needs: build-image
timeout-minutes: 60
Expand Down
4 changes: 2 additions & 2 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ void closeDevice(Device device);

void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex, std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

void wait(Event event);
Expand Down
37 changes: 33 additions & 4 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,16 @@ void closeDevice(Device device);

void deallocateBuffers(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);
Tensor toHost(Tensor tensor, bool untilize = false);

void wait(Event event);
Tensor toDevice(Tensor tensor, Device device);

Tensor toDevice(Tensor tensor, Device device, Layout layout);

Tensor toLayout(Tensor tensor, Layout layout);

Layout getLayout(Binary executableHandle, std::uint32_t programIndex,
std::uint32_t inputIndex);

std::string getOpDebugString(OpContext opContextHandle);

Expand All @@ -88,10 +93,34 @@ Tensor getOpOutputTensor(OpContext opContextHandle,

std::vector<float> getTensorData(Tensor tensor);

void memcpy(void *dst, Tensor src);

void memcpy(Tensor dst, Tensor src);

void deallocateTensor(Tensor &tensor, bool force = false);

namespace legacy {
/* Will be deprecated soon once FEs migrate to new API */
void wait(Event event);

Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex, std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

void runProgram(::ttnn::MeshDevice &meshDevice, Binary &executableHandle,
std::uint32_t programIndex,
std::vector<::ttnn::Tensor *> const &inputs,
std::vector<::ttnn::Tensor *> const &outputs);
} // namespace legacy

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);

std::vector<Tensor> runProgram(::ttnn::MeshDevice &meshDevice,
Binary executableHandle,
std::uint32_t programIndex,
std::vector<::ttnn::Tensor *> const &inputs);

} // namespace tt::runtime::ttnn

Expand Down
29 changes: 25 additions & 4 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,40 @@ Device openDevice(DeviceIds const &deviceIds, size_t numHWCQs = 1);

void closeDevice(Device device);

Event submit(Device device, Binary executable, std::uint32_t programIndex,
std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

void wait(Event event);

Tensor toHost(Tensor tensor, bool untilize = false);

Tensor toDevice(Tensor tensor, Device device);

Tensor toDevice(Tensor tensor, Device device, Layout layout);

Tensor toLayout(Tensor tensor, Layout layout);

Layout getLayout(Binary executableHandle, std::uint32_t programIndex,
std::uint32_t inputIndex);

std::string getOpDebugString(OpContext opContextHandle);

Tensor getOpOutputTensor(OpContext opContextHandle,
CallbackContext programContextHandle);

std::vector<float> getTensorData(Tensor tensor);

void memcpy(void *dst, Tensor src);

void memcpy(Tensor dst, Tensor src);

void deallocateTensor(Tensor &tensor, bool force = false);

std::vector<Tensor> submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex,
std::vector<Tensor> const &inputs);

Event submit(Device deviceHandle, Binary executableHandle,
std::uint32_t programIndex, std::vector<Tensor> const &inputs,
std::vector<Tensor> const &outputs);

} // namespace tt::runtime

#endif
14 changes: 12 additions & 2 deletions runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,20 @@ struct Event : public detail::RuntimeCheckedObjectImpl {

struct Tensor : public detail::RuntimeCheckedObjectImpl {
std::shared_ptr<void> data;

Event event;
Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data) {}
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(nullptr, runtime) {}

Tensor(std::shared_ptr<void> handle, std::shared_ptr<void> data,
std::shared_ptr<void> eventHandle, DeviceRuntime runtime)
: detail::RuntimeCheckedObjectImpl(handle, runtime), data(data),
event(eventHandle, runtime) {}
};

struct Layout : public detail::RuntimeCheckedObjectImpl {
using detail::RuntimeCheckedObjectImpl::RuntimeCheckedObjectImpl;
};

struct CallbackContext : public detail::RuntimeCheckedObjectImpl {
Expand Down
34 changes: 12 additions & 22 deletions runtime/lib/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,12 @@ static std::string asJson(void const *fbb, uint8_t const *binarySchema,
flatbuffers::Parser parser(opts);

if (not parser.Deserialize(binarySchema, schemaSize)) {
throw std::runtime_error("Failed to deserialize schema");
LOG_FATAL("Failed to deserialize schema");
}

std::string text;
const char *err = ::flatbuffers::GenerateText(parser, fbb, &text);
if (err) {
throw std::runtime_error("Failed to generate JSON: " + std::string(err));
}

LOG_ASSERT(not err, "Failed to generate JSON: ", err);
return text;
}

Expand All @@ -44,9 +41,7 @@ namespace ttnn {
::tt::target::ttnn::TTNNBinary const *getBinary(Flatbuffer binary) {
bool isTTNN = ::tt::target::ttnn::SizePrefixedTTNNBinaryBufferHasIdentifier(
binary.handle.get());
if (not isTTNN) {
throw std::runtime_error("Unsupported binary format");
}
LOG_ASSERT(isTTNN, "Unsupported binary format");
return ::tt::target::ttnn::GetSizePrefixedTTNNBinary(binary.handle.get());
}

Expand Down Expand Up @@ -128,9 +123,7 @@ ::tt::target::metal::TTMetalBinary const *getBinary(Flatbuffer binary) {
bool isTTMetal =
::tt::target::metal::SizePrefixedTTMetalBinaryBufferHasIdentifier(
binary.handle.get());
if (not isTTMetal) {
throw std::runtime_error("Unsupported binary format");
}
LOG_ASSERT(isTTMetal, "Unsupported binary format");
return ::tt::target::metal::GetSizePrefixedTTMetalBinary(binary.handle.get());
}

Expand Down Expand Up @@ -207,7 +200,7 @@ namespace system_desc {
::tt::target::SystemDescRoot const *getBinary(Flatbuffer binary) {
if (!::tt::target::SizePrefixedSystemDescRootBufferHasIdentifier(
binary.handle.get())) {
throw std::runtime_error("Unsupported binary format");
LOG_FATAL("Unsupported binary format");
}
return ::tt::target::GetSizePrefixedSystemDescRoot(binary.handle.get());
}
Expand All @@ -234,10 +227,7 @@ std::string asJson(Flatbuffer binary) {
Flatbuffer Flatbuffer::loadFromPath(char const *path) {
// load a flatbuffer from path
std::ifstream fbb(path, std::ios::binary | std::ios::ate);
if (!fbb.is_open()) {
throw std::runtime_error("Failed to open file: " + std::string(path));
}

LOG_ASSERT(fbb.is_open(), "Failed to open file: ", path);
std::streampos size = fbb.tellg();
fbb.seekg(0, std::ios::beg);
auto buffer = ::tt::runtime::utils::malloc_shared(size);
Expand Down Expand Up @@ -269,7 +259,7 @@ std::string_view Flatbuffer::getFileIdentifier() const {
return ::tt::target::SystemDescRootIdentifier();
}

throw std::runtime_error("Unsupported binary format");
LOG_FATAL("Unsupported binary format");
}

std::string Flatbuffer::getVersion() const {
Expand All @@ -288,7 +278,7 @@ std::string Flatbuffer::getVersion() const {
return system_desc::getVersion(*this);
}

throw std::runtime_error("Unsupported binary format");
LOG_FATAL("Unsupported binary format");
}

std::string_view Flatbuffer::getTTMLIRGitHash() const {
Expand All @@ -307,7 +297,7 @@ std::string_view Flatbuffer::getTTMLIRGitHash() const {
return system_desc::getTTMLIRGitHash(*this);
}

throw std::runtime_error("Unsupported binary format");
LOG_FATAL("Unsupported binary format");
}

std::string Flatbuffer::asJson() const {
Expand All @@ -326,7 +316,7 @@ std::string Flatbuffer::asJson() const {
return system_desc::asJson(*this);
}

throw std::runtime_error("Unsupported binary format");
LOG_FATAL("Unsupported binary format");
}

SystemDesc SystemDesc::loadFromPath(char const *path) {
Expand All @@ -349,7 +339,7 @@ Binary::getProgramInputs(std::uint32_t programIndex) const {
return metal::getProgramInputs(*this, programIndex);
}

throw std::runtime_error("Unsupported binary format");
LOG_FATAL("Unsupported binary format");
}

std::vector<TensorDesc>
Expand All @@ -364,7 +354,7 @@ Binary::getProgramOutputs(std::uint32_t programIndex) const {
return metal::getProgramOutputs(*this, programIndex);
}

throw std::runtime_error("Unsupported binary format");
LOG_FATAL("Unsupported binary format");
}

const ::tt::target::GoldenTensor *
Expand Down
4 changes: 2 additions & 2 deletions runtime/lib/common/system_desc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ static ::tt::target::Arch toFlatbuffer(::tt::ARCH arch) {
break;
}

throw std::runtime_error("Unsupported arch");
LOG_FATAL("Unsupported arch");
}

static std::vector<::tt::target::ChipChannel>
Expand Down Expand Up @@ -246,7 +246,7 @@ static std::unique_ptr<::tt::runtime::SystemDesc> getCurrentSystemDescImpl(
::tt::target::FinishSizePrefixedSystemDescRootBuffer(fbb, root);
::flatbuffers::Verifier verifier(fbb.GetBufferPointer(), fbb.GetSize());
if (!::tt::target::VerifySizePrefixedSystemDescRootBuffer(verifier)) {
throw std::runtime_error("Failed to verify system desc root buffer");
LOG_FATAL("Failed to verify system desc root buffer");
}
uint8_t *buf = fbb.GetBufferPointer();
auto size = fbb.GetSize();
Expand Down
Loading

0 comments on commit 4361f74

Please sign in to comment.