diff --git a/runtime/lib/ttnn/program.cpp b/runtime/lib/ttnn/program.cpp index 3f0f9e1295..11ba0bf903 100644 --- a/runtime/lib/ttnn/program.cpp +++ b/runtime/lib/ttnn/program.cpp @@ -8,7 +8,7 @@ #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/runtime.h" -#include "ttnn/runtime/utils.h" +#include "utils.h" #include "ttmlir/Target/TTNN/Target.h" #include "ttmlir/Version.h" @@ -25,74 +25,119 @@ // some reason a static_assert fails when this is called from within our // namespace. ttnn::Tensor tilize(ttnn::Tensor const &input) { - return ttnn::to_layout(input, ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, + return ttnn::to_layout(input, ::ttnn::TILE_LAYOUT, std::nullopt, std::nullopt, (Device *)nullptr); } +ttnn::Tensor untilize(ttnn::Tensor const &input) { + return ttnn::to_layout(input, ::ttnn::ROW_MAJOR_LAYOUT, std::nullopt, + std::nullopt, (Device *)nullptr); +} + namespace tt::runtime::ttnn { + +static ::ttnn::Tensor convertDataType(const ::ttnn::Tensor &input, + const ::ttnn::DataType &targetDataType) { + const ::ttnn::StorageType storageType = input.storage_type(); + if (storageType == ::tt::tt_metal::StorageType::BORROWED) { + return ::ttnn::to_dtype(input, targetDataType); + } else if (storageType == ::tt::tt_metal::StorageType::DEVICE) { + if (input.get_layout() != ::ttnn::TILE_LAYOUT) { + // typecast op requires tilized tensor + ::ttnn::Tensor converted = + ::ttnn::typecast(::tilize(input), targetDataType); + // untilize and return + return ::untilize(converted); + } + return ::ttnn::typecast(input, targetDataType); + } else { + throw runtime_error("Unsupported storage type"); + } +} + +static ::ttnn::Tensor +updateLayoutAndDataType(const ::ttnn::Tensor &inputTensor, + const ::ttnn::DataType targetDataType, + const ::tt::target::Dim2d *targetTileShape) { + ::ttnn::Tensor outputTensor = inputTensor; + const bool shouldConvertDataType = inputTensor.get_dtype() != targetDataType; + const int targetTileX = targetTileShape->x(); + const int targetTileY = targetTileShape->y(); + const bool shouldTilize = + targetTileX == 32 and targetTileY == 32 and + inputTensor.get_layout() == ::ttnn::ROW_MAJOR_LAYOUT; + const bool shouldUntilize = (targetTileX != 32 or targetTileY != 32) and + inputTensor.get_layout() == ::ttnn::TILE_LAYOUT; + if (shouldTilize) { + outputTensor = ::tilize(outputTensor); + } else if (shouldUntilize) { + outputTensor = ::untilize(outputTensor); + } + if (shouldConvertDataType) { + outputTensor = convertDataType(outputTensor, targetDataType); + } + return outputTensor; +} + static void run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device, std::unordered_map &liveTensors, std::list<::ttnn::Tensor> &tensorPool) { - ::tt::target::DataType targetDataType = - op->out()->desc()->layout()->memory_desc()->data_type(); - assert(targetDataType == ::tt::target::DataType::Float32 or - targetDataType == ::tt::target::DataType::BFloat16); - ::ttnn::DataType targetDataTypeTTNN = utils::toTTNNDataType(targetDataType); const ::ttnn::Tensor &inputTensor = *liveTensors.at(op->in0()->global_id()); assert(inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED or inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE); + + const ::tt::target::Dim2d *targetTileShape = + op->out()->desc()->layout()->memory_desc()->tile_shape(); + TT_FATAL(utils::isValidTileShape(targetTileShape), + "Invalid tile shape ({}, {})", targetTileShape->x(), + targetTileShape->y()); + + ::tt::target::DataType targetDataType = + op->out()->desc()->layout()->memory_desc()->data_type(); + ::ttnn::DataType targetDataTypeTTNN = utils::toTTNNDataType(targetDataType); + const ::tt::target::MemorySpace targetMemorySpace = op->out()->desc()->layout()->memory_desc()->memory_space(); + switch (targetMemorySpace) { case ::tt::target::MemorySpace::System: case ::tt::target::MemorySpace::SystemMMIO: { + ::ttnn::Tensor result; if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) { - ::ttnn::Tensor hostTensor = inputTensor.to(::ttnn::ROW_MAJOR_LAYOUT); - hostTensor = ::ttnn::to_dtype(hostTensor, targetDataTypeTTNN); - ::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id()); - void *src = ::tt::tt_metal::get_raw_host_data_ptr(hostTensor); - void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor); - std::uint32_t size = hostTensor.volume() * hostTensor.element_size(); - std::memcpy(dst, src, size); + result = updateLayoutAndDataType(inputTensor, targetDataTypeTTNN, + targetTileShape); } else if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE) { - ::ttnn::Tensor hostTensor = - ::ttnn::typecast(inputTensor, targetDataTypeTTNN); - // Following the flow in core.py::to_torch - untilize on host - hostTensor = hostTensor.cpu().to(::ttnn::ROW_MAJOR_LAYOUT); - ::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id()); - void *src = ::tt::tt_metal::get_raw_host_data_ptr(hostTensor); - void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor); - std::uint32_t size = hostTensor.volume() * hostTensor.element_size(); - std::memcpy(dst, src, size); + result = updateLayoutAndDataType(inputTensor.cpu(), targetDataTypeTTNN, + targetTileShape); } + ::ttnn::Tensor &outputTensor = *liveTensors.at(op->out()->global_id()); + void *src = ::tt::tt_metal::get_raw_host_data_ptr(result); + void *dst = ::tt::tt_metal::get_raw_host_data_ptr(outputTensor); + std::uint32_t size = result.volume() * result.element_size(); + std::memcpy(dst, src, size); break; } case ::tt::target::MemorySpace::DeviceDRAM: { ::tt::tt_metal::MemoryConfig memConfig = ::ttnn::DRAM_MEMORY_CONFIG; - // Host tensor, currently only supports borrowed storage if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) { - // moving to device first allows us to use device tilize - ::ttnn::Tensor deviceTensor = - ::ttnn::to_device(inputTensor, &device, memConfig); - deviceTensor = ::tilize(deviceTensor); - if (deviceTensor.get_dtype() != targetDataTypeTTNN) { - deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN); + ::ttnn::Tensor result = inputTensor; + // device tilize requires BFLOAT16, if not then tilize on host + if (result.get_dtype() != ::ttnn::DataType::BFLOAT16) { + result = ::tilize(result); } - tensorPool.push_back(deviceTensor); + result = ::ttnn::to_device(result, &device, memConfig); + result = + updateLayoutAndDataType(result, targetDataTypeTTNN, targetTileShape); + tensorPool.push_back(result); liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); - // Device tensor, currently only support single-device storage - // Since tensor already on device, update the memory config and break } else if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE) { - // Dram to L1 or Dram to Dram - ::ttnn::Tensor deviceTensor = - ::ttnn::to_memory_config(inputTensor, memConfig, targetDataTypeTTNN); - if (deviceTensor.get_dtype() != targetDataTypeTTNN) { - deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN); - } - tensorPool.push_back(deviceTensor); + ::ttnn::Tensor result = updateLayoutAndDataType( + inputTensor, targetDataTypeTTNN, targetTileShape); + result = ::ttnn::to_memory_config(result, memConfig, std::nullopt); + tensorPool.push_back(result); liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); } break; @@ -101,28 +146,24 @@ run(::tt::target::ttnn::ToMemoryConfigOp const *op, ::ttnn::Device &device, // But will need it's own code path when we add support for sharding case ::tt::target::MemorySpace::DeviceL1: { ::tt::tt_metal::MemoryConfig memConfig = ::ttnn::L1_MEMORY_CONFIG; - // Host tensor, currently only supports borrowed storage + // moving to device first allows us to use device tilize if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::BORROWED) { - // moving to device first allows us to use device tilize - ::ttnn::Tensor deviceTensor = - ::ttnn::to_device(inputTensor, &device, memConfig); - deviceTensor = ::tilize(deviceTensor); - if (deviceTensor.get_dtype() != targetDataTypeTTNN) { - deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN); + ::ttnn::Tensor result = inputTensor; + // device tilize requires BFLOAT16, if not then tilize on host + if (result.get_dtype() != ::ttnn::DataType::BFLOAT16) { + result = ::tilize(result); } - tensorPool.push_back(deviceTensor); + result = ::ttnn::to_device(result, &device, memConfig); + result = + updateLayoutAndDataType(result, targetDataTypeTTNN, targetTileShape); + tensorPool.push_back(result); liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); - // Device tensor, currently only support single-device storage - // Since tensor already on device, update the memory config and break } else if (inputTensor.storage_type() == ::tt::tt_metal::StorageType::DEVICE) { - // L1 to Dram or L1 to L1 - ::ttnn::Tensor deviceTensor = - ::ttnn::to_memory_config(inputTensor, memConfig, targetDataTypeTTNN); - if (deviceTensor.get_dtype() != targetDataTypeTTNN) { - deviceTensor = ::ttnn::typecast(deviceTensor, targetDataTypeTTNN); - } - tensorPool.push_back(deviceTensor); + ::ttnn::Tensor result = updateLayoutAndDataType( + inputTensor, targetDataTypeTTNN, targetTileShape); + result = ::ttnn::to_memory_config(result, memConfig, std::nullopt); + tensorPool.push_back(result); liveTensors.try_emplace(op->out()->global_id(), &tensorPool.back()); } break; diff --git a/runtime/lib/ttnn/runtime.cpp b/runtime/lib/ttnn/runtime.cpp index 83f056ede0..a50e0dec53 100644 --- a/runtime/lib/ttnn/runtime.cpp +++ b/runtime/lib/ttnn/runtime.cpp @@ -5,7 +5,7 @@ #include "tt/runtime/runtime.h" #include "tt/runtime/detail/ttnn.h" #include "tt/runtime/utils.h" -#include "ttnn/runtime/utils.h" +#include "utils.h" #include "ttmlir/Target/TTNN/Target.h" #include "ttmlir/Version.h" diff --git a/runtime/include/ttnn/runtime/utils.h b/runtime/lib/ttnn/utils.h similarity index 67% rename from runtime/include/ttnn/runtime/utils.h rename to runtime/lib/ttnn/utils.h index e6c19fe7cc..235d9e446c 100644 --- a/runtime/include/ttnn/runtime/utils.h +++ b/runtime/lib/ttnn/utils.h @@ -10,20 +10,27 @@ namespace tt::runtime::ttnn::utils { +inline bool isValidTileShape(const ::tt::target::Dim2d *shape) { + return (shape->x() == 0 and shape->y() == 0) or + (shape->x() == 1 and shape->y() == 1) or + (shape->x() == 32 and shape->y() == 32); +} + inline ::ttnn::DataType toTTNNDataType(::tt::target::DataType dataType) { switch (dataType) { case ::tt::target::DataType::Float32: return ::ttnn::DataType::FLOAT32; - // case ::tt::target::DataType::Float16: - // return ::ttnn::DataType::FLOAT16; case ::tt::target::DataType::BFloat16: return ::ttnn::DataType::BFLOAT16; + case ::tt::target::DataType::BFP_BFloat8: + return ::ttnn::DataType::BFLOAT8_B; + case ::tt::target::DataType::BFP_BFloat4: + return ::ttnn::DataType::BFLOAT4_B; case ::tt::target::DataType::UInt32: return ::ttnn::DataType::UINT32; case ::tt::target::DataType::UInt16: return ::ttnn::DataType::UINT16; - // case ::tt::target::DataType::UInt8: - // return ::ttnn::DataType::UINT8; + default: throw std::runtime_error("Unsupported data type"); }