Skip to content

Commit

Permalink
Enable lowering ttir all_reduce and mesh_shard to ttnn and flatbuffer
Browse files Browse the repository at this point in the history
* all_reduce breaks down into reduce_scatter and all_gather

* mesh_shard is directly converted

* corresponding dialect unit tests are added

* fix issue in default mesh shape in getOrInserDeivce()

* enable workaround pass to avoid issues in ttnn
  • Loading branch information
wooseokTT committed Dec 6, 2024
1 parent 8a6151b commit c764b41
Show file tree
Hide file tree
Showing 16 changed files with 522 additions and 43 deletions.
54 changes: 46 additions & 8 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -903,13 +903,21 @@ def TTNN_DeallocateOp : TTNN_Op<"deallocate"> {
DefaultValuedAttr<BoolAttr, "false">:$force);
}

def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> {
let summary = "Scatter op.";
let description = [{
Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor.
}];
}

def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
let summary = "All gather op.";
let description = [{
Tensor All Gather operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$dim,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

Expand All @@ -918,23 +926,53 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
let hasVerifier = 1;
}

def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> {
let summary = "Scatter op.";
let description = [{
Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor.
}];
}

def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> {
let summary = "Reduce scatter op.";
let description = [{
Tensor Reduce Scatter operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$scatter_split_dim,
TTNN_ReduceType:$math_op,
TT_ReduceTypeAttr:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_AllReduceOp: TTNN_Op<"all_reduce"> {
let summary = "All reduce op.";
let description = [{
Tensor All Reduce operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$scatter_dim,
SI32Attr:$scatter_num,
TT_ReduceTypeAttr:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {
let summary = "Mesh shard op.";
let description = [{
Tensor Mesh Shard operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TT_MeshShardDirectionAttr:$shard_direction,
TT_MeshShardTypeAttr:$shard_type,
TT_GridAttr:$shard_shape);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
Expand Down
13 changes: 9 additions & 4 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,16 @@ struct TTIRToTTNNBackendPipelineOptions
ListOption<int64_t> meshShape{
*this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")};

// Option to enable/disable the workaround pass.
// Options to enable/disable the workaround pass.
//
Option<bool> workaroundPassEnabled{*this, "enable-workaround-pass",
llvm::cl::desc("Enable workaround pass."),
llvm::cl::init(false)};
Option<bool> layouotWorkaroundsEnabled{
*this, "enable-layout-workaround-pass",
llvm::cl::desc("Enable layout workaround pass."), llvm::cl::init(false)};

Option<bool> decompositionWorkaroundsEnabled{
*this, "enable-decomposition-workaround-pass",
llvm::cl::desc("Enable decomposition workaround pass."),
llvm::cl::init(true)};
};

// TTIR to EmitC pipeline options.
Expand Down
11 changes: 11 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def TTNNWorkarounds : Pass<"ttnn-workaround", "::mlir::ModuleOp"> {
This pass applies necessary TTNN workarounds to the IR in order to create
a valid and functional IR that can be executed on the hardware.
}];

let options = [
Option<"layouotWorkaroundsEnabled",
"ttnn-enable-layout-workaround-pass",
"bool", /*default=*/"false",
"TTNN Layout Workarounds Pass">,
Option<"decompositionWorkaroundsEnabled",
"ttnn-enable-decomposition-workaround-pass",
"bool", /*default=*/"true",
"TTNN Decompsition Workarounds Pass">,
];
}

#endif
21 changes: 21 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,29 @@ table DeallocateOp {
table AllGatherOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
dim: uint32;
num_links: uint32;
}

table ReduceScatterOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
scatter_split_dim: uint32;
math_op: uint32;
num_links: uint32;
}

table MeshShardOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
shard_direction: uint32;
shard_type: uint32;
shard_shape: [int64];
}

union OpType {
GetDeviceOp,
ToMemoryConfigOp,
Expand All @@ -295,6 +314,8 @@ union OpType {
MaxPool2dOp,
DeallocateOp,
AllGatherOp,
ReduceScatterOp,
MeshShardOp,
ArangeOp,
UpdateCacheOp,
FillCacheOp,
Expand Down
50 changes: 49 additions & 1 deletion lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,50 @@ class SubtractOpConversionPattern
}
};

class AllReduceOpConversionPattern
: public OpConversionPattern<ttir::AllReduceOp> {
public:
using OpConversionPattern<ttir::AllReduceOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();
size_t scatter_dim = adaptor.getDim();
// scatter_num is needed when determining the output shape of workaround
// pass of reduce_scatter output and all_gather input
int32_t scatter_num =
replicaGroupsShape[scatter_dim % replicaGroupsShape.size()];
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::AllReduceOp>(
op, this->getTypeConverter()->convertType(op.getType(0)),
adaptor.getInputs().front(), device, scatter_dim, scatter_num,
adaptor.getReduceType());

return success();
}
};

class MeshShardOpConversionPattern
: public OpConversionPattern<ttir::MeshShardOp> {
public:
using OpConversionPattern<ttir::MeshShardOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::MeshShardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::MeshShardOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getShardDirection(),
adaptor.getShardType(), adaptor.getShardShape());

return success();
}
};

class AllGatherOpConversionPattern
: public OpConversionPattern<ttir::AllGatherOp> {
public:
Expand All @@ -905,9 +949,11 @@ class AllGatherOpConversionPattern
LogicalResult
matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::AllGatherOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), adaptor.getDim());
adaptor.getInput(), device, adaptor.getDim());
return success();
}
};
Expand Down Expand Up @@ -1035,6 +1081,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
MeshShardOpConversionPattern,
AllReduceOpConversionPattern,
AllGatherOpConversionPattern,
ArangeOpConversionPattern,
UpdateCacheOpConversionPattern,
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::AllGatherOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ReduceScatterOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MeshShardOp>>(typeConverter,
ctx);

