Skip to content

Commit

Permalink
Enable lowering ttir all_reduce and mesh_shard to ttnn and flatbuffer
Browse files Browse the repository at this point in the history
* all_reduce breaks down into reduce_scatter and all_gather

* mesh_shard is directly converted

* corresponding dialect unit tests are added

* fix issue in default mesh shape in getOrInserDeivce()

* enable workaround pass to avoid issues in ttnn
  • Loading branch information
wooseokTT committed Dec 5, 2024
1 parent 0a172a2 commit 1aa9207
Show file tree
Hide file tree
Showing 11 changed files with 397 additions and 40 deletions.
54 changes: 46 additions & 8 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -876,13 +876,21 @@ def TTNN_DeallocateOp : TTNN_Op<"deallocate"> {
DefaultValuedAttr<BoolAttr, "false">:$force);
}

def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> {
let summary = "Scatter op.";
let description = [{
Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor.
}];
}

def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
let summary = "All gather op.";
let description = [{
Tensor All Gather operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$dim,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

Expand All @@ -891,23 +899,53 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
let hasVerifier = 1;
}

def TTNN_ScatterOp: TTNN_ElementwiseBinaryOp<"scatter"> {
let summary = "Scatter op.";
let description = [{
Embeds the values of the 'update' tensor into 'input' at the given index and puts the value in the 'output' tensor.
}];
}

def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> {
let summary = "Reduce scatter op.";
let description = [{
Tensor Reduce Scatter operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$scatter_split_dim,
TTNN_ReduceType:$math_op,
TT_ReduceTypeAttr:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_AllReduceOp: TTNN_Op<"all_reduce"> {
let summary = "All reduce op.";
let description = [{
Tensor All Reduce operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$scatter_dim,
SI32Attr:$scatter_num,
TT_ReduceTypeAttr:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {
let summary = "Mesh shard op.";
let description = [{
Tensor Mesh Shard operation
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
TT_MeshShardDirectionAttr:$shard_direction,
TT_MeshShardTypeAttr:$shard_type,
TT_GridAttr:$shard_shape);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
Expand Down
21 changes: 21 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,29 @@ table DeallocateOp {
table AllGatherOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
dim: uint32;
num_links: uint32;
}

table ReduceScatterOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
scatter_split_dim: uint32;
math_op: uint32;
num_links: uint32;
}

table MeshShardOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
shard_direction: uint32;
shard_type: uint32;
shard_shape: [int64];
}

union OpType {
GetDeviceOp,
ToMemoryConfigOp,
Expand All @@ -282,6 +301,8 @@ union OpType {
MaxPool2dOp,
DeallocateOp,
AllGatherOp,
ReduceScatterOp,
MeshShardOp,
ArangeOp,
}

Expand Down
56 changes: 49 additions & 7 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,50 @@ class SubtractOpConversionPattern
}
};

class AllReduceOpConversionPattern
: public OpConversionPattern<ttir::AllReduceOp> {
public:
using OpConversionPattern<ttir::AllReduceOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();
size_t scatter_dim = adaptor.getDim();
// scatter_num is needed when determining the output shape of workaround
// pass of reduce_scatter output and all_gather input
int32_t scatter_num =
replicaGroupsShape[scatter_dim % replicaGroupsShape.size()];
auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::AllReduceOp>(
op, this->getTypeConverter()->convertType(op.getType(0)),
adaptor.getInputs().front(), device, scatter_dim, scatter_num,
adaptor.getReduceType());

return success();
}
};

class MeshShardOpConversionPattern
: public OpConversionPattern<ttir::MeshShardOp> {
public:
using OpConversionPattern<ttir::MeshShardOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::MeshShardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::MeshShardOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getShardDirection(),
adaptor.getShardType(), adaptor.getShardShape());

return success();
}
};

class AllGatherOpConversionPattern
: public OpConversionPattern<ttir::AllGatherOp> {
public:
Expand All @@ -841,15 +885,11 @@ class AllGatherOpConversionPattern
LogicalResult
matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType type =
mlir::cast<RankedTensorType>(adaptor.getInput().getType());
Value device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
op.getLoc(), this->getTypeConverter()->convertType(type), device);

auto device = ::ttnn::utils::getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::AllGatherOp>(
op, this->getTypeConverter()->convertType(op.getType()), emptyOp,
adaptor.getDim());
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getDim());
return success();
}
};
Expand Down Expand Up @@ -978,6 +1018,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
Conv2dOpConversionPattern,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
MeshShardOpConversionPattern,
AllReduceOpConversionPattern,
AllGatherOpConversionPattern,
ArangeOpConversionPattern,
ScatterOpConversionPattern
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::AllGatherOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ReduceScatterOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MeshShardOp>>(typeConverter,
ctx);

// Module op
//
Expand Down
57 changes: 56 additions & 1 deletion lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,62 @@ ::mlir::LogicalResult AllGatherOp::verify() {
//===----------------------------------------------------------------------===//

::mlir::LogicalResult ReduceScatterOp::verify() {
// TODO(gfengTT)
::mlir::RankedTensorType inputType = getInput().getType();
int32_t scatterSplitDim = getScatterSplitDim();
auto mathOp = getMathOp();

if (scatterSplitDim >= inputType.getRank() ||
scatterSplitDim < -inputType.getRank()) {
return emitOpError("Invalid dimension for reduce scatter op.");
}

// Check reduction op that we currently support in tt_nn
if (mathOp != ::mlir::tt::ReduceType::Sum &&
mathOp != ::mlir::tt::ReduceType::Max &&
mathOp != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for reduce scatter op.");
}

return success();
}

//===----------------------------------------------------------------------===//
// AllReduceOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult AllReduceOp::verify() {
::mlir::RankedTensorType inputType = getInput().getType();
int32_t dim = getScatterDim();
auto mathOp = getMathOp();

if (dim >= inputType.getRank() ||
dim < -inputType.getRank()) {
return emitOpError("Invalid dimension for all reduce op.");
}

// Check reduction op that we currently support in tt_nn
if (mathOp != ::mlir::tt::ReduceType::Sum &&
mathOp != ::mlir::tt::ReduceType::Max &&
mathOp != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for all reduce op.");
}

return success();
}

//===----------------------------------------------------------------------===//
// MeshShardOp
//===----------------------------------------------------------------------===//

::mlir::LogicalResult MeshShardOp::verify() {
auto shardType = getShardType();

// Check sharding is one of replicate or devices
if (shardType != ::mlir::tt::MeshShardType::Replicate &&
shardType != ::mlir::tt::MeshShardType::Devices) {
return emitOpError("Invalid shard_type for mesh_shard op.");
}

return success();
}

Expand Down
Loading

0 comments on commit 1aa9207

Please sign in to comment.