Skip to content

Commit

Permalink
Runtime support for multi-chip tensors/ops, including creation and
Browse files Browse the repository at this point in the history
execution. Updated program context to use parent/sub mesh model where
sub-meshes won't close devices if parent is still alive.
  • Loading branch information
jnie-TT committed Nov 6, 2024
1 parent ae93524 commit fce58b3
Show file tree
Hide file tree
Showing 29 changed files with 610 additions and 201 deletions.
17 changes: 17 additions & 0 deletions include/ttmlir/Target/Common/Target.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_TARGET_COMMON_TARGET_H
#define TTMLIR_TARGET_COMMON_TARGET_H

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wcovered-switch-default"

#include "ttmlir/Target/Common/system_desc_generated.h"
#include "ttmlir/Target/Common/types_generated.h"
#include "ttmlir/Target/Common/version_generated.h"

#pragma clang diagnostic pop

#endif
28 changes: 28 additions & 0 deletions include/ttmlir/Target/Common/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,33 @@ table MemoryConfigDesc {
shard_spec: ShardSpec;
}

table ReplicateTensor {
replication_factor: uint32;
}

table ShardTensor {
shard_dim: uint32;
}

table ShardTensor2D {
shard_mesh: Dim2d;
}

table AllGatherTensor {

}

union DistributedTensorConfig {
ReplicateTensor,
ShardTensor,
ShardTensor2D,
AllGatherTensor
}

table DistributionStrategy {
strategy: DistributedTensorConfig;
}

table MemoryDesc {
shape: [int];
tile_shape: Dim2d;
Expand All @@ -99,6 +126,7 @@ table LayoutDesc {
oob_val: OOBVal;
core_range_set: [Dim2dRange];
memory_desc: MemoryDesc;
strategy: DistributionStrategy;
}

table TensorDesc {
Expand Down
4 changes: 4 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,18 @@ table EmptyOp {
shape: [int64];
dtype: DataType;
layout: TensorLayout;
num_shards: uint32;
device: tt.target.DeviceRef; // optional
memcfg: tt.target.MemoryConfigDesc; // optional
strategy: tt.target.DistributionStrategy;
out: tt.target.TensorRef;
}

table FullOp {
device: tt.target.DeviceRef;
fill_value: float;
num_shards: uint32;
strategy: tt.target.DistributionStrategy;
out: tt.target.TensorRef;
}

Expand Down
9 changes: 8 additions & 1 deletion include/ttmlir/Target/Utils/MLIRToFlatbuffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,11 +424,18 @@ layoutAttrToFlatbuffer(FlatbufferObjectCache &cache, Attribute attr,
std::vector<int32_t> stride(strideInt64.begin(), strideInt64.end());
auto coreRangeSet =
toFlatbuffer(cache, layoutAttr.getGrid(), deviceAttr.getWorkerGrid());
::tt::target::DistributedTensorConfig distributionType =
::tt::target::DistributedTensorConfig::NONE;
::flatbuffers::Offset<void> distribution = 0;
flatbuffers::Offset<::tt::target::DistributionStrategy> strategy =
::tt::target::CreateDistributionStrategy(*cache.fbb, distributionType,
distribution);
return ::tt::target::CreateLayoutDescDirect(
*cache.fbb, &stride, toFlatbuffer(cache, layoutAttr.getOobVal()),
&coreRangeSet,
cache.getOrCreate(layoutAttr.getMemref(), memrefAttrToFlatbuffer,
layoutAttr.getMemLayout()));
layoutAttr.getMemLayout()),
strategy);
}

