Skip to content

Commit

Permalink
#92: Refactor toMemoryConfigOp v2
Browse files Browse the repository at this point in the history
  • Loading branch information
jnie-TT committed Aug 2, 2024
1 parent 231548f commit d5ad7f7
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 62 deletions.
154 changes: 97 additions & 57 deletions runtime/lib/ttnn/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/runtime.h"
#include "ttnn/runtime/utils.h"
#include "utils.h"

#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"
Expand All @@ -25,74 +25,119 @@
// some reason a static_assert fails when this is called from within our
// namespace.
ttnn::Tensor tilize(ttnn::Tensor const &input) {
return ttnn::to_layout(input, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt,
return ttnn::to_layout(input, ::ttnn::TILE_LAYOUT, std::nullopt, std::nullopt,
(Device *)nullptr);
}

ttnn::Tensor untilize(ttnn::Tensor const &input) {
return ttnn::to_layout(input, ::ttnn::ROW_MAJOR_LAYOUT, std::nullopt,
std::nullopt, (Device *)nullptr);
}

namespace tt::runtime::ttnn {

static ::ttnn::Tensor convertDataType(const ::ttnn::Tensor &input,
const ::ttnn::DataType &targetDataType) {
const ::ttnn::StorageType storageType = input.storage_type();
if (storageType == ::tt::tt_metal::StorageType::BORROWED) {
return ::ttnn::to_dtype(input, targetDataType);
} else if (storageType == ::tt::tt_metal::StorageType::DEVICE) {
if (input.get_layout() != ::ttnn::TILE_LAYOUT) {
// typecast op requires tilized tensor
::ttnn::Tensor converted =
::ttnn::typecast(::tilize(input), targetDataType);
// untilize and return
return ::untilize(converted);
}
return ::ttnn::typecast(input, targetDataType);
} else {
throw runtime_error("Unsupported storage type");
}
}

static ::ttnn::Tensor
updateLayoutAndDataType(const ::ttnn::Tensor &inputTensor,
const ::ttnn::DataType targetDataType,
const ::tt::target::Dim2d *targetTileShape) {
::ttnn::Tensor outputTensor = inputTensor;
const bool shouldConvertDataType = inputTensor.get_dtype() != targetDataType;
const int targetTileX = targetTileShape->x();
const int targetTileY = targetTileShape->y();
const bool shouldTilize =
targetTileX == 32 and targetTileY == 32 and
inputTensor.get_layout() == ::ttnn::ROW_MAJOR_LAYOUT;
const bool shouldUntilize = (targetTileX != 32 or targetTileY != 32) and
inputTensor.get_layout() == ::ttnn::TILE_LAYOUT;
if (shouldTilize) {
outputTensor = ::tilize(outputTensor);
} else if (shouldUntilize) {
outputTensor = ::untilize(outputTensor);
}
if (shouldConvertDataType) {
outputTensor = convertDataType(outputTensor, targetDataType);
}
return outputTensor;
}

static void
run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
std::unordered_map<std::uint32_t, ::ttnn::Tensor *> &liveTensors,
std::list<::ttnn::Tensor> &tensorPool) {
::tt::target::DataType targetDataType =
op->out()->desc()->layout()->memory_desc()->data_type();
assert(targetDataType == ::tt::target::DataType::Float32 or
targetDataType == ::tt::target::DataType::BFloat16);
::ttnn::DataType targetDataTypeTTNN = utils::toTTNNDataType(targetDataType);
const ::ttnn::Tensor &inputTensor = *liveTensors.at(op->in0()->global_id());
assert(inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED or
inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE);

const ::tt::target::Dim2d *targetTileShape =
op->out()->desc()->layout()->memory_desc()->tile_shape();
TT_FATAL(utils::isValidTileShape(targetTileShape),
"Invalid tile shape ({}, {})", targetTileShape->x(),
targetTileShape->y());

::tt::target::DataType targetDataType =
op->out()->desc()->layout()->memory_desc()->data_type();
::ttnn::DataType targetDataTypeTTNN = utils::toTTNNDataType(targetDataType);

const ::tt::target::MemorySpace targetMemorySpace =
op->out()->desc()->layout()->memory_desc()->memory_space();

switch (targetMemorySpace) {
case ::tt::target::MemorySpace::System:
case ::tt::target::MemorySpace::SystemMMIO: {
::ttnn::Tensor result;
if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
::ttnn::Tensor hostTensor = inputTensor.to(::ttnn::ROW_MAJOR_LAYOUT);
hostTensor = ::ttnn::to_dtype(hostTensor, targetDataTypeTTNN);
::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id());
void *src = ::tt::tt_metal::get_raw_host_data_ptr(hostTensor);
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor);
std::uint32_t size = hostTensor.volume() * hostTensor.element_size();
std::memcpy(dst, src, size);
result = updateLayoutAndDataType(inputTensor, targetDataTypeTTNN,
targetTileShape);
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
::ttnn::Tensor hostTensor =
::ttnn::typecast(inputTensor, targetDataTypeTTNN);
// Following the flow in core.py::to_torch - untilize on host
hostTensor = hostTensor.cpu().to(::ttnn::ROW_MAJOR_LAYOUT);
::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id());
void *src = ::tt::tt_metal::get_raw_host_data_ptr(hostTensor);
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor);
std::uint32_t size = hostTensor.volume() * hostTensor.element_size();
std::memcpy(dst, src, size);
result = updateLayoutAndDataType(inputTensor.cpu(), targetDataTypeTTNN,
targetTileShape);
}
::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id());
void *src = ::tt::tt_metal::get_raw_host_data_ptr(result);
void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor);
std::uint32_t size = result.volume() * result.element_size();
std::memcpy(dst, src, size);
break;
}
case ::tt::target::MemorySpace::DeviceDRAM: {
::tt::tt_metal::MemoryConfig memConfig = ::ttnn::DRAM_MEMORY_CONFIG;
// Host tensor, currently only supports borrowed storage
if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
// moving to device first allows us to use device tilize
::ttnn::Tensor deviceTensor =
::ttnn::to_device(inputTensor, &device, memConfig);
deviceTensor = ::tilize(deviceTensor);
if (deviceTensor.get_dtype() != targetDataTypeTTNN) {
deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN);
::ttnn::Tensor result = inputTensor;
// device tilize requires BFLOAT16, if not then tilize on host
if (result.get_dtype() != ::ttnn::DataType::BFLOAT16) {
result = ::tilize(result);
}
tensorPool.push_back(deviceTensor);
result = ::ttnn::to_device(result, &device, memConfig);
result =
updateLayoutAndDataType(result, targetDataTypeTTNN, targetTileShape);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
// Device tensor, currently only support single-device storage
// Since tensor already on device, update the memory config and break
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
// Dram to L1 or Dram to Dram
::ttnn::Tensor deviceTensor =
::ttnn::to_memory_config(inputTensor, memConfig, targetDataTypeTTNN);
if (deviceTensor.get_dtype() != targetDataTypeTTNN) {
deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN);
}
tensorPool.push_back(deviceTensor);
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, targetTileShape);
result = ::ttnn::to_memory_config(result, memConfig, std::nullopt);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
}
break;
Expand All @@ -101,28 +146,23 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device,
// But will need it's own code path when we add support for sharding
case ::tt::target::MemorySpace::DeviceL1: {
::tt::tt_metal::MemoryConfig memConfig = ::ttnn::L1_MEMORY_CONFIG;
// Host tensor, currently only supports borrowed storage
if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) {
// moving to device first allows us to use device tilize
::ttnn::Tensor deviceTensor =
::ttnn::to_device(inputTensor, &device, memConfig);
deviceTensor = ::tilize(deviceTensor);
if (deviceTensor.get_dtype() != targetDataTypeTTNN) {
deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN);
::ttnn::Tensor result = inputTensor;
// device tilize requires BFLOAT16, if not then tilize on host
if (result.get_dtype() != ::ttnn::DataType::BFLOAT16) {
result = ::tilize(result);
}
tensorPool.push_back(deviceTensor);
result = ::ttnn::to_device(result, &device, memConfig);
result =
updateLayoutAndDataType(result, targetDataTypeTTNN, targetTileShape);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
// Device tensor, currently only support single-device storage
// Since tensor already on device, update the memory config and break
} else if (inputTensor.storage_type() ==
::tt::tt_metal::StorageType::DEVICE) {
// L1 to Dram or L1 to L1
::ttnn::Tensor deviceTensor =
::ttnn::to_memory_config(inputTensor, memConfig, targetDataTypeTTNN);
if (deviceTensor.get_dtype() != targetDataTypeTTNN) {
deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN);
}
tensorPool.push_back(deviceTensor);
::ttnn::Tensor result = updateLayoutAndDataType(
inputTensor, targetDataTypeTTNN, targetTileShape);
result = ::ttnn::to_memory_config(result, memConfig, std::nullopt);
tensorPool.push_back(result);
liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back());
}
break;
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/ttnn/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "tt/runtime/runtime.h"
#include "tt/runtime/detail/ttnn.h"
#include "tt/runtime/utils.h"
#include "ttnn/runtime/utils.h"
#include "utils.h"