// Module op
//
Expand Down
56 changes: 55 additions & 1 deletion lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,61 @@ ::mlir::LogicalResult AllGatherOp::verify() {
//===----------------------------------------------------------------------===//

::mlir::LogicalResult ReduceScatterOp::verify() {
// TODO(gfengTT)
::mlir::RankedTensorType inputType = getInput().getType();
int32_t scatterSplitDim = getScatterSplitDim();
auto mathOp = getMathOp();

if (scatterSplitDim >= inputType.getRank() ||
scatterSplitDim < -inputType.getRank()) {
return emitOpError("Invalid dimension for reduce scatter op.");
}

// Check reduction op that we currently support in tt_nn
if (mathOp != ::mlir::tt::ReduceType::Sum &&
mathOp != ::mlir::tt::ReduceType::Max &&
mathOp != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for reduce scatter op.");
}

return success();
}

//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult AllReduceOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t dim = getScatterDim();
auto mathOp = getMathOp();

if (dim >= inputType.getRank() || dim < -inputType.getRank()) {
return emitOpError("Invalid dimension for all reduce op.");
}

// Check reduction op that we currently support in tt_nn
if (mathOp != ::mlir::tt::ReduceType::Sum &&
mathOp != ::mlir::tt::ReduceType::Max &&
mathOp != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for all reduce op.");
}

return success();
}

//===----------------------------------------------------------------------===//
// MeshShardOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult MeshShardOp::verify() {
auto shardType = getShardType();

// Check sharding is one of replicate or devices
if (shardType != ::mlir::tt::MeshShardType::Replicate &&
shardType != ::mlir::tt::MeshShardType::Devices) {
return emitOpError("Invalid shard_type for mesh_shard op.");
}

return success();
}

Expand Down
7 changes: 4 additions & 3 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,10 @@ void createTTNNPipelineLoweringPasses(
// Create a pass to workaround issues in the TTNN dialect.
void createTTNNPipelineWorkaroundPass(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
if (options.workaroundPassEnabled) {
pm.addPass(createTTNNWorkarounds());
}
TTNNWorkaroundsOptions workaroundOptions{
options.layouotWorkaroundsEnabled,
options.decompositionWorkaroundsEnabled};
pm.addPass(createTTNNWorkarounds(workaroundOptions));
}

void createTTNNPipelineLayoutDecompositionPass(
Expand Down
Loading

0 comments on commit c764b41

Please sign in to comment.