inline flatbuffers::Offset<::tt::target::TensorDesc>
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "ttmlir/Dialect/TT/IR/TT.h"
#include "ttmlir/Target/Common/system_desc_generated.h"
#include "ttmlir/Target/Common/Target.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down
20 changes: 18 additions & 2 deletions lib/Target/TTNN/TTNNToFlatbuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,21 @@ createOp(FlatbufferObjectCache &cache, EmptyOp op) {
::tt::target::TensorLayout layout =
::tt::mlir::ttnn::utils::toTargetTensorLayout(op.getLayout().value());

uint32_t numShards = 1;
::tt::target::DistributedTensorConfig distributionType =
::tt::target::DistributedTensorConfig::NONE;
::flatbuffers::Offset<void> distribution = 0;
flatbuffers::Offset<::tt::target::DistributionStrategy> strategy =
::tt::target::CreateDistributionStrategy(*cache.fbb, distributionType,
distribution);
auto output = getOperandThroughDPSOps(op.getResult());

// If the device is not set, we create on host
//
if (!op.getDevice()) {
return ::tt::target::ttnn::CreateEmptyOp(
*cache.fbb, cache.fbb->CreateVector<int64_t>(shape), dtype, layout,
/* device */ 0, /* memcfg */ 0,
numShards, /* device */ 0, /* memcfg */ 0, strategy,
cache.getOrCreate(output, tensorValueToFlatbuffer,
kHostAllocatedAddress, kHostAllocatedSize));
}
Expand All @@ -227,7 +234,8 @@ createOp(FlatbufferObjectCache &cache, EmptyOp op) {

return ::tt::target::ttnn::CreateEmptyOp(
*cache.fbb, cache.fbb->CreateVector<int64_t>(shape), dtype, layout,
cache.at<::tt::target::DeviceRef>(device), memoryConfigDesc,
numShards, cache.at<::tt::target::DeviceRef>(device), memoryConfigDesc,
strategy,
cache.getOrCreate(output, tensorValueToFlatbuffer, kHostAllocatedAddress,
kHostAllocatedSize));
}
Expand All @@ -237,8 +245,16 @@ createOp(FlatbufferObjectCache &cache, FullOp op) {
auto device = getOperandThroughDPSOps(op.getDevice());
auto fillValue = op.getFillValue().convertToFloat();
auto output = getOperandThroughDPSOps(op.getResult());
uint32_t numShards = 1;
::tt::target::DistributedTensorConfig distributionType =
::tt::target::DistributedTensorConfig::NONE;
::flatbuffers::Offset<void> distribution = 0;
flatbuffers::Offset<::tt::target::DistributionStrategy> strategy =
::tt::target::CreateDistributionStrategy(*cache.fbb, distributionType,
distribution);
return ::tt::target::ttnn::CreateFullOp(
*cache.fbb, cache.at<::tt::target::DeviceRef>(device), fillValue,
numShards, strategy,
cache.getOrCreate(output, tensorValueToFlatbuffer, kHostAllocatedAddress,
kHostAllocatedSize));
}
Expand Down
2 changes: 1 addition & 1 deletion python/TTModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "mlir/CAPI/IR.h"

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Target/Common/types_generated.h"
#include "ttmlir/Target/Common/Target.h"
#include "ttmlir/Utils.h"

namespace mlir::ttmlir::python {
Expand Down
1 change: 1 addition & 0 deletions runtime/include/tt/runtime/detail/ttmetal.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#pragma clang diagnostic pop

#include "tt/runtime/types.h"
#include "tt/runtime/utils.h"
#include "ttmlir/Target/TTMetal/Target.h"

namespace tt::runtime::ttmetal {
Expand Down
15 changes: 15 additions & 0 deletions runtime/include/tt/runtime/detail/ttnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include "ttnn/operations/normalization/softmax/softmax.hpp"
#include "ttnn/operations/pool/maxpool/max_pool2d.hpp"
#include "ttnn/operations/reduction/generic/generic_reductions.hpp"
#include "ttnn/tensor/host_buffer/functions.hpp"
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/tensor/types.hpp"
#pragma clang diagnostic pop
Expand All @@ -81,11 +82,25 @@ Tensor createTensor(std::shared_ptr<void> data,
std::vector<std::uint32_t> const &stride,
std::uint32_t itemsize, ::tt::target::DataType dataType);

Tensor
createTensor(std::vector<std::shared_ptr<void>> &data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride, std::uint32_t itemsize,
::tt::target::DataType dataType,
std::unordered_map<std::string, std::string> const &stratagy);

inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
return createTensor(data, desc.shape, desc.stride, desc.itemsize,
desc.dataType);
}

inline Tensor
createTensor(std::vector<std::shared_ptr<void>> &data, TensorDesc const &desc,
std::unordered_map<std::string, std::string> const &stratagy) {
return createTensor(data, desc.shape, desc.stride, desc.itemsize,
desc.dataType, stratagy);
}

tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();
Expand Down
14 changes: 14 additions & 0 deletions runtime/include/tt/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,25 @@ Tensor createTensor(std::shared_ptr<void> data,
std::vector<std::uint32_t> const &stride,
std::uint32_t itemsize, ::tt::target::DataType dataType);

Tensor
createTensor(std::vector<std::shared_ptr<void>> &data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride, std::uint32_t itemsize,
::tt::target::DataType dataType,
std::unordered_map<std::string, std::string> const &stratagy);

inline Tensor createTensor(std::shared_ptr<void> data, TensorDesc const &desc) {
return createTensor(data, desc.shape, desc.stride, desc.itemsize,
desc.dataType);
}

inline Tensor
createTensor(std::vector<std::shared_ptr<void>> &data, TensorDesc const &desc,
std::unordered_map<std::string, std::string> const &stratagy) {
return createTensor(data, desc.shape, desc.stride, desc.itemsize,
desc.dataType, stratagy);
}

tt::target::DataType getTensorDataType(Tensor tensor);

size_t getNumAvailableDevices();
Expand Down
4 changes: 3 additions & 1 deletion runtime/include/tt/runtime/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
#include <string_view>
#include <vector>

#include "tt/runtime/utils.h"
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#include "ttmlir/Target/Common/system_desc_generated.h"
#include "ttmlir/Target/Common/types_generated.h"
#pragma clang diagnostic pop

namespace tt::runtime {

Expand Down
3 changes: 3 additions & 0 deletions runtime/include/tt/runtime/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

#include <memory>

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wcovered-switch-default"
#include "ttmlir/Target/Common/types_generated.h"
#pragma clang diagnostic pop

namespace tt::runtime::utils {

Expand Down
24 changes: 24 additions & 0 deletions runtime/lib/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,30 @@ Tensor createTensor(std::shared_ptr<void> data,
throw std::runtime_error("runtime is not enabled");
}

Tensor
createTensor(std::vector<std::shared_ptr<void>> &data,
std::vector<std::uint32_t> const &shape,
std::vector<std::uint32_t> const &stride, std::uint32_t itemsize,
::tt::target::DataType dataType,
std::unordered_map<std::string, std::string> const &stratagy) {
LOG_ASSERT(not shape.empty());
LOG_ASSERT(not stride.empty());
LOG_ASSERT(itemsize > 0);
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
return ::tt::runtime::ttnn::createTensor(data, shape, stride, itemsize,
dataType, stratagy);
}
#endif

#if defined(TT_RUNTIME_ENABLE_TTMETAL)
if (getCurrentRuntime() == DeviceRuntime::TTMetal) {
throw std::runtime_error("Not implemented");
}
#endif
throw std::runtime_error("runtime is not enabled");
}

tt::target::DataType getTensorDataType(Tensor tensor) {
#if defined(TT_RUNTIME_ENABLE_TTNN)
if (getCurrentRuntime() == DeviceRuntime::TTNN) {
Expand Down
Loading

0 comments on commit fce58b3

Please sign in to comment.