#include "ttmlir/Target/TTNN/Target.h"
#include "ttmlir/Version.h"
Expand Down
15 changes: 11 additions & 4 deletions runtime/include/ttnn/runtime/utils.h → runtime/lib/ttnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,27 @@

namespace tt::runtime::ttnn::utils {

inline bool isValidTileShape(const ::tt::target::Dim2d *shape) {
return (shape->x() == 0 and shape->y() == 0) or
(shape->x() == 1 and shape->y() == 1) or
(shape->x() == 32 and shape->y() == 32);
}

inline ::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) {
switch (dataType) {
case ::tt::target::DataType::Float32:
return ::ttnn::DataType::FLOAT32;
// case ::tt::target::DataType::Float16:
// return ::ttnn::DataType::FLOAT16;
case ::tt::target::DataType::BFloat16:
return ::ttnn::DataType::BFLOAT16;
case ::tt::target::DataType::BFP_BFloat8:
return ::ttnn::DataType::BFLOAT8_B;
case ::tt::target::DataType::BFP_BFloat4:
return ::ttnn::DataType::BFLOAT4_B;
case ::tt::target::DataType::UInt32:
return ::ttnn::DataType::UINT32;
case ::tt::target::DataType::UInt16:
return ::ttnn::DataType::UINT16;
// case ::tt::target::DataType::UInt8:
// return ::ttnn::DataType::UINT8;

default:
throw std::runtime_error("Unsupported data type");
}
Expand Down

0 comments on commit d5ad7f7

Please sign in to comment.