Skip to content

Commit

Permalink
Enable lowering ttir all_reduce and mesh_shard to ttnn and flatbuffer
Browse files Browse the repository at this point in the history
- all_reduce breaks down into reduce_scatter and all_gather
- mesh_shard is directly converted
- corresponding dialect unit tests are added
  • Loading branch information
wooseokTT committed Nov 27, 2024
1 parent f8121bf commit ed37e0a
Show file tree
Hide file tree
Showing 9 changed files with 315 additions and 33 deletions.
22 changes: 21 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,7 @@ def TTNN_AllGatherOp: TTNN_Op<"all_gather"> {
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$dim,
DefaultValuedAttr<SI32Attr, "1">:$num_links);

Expand All @@ -867,12 +868,31 @@ def TTNN_ReduceScatterOp: TTNN_Op<"reduce_scatter"> {
}];

let arguments = (ins AnyRankedTensor:$input,
TT_Device:$device,
SI32Attr:$scatter_split_dim,
TTNN_ReduceType:$math_op,
TT_ReduceTypeAttr:$math_op,
DefaultValuedAttr<SI32Attr, "1">:$num_links);
let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {
let summary = "Mesh shard op.";
let description = [{
Tensor Mesh Shard operation
}];

let arguments = (ins
AnyRankedTensor:$input,
TT_Device:$device,
TT_MeshShardDirectionAttr:$shard_direction,
TT_MeshShardTypeAttr:$shard_type,
TT_GridAttr:$shard_shape);

let results = (outs AnyRankedTensor:$result);

let hasVerifier = 1;
}

#endif
21 changes: 21 additions & 0 deletions include/ttmlir/Target/TTNN/program.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,29 @@ table DeallocateOp {
table AllGatherOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
dim: uint32;
num_links: uint32;
}

table ReduceScatterOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
scatter_split_dim: uint32;
math_op: uint32;
num_links: uint32;
}

table MeshShardOp {
in: tt.target.TensorRef;
out: tt.target.TensorRef;
device: tt.target.DeviceRef;
shard_direction: uint32;
shard_type: uint32;
shard_shape: [int64];
}

union OpType {
GetDeviceOp,
ToMemoryConfigOp,
Expand All @@ -279,6 +298,8 @@ union OpType {
MaxPool2dOp,
DeallocateOp,
AllGatherOp,
ReduceScatterOp,
MeshShardOp,
ArangeOp,
}

Expand Down
140 changes: 132 additions & 8 deletions lib/Conversion/TTIRToTTNN/TTIRToTTNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ static Value getOrInsertDevice(ConversionPatternRewriter &rewriter,
DeviceAttr deviceAttr = getCurrentScopeDevice(op);
auto currentInsertionPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(block, block->begin());
auto meshShape = deviceAttr.getMeshShape();
if (meshShape.empty()) {
meshShape = {1, 1};
}
auto deviceOp = rewriter.create<ttnn::GetDeviceOp>(
op->getLoc(), rewriter.getType<DeviceType>(deviceAttr),
ttnn::MeshShapeAttr::get(op->getContext(), 1, 1));
ttnn::MeshShapeAttr::get(op->getContext(), meshShape[0], meshShape[1]));
rewriter.restoreInsertionPoint(currentInsertionPoint);
return deviceOp.getResult();
}
Expand Down Expand Up @@ -892,6 +896,128 @@ class SubtractOpConversionPattern
}
};

class AllReduceOpConversionPattern
: public OpConversionPattern<ttir::AllReduceOp> {
public:
using OpConversionPattern<ttir::AllReduceOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::AllReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType inputType =
mlir::cast<RankedTensorType>(adaptor.getInputs().front().getType());
SmallVector<int64_t> inputTypeShape(inputType.getShape());
auto replicaGroupsShape = adaptor.getReplicaGroups().getType().getShape();
size_t scatter_dim = adaptor.getDim();
int32_t scatter_num =
replicaGroupsShape[scatter_dim % replicaGroupsShape.size()];
auto device = getOrInsertDevice(rewriter, op);

// This lowering is supposed to be one to one mapping from ttir all_reduce
// to ttnn all_reduce. However, the all_reduce ops are broken down into
// reduce_scatter and all_gather ops because current support of all_reduce
// in TTNN is not stable.

// The reduce_scatter op in TTNN currently does not support two dimensional
// tensor correctly. As a temporary workaround, we insert reshape ops front
// and back to make the tensor as four dimensional tensor.

// TODO(wooseoklee): Once it supports two dimensional tensor, we can remove
// this workaround solution.
if (inputTypeShape.size() < 4) {
std::vector<int64_t> reshapedInputShape(4, 1);
for (size_t i = 0; i < inputTypeShape.size(); ++i) {
reshapedInputShape[i + inputTypeShape.size()] = inputTypeShape[i];
}

ArrayAttr reshapedInputShapeAttr =
rewriter.getI32ArrayAttr(std::vector<int32_t>(
reshapedInputShape.begin(), reshapedInputShape.end()));

auto reshapedInputType =
RankedTensorType::Builder(inputType).setShape(reshapedInputShape);

ttnn::ReshapeOp preReshapeOp = rewriter.create<ttnn::ReshapeOp>(
op.getLoc(), this->getTypeConverter()->convertType(reshapedInputType),
adaptor.getInputs().front(), reshapedInputShapeAttr);

scatter_dim = scatter_dim + (4 - inputTypeShape.size());

reshapedInputShape[scatter_dim] =
static_cast<int32_t>(reshapedInputShape[scatter_dim] / scatter_num);

auto scatteredInputType =
RankedTensorType::Builder(inputType).setShape(reshapedInputShape);

ttnn::ReduceScatterOp reduceScatterOp =
rewriter.create<ttnn::ReduceScatterOp>(
op.getLoc(),
this->getTypeConverter()->convertType(scatteredInputType),
preReshapeOp.getResult(), device, scatter_dim,
adaptor.getReduceType());

RankedTensorType outputType = mlir::cast<RankedTensorType>(op.getType(0));
SmallVector<int64_t> outputTypeShape(outputType.getShape());

std::vector<int64_t> reshapedOutputShape(4, 1);
for (size_t i = 0; i < outputTypeShape.size(); ++i) {
reshapedOutputShape[i + outputTypeShape.size()] = outputTypeShape[i];
}

auto reshapedOutputType =
RankedTensorType::Builder(outputType).setShape(reshapedOutputShape);

ttnn::AllGatherOp allGatherOp = rewriter.create<ttnn::AllGatherOp>(
op.getLoc(),
this->getTypeConverter()->convertType(reshapedOutputType),
reduceScatterOp.getResult(), device, scatter_dim);

ArrayAttr reshapedOutputShapeAttr = rewriter.getI32ArrayAttr(
std::vector<int32_t>(outputTypeShape.begin(), outputTypeShape.end()));

rewriter.replaceOpWithNewOp<ttnn::ReshapeOp>(
op, this->getTypeConverter()->convertType(outputType),
allGatherOp.getResult(), reshapedOutputShapeAttr);
} else {
// TODO(wooseoklee): Once all_reduce support in ttnn is stable, we can
// convert directly to ttnn.all_reduce.
inputTypeShape[scatter_dim] = inputTypeShape[scatter_dim] / scatter_num;
auto scatteredInputType =
RankedTensorType::Builder(inputType).setShape(inputTypeShape);

ttnn::ReduceScatterOp reduceScatterOp =
rewriter.create<ttnn::ReduceScatterOp>(
op.getLoc(),
this->getTypeConverter()->convertType(scatteredInputType),
adaptor.getInputs().front(), device, scatter_dim,
adaptor.getReduceType());

rewriter.replaceOpWithNewOp<ttnn::AllGatherOp>(
op, this->getTypeConverter()->convertType(op.getType(0)),
reduceScatterOp.getResult(), device, scatter_dim);
}
return success();
}
};

class MeshShardOpConversionPattern
: public OpConversionPattern<ttir::MeshShardOp> {
public:
using OpConversionPattern<ttir::MeshShardOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(ttir::MeshShardOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto device = getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::MeshShardOp>(
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getShardDirection(),
adaptor.getShardType(), adaptor.getShardShape());

return success();
}
};

class AllGatherOpConversionPattern
: public OpConversionPattern<ttir::AllGatherOp> {
public:
Expand All @@ -900,15 +1026,11 @@ class AllGatherOpConversionPattern
LogicalResult
matchAndRewrite(ttir::AllGatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType type =
mlir::cast<RankedTensorType>(adaptor.getInput().getType());
Value device = getOrInsertDevice(rewriter, op);
tensor::EmptyOp emptyOp = rewriter.create<tensor::EmptyOp>(
op.getLoc(), this->getTypeConverter()->convertType(type), device);

auto device = getOrInsertDevice(rewriter, op);
rewriter.replaceOpWithNewOp<ttnn::AllGatherOp>(
op, this->getTypeConverter()->convertType(op.getType()), emptyOp,
adaptor.getDim());
op, this->getTypeConverter()->convertType(op.getType()),
adaptor.getInput(), device, adaptor.getDim());
return success();
}
};
Expand Down Expand Up @@ -1022,6 +1144,8 @@ void populateTTIRToTTNNPatterns(MLIRContext *ctx, RewritePatternSet &patterns,
MaxPool2dOpConversionPattern,
SubtractOpConversionPattern,
AllGatherOpConversionPattern,
MeshShardOpConversionPattern,
AllReduceOpConversionPattern,
ArangeOpConversionPattern
>(typeConverter, ctx);
// ANCHOR_END: op_rewriter_pattern_set
Expand Down
4 changes: 4 additions & 0 deletions lib/Conversion/TTNNToEmitC/TTNNToEmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,10 @@ void populateTTNNToEmitCPatterns(mlir::MLIRContext *ctx,
//
patterns.add<DefaultOpConversionPattern<ttnn::AllGatherOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::ReduceScatterOp>>(typeConverter,
ctx);
patterns.add<DefaultOpConversionPattern<ttnn::MeshShardOp>>(typeConverter,
ctx);

// Module op
//
Expand Down
29 changes: 28 additions & 1 deletion lib/Dialect/TTNN/IR/TTNNOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,34 @@ ::mlir::LogicalResult AllGatherOp::verify() {
}

::mlir::LogicalResult ReduceScatterOp::verify() {
// TODO(gfengTT)
::mlir::RankedTensorType inputType = getInput().getType();
int32_t scatterSplitDim = getScatterSplitDim();
auto mathOp = getMathOp();

if (scatterSplitDim >= inputType.getRank() ||
scatterSplitDim < -inputType.getRank()) {
return emitOpError("Invalid dimension for reduce scatter op.");
}

// Check reduction op that we currently support in tt_nn
if (mathOp != ::mlir::tt::ReduceType::Sum &&
mathOp != ::mlir::tt::ReduceType::Max &&
mathOp != ::mlir::tt::ReduceType::Min) {
return emitOpError("Invalid reduction op for reduce scatter op.");
}

return success();
}

::mlir::LogicalResult MeshShardOp::verify() {
auto shardType = getShardType();

// Check sharding is one of replicate or devices
if (shardType != ::mlir::tt::MeshShardType::Replicate &&
shardType != ::mlir::tt::MeshShardType::Devices) {
return emitOpError("Invalid shard_type for mesh_shard op.");
}

return success();
}

Expand Down
Loading

0 comments on commit ed37e0a

Please sign in to comment.