Skip to content

Commit

Permalink
[Optimizer] Fix mem reconfig/reshard.
Browse files Browse the repository at this point in the history
  • Loading branch information
nobradovictt committed Oct 31, 2024
1 parent 2a2121d commit f041d8d
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 132 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TT/Utils/OverrideParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ struct InputLayoutOverrideParser

static void print(llvm::raw_ostream &os,
const llvm::StringMap<InputLayoutOverrideParams> &value) {
os << "insert-reshard=";
os << "insert-memreconfig=";
size_t count = 0;
for (const auto &entry : value) {
os << entry.getKey() << "=";
Expand Down
12 changes: 6 additions & 6 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,17 @@ struct TTIRToTTNNBackendPipelineOptions
//
// Full Example: "op1=0,op2=0:1"
//
// This will insert one TTIR_ToLayoutOps responsible for resharding the op1's
// first operand and two TTIR_ToLayoutOps responsible for resharding the op2's
// first and second operand.
// This will insert one memory reconfig op responsible for resharding the
// op1's first operand and two memory reconfig ops responsible for resharding
// the op2's first and second operand.
//
// Note: This option is only valid if optimizerPassEnabled is true.
//
Option<llvm::StringMap<InputLayoutOverrideParams>, InputLayoutOverrideParser>
overrideInputLayout{
*this, "insert-reshard",
*this, "insert-memreconfig",
llvm::cl::desc(
"Manually insert TTIR_ToLayoutOp for specific op's operand."),
"Manually insert memory reconfig op for specific op's operand."),
llvm::cl::init(llvm::StringMap<InputLayoutOverrideParams>())};

// Option to override output layout for specific ops.
Expand Down Expand Up @@ -83,7 +83,7 @@ struct TTIRToTTNNBackendPipelineOptions
llvm::cl::desc("Memory layout reconfiguration pass. Temp disabled till "
"we support all types "
"of shard specs."),
llvm::cl::init(false)};
llvm::cl::init(true)};

// Option to provide a system descriptor flatbuffer file to compile
// against.
Expand Down
6 changes: 3 additions & 3 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
::mlir::Pass::Option<llvm::StringMap<InputLayoutOverrideParams>,
mlir::tt::InputLayoutOverrideParser>
overrideInputLayout{
*this, "insert-reshard",
*this, "insert-memreconfig",
::llvm::cl::desc(
"Manually insert reshard for specific op's operand."),
"Manually insert memory reconfig op for specific op's operand."),
::llvm::cl::init(llvm::StringMap<InputLayoutOverrideParams>())};
::mlir::Pass::Option<llvm::StringMap<OutputLayoutOverrideParams>,
mlir::tt::OutputLayoutOverrideParser>
Expand All @@ -121,7 +121,7 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
::llvm::cl::desc("Memory layout reconfiguration pass. Temp disabled till "
"we support all "
"types of shard specs."),
::llvm::cl::init(false)};
::llvm::cl::init(true)};
::mlir::Pass::Option<int64_t> maxLegalLayouts{
*this, "max-legal-layouts",
::llvm::cl::desc(
Expand Down
5 changes: 2 additions & 3 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void DFShardingPolicy::run(
//
if (l1ChainConfigs->back().isEmpty()) {
for (auto *op : scheduleableOps) {
if (isa<ttnn::ToLayoutOp>(op)) {
if (isa<ToLayoutOp>(op)) {
currentOp = op;
break;
}
Expand All @@ -52,8 +52,7 @@ void DFShardingPolicy::run(

// Skip starting sharding chain if currentOp is a memory management op.
//
if (l1ChainConfigs->back().isEmpty() &&
isa<ttnn::ToLayoutOp>(currentOp)) {
if (l1ChainConfigs->back().isEmpty() && isa<ToLayoutOp>(currentOp)) {
currentOp = nullptr;
continue;
}
Expand Down
193 changes: 97 additions & 96 deletions lib/Dialect/TTNN/Transforms/Optimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
return;
}

// Skip empty ops. Handled via DPS op output operand update.
//
if (isa<EmptyOp>(op)) {
return;
}

if (!isa<RankedTensorType>(op->getResult(0).getType())) {
return;
}
Expand Down Expand Up @@ -149,7 +155,7 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
EmptyOp emptyOp =
mlir::cast<EmptyOp>(op->getOperands().back().getDefiningOp());

emptyOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get(
emptyOp.setMemoryConfigAttr(MemoryConfigAttr::get(
op->getContext(),
TensorMemoryLayoutAttr::get(op->getContext(),
tensorMemoryLayout),
Expand All @@ -159,29 +165,6 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
ShapeAttr::get(op->getContext(),
ttLayoutAttr.getMemref().getShape()))));
}
// TODO (nobradovic): Other memory management ops after lowering to
// TTNN will need to be special handled as well. Depends on ttnn
// layout attr refactor and lowering.
//
else if (isa<ttnn::ToLayoutOp>(op)) {
BufferType bufferType =
utils::toTTNNBufferType(ttLayoutAttr.getMemorySpace());
TensorMemoryLayout tensorMemoryLayout =
utils::toTTNNTensorMemoryLayout(ttLayoutAttr.getMemLayout());
// Update the device op with the new tensor type.
//
ttnn::ToLayoutOp toLayoutOp = llvm::cast<ttnn::ToLayoutOp>(op);
toLayoutOp.setMemoryConfigAttr(ttnn::MemoryConfigAttr::get(
op->getContext(),
ttnn::TensorMemoryLayoutAttr::get(op->getContext(),
tensorMemoryLayout),
ttnn::BufferTypeAttr::get(op->getContext(), bufferType),
ttnn::ShardSpecAttr::get(
op->getContext(),
ttnn::ShapeAttr::get(
op->getContext(),
ttLayoutAttr.getMemref().getShape()))));
}
}
});

Expand Down Expand Up @@ -233,6 +216,19 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
assert(overrideInputLayout.size() == overrideReshardEdges.size());
}

mlir::TypedValue<mlir::tt::DeviceType>
getDeviceOpValue(Operation *contextOp) {
Block *block = contextOp->getBlock();
mlir::TypedValue<mlir::tt::DeviceType> deviceOpResult;
for (auto &op : block->getOperations()) {
if (GetDeviceOp deviceOp = dyn_cast<GetDeviceOp>(op)) {
deviceOpResult = deviceOp.getResult();
break;
}
}
return deviceOpResult;
}

void
processMemReconfigEdges(const std::unordered_set<Edge> &memReconfigEdges) {
// Insert memory reconfig ops here based on results of memory layout
Expand All @@ -242,86 +238,91 @@ class TTNNOptimizer : public impl::TTNNOptimizerBase<TTNNOptimizer> {
Operation *producerOp = edge.producerOp;
Operation *consumerOp = edge.consumerOp;

tt::LayoutAttr consumerOpOutputLayout = mlir::cast<tt::LayoutAttr>(
mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType())
.getEncoding());

RankedTensorType producerOpTensorType =
mlir::cast<RankedTensorType>(producerOp->getResult(0).getType());
llvm::ArrayRef<int64_t> producerOpTensorShape =
producerOpTensorType.getShape();
tt::LayoutAttr producerOpLayout =
mlir::cast<tt::LayoutAttr>(producerOpTensorType.getEncoding());

// TODO(nobradovic): Match memory space and layout of consumer op.
// This actually needs to be properly resolved based on op type, output
// layout and other inputs.
//
RankedTensorType newTensorType = RankedTensorType::get(
producerOpTensorShape, producerOpTensorType.getElementType(),
producerOpLayout
.withElementType(consumerOp->getContext(),
consumerOpOutputLayout.getElementType())
.withMemorySpace(consumerOp->getContext(),
consumerOpOutputLayout.getMemorySpace())
.withMemoryLayout(consumerOp->getContext(),
consumerOpOutputLayout.getMemLayout())
.withGrid(consumerOp->getContext(), producerOpTensorType,
consumerOpOutputLayout.getGrid()));

// If producerOp is a toLayoutOp, adjust its output layout(update
// inplace) to reflect consumerOp's output layout. If producerOp is not a
// toLayoutOp, insert a toLayoutOp in between producerOp
// and consumerOp.
//
if (isa<ttnn::ToLayoutOp>(producerOp)) {
ttnn::ToLayoutOp toLayoutOp = llvm::cast<ttnn::ToLayoutOp>(producerOp);
tt::LayoutAttr consumerOpOutputLayout = mlir::cast<tt::LayoutAttr>(
mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType())
.getEncoding());

RankedTensorType toLayoutOpTensorType =
mlir::cast<RankedTensorType>(toLayoutOp.getResult().getType());
llvm::ArrayRef<int64_t> toLayoutOpTensorShape =
toLayoutOpTensorType.getShape();
tt::LayoutAttr toLayoutOpLayout =
mlir::cast<tt::LayoutAttr>(toLayoutOpTensorType.getEncoding());

// TODO(nobradovic): Match memory space and layout of consumer op. This
// actually needs to be properly resolved based on op type, output
// layout and other inputs.
//
RankedTensorType newTensorType = RankedTensorType::get(
toLayoutOpTensorShape, toLayoutOpTensorType.getElementType(),
toLayoutOpLayout
.withElementType(toLayoutOp->getContext(),
consumerOpOutputLayout.getElementType())
.withMemorySpace(toLayoutOp.getContext(),
consumerOpOutputLayout.getMemorySpace())
.withMemoryLayout(toLayoutOp.getContext(),
consumerOpOutputLayout.getMemLayout())
.withGrid(toLayoutOp.getContext(), toLayoutOpTensorType,
consumerOpOutputLayout.getGrid()));
if (isa<ToLayoutOp>(producerOp)) {
ToLayoutOp toLayoutOp = llvm::cast<ToLayoutOp>(producerOp);

BufferType bufferType =
utils::toTTNNBufferType(consumerOpOutputLayout.getMemorySpace());
TensorMemoryLayout tensorMemoryLayout = utils::toTTNNTensorMemoryLayout(
consumerOpOutputLayout.getMemLayout());
toLayoutOp.setMemoryConfigAttr(MemoryConfigAttr::get(
consumerOp->getContext(),
TensorMemoryLayoutAttr::get(consumerOp->getContext(),
tensorMemoryLayout),
BufferTypeAttr::get(consumerOp->getContext(), bufferType),
ShardSpecAttr::get(
consumerOp->getContext(),
ShapeAttr::get(
consumerOp->getContext(),
consumerOpOutputLayout.getMemref().getShape()))));

toLayoutOp.getResult().setType(newTensorType);
} else {

OpBuilder builder(consumerOp);

MemRefType outputMemref = consumerOpOutputLayout.getMemref();

DataTypeAttr outputDataType =
DataTypeAttr::get(consumerOp->getContext(),
utils::getDataTypeFromMemRef(outputMemref));
BufferType outputBufferType =
utils::toTTNNBufferType(consumerOpOutputLayout.getMemorySpace());
Layout outputLayoutEnum = utils::getLayoutFromMemRef(outputMemref);
LayoutAttr outputLayout =
LayoutAttr::get(consumerOp->getContext(), outputLayoutEnum);
TensorMemoryLayout outputTensorMemoryLayout =
utils::toTTNNTensorMemoryLayout(
consumerOpOutputLayout.getMemLayout());
MemoryConfigAttr outputMemConfigAttr = MemoryConfigAttr::get(
consumerOp->getContext(),
TensorMemoryLayoutAttr::get(consumerOp->getContext(),
outputTensorMemoryLayout),
BufferTypeAttr::get(consumerOp->getContext(), outputBufferType),
ShardSpecAttr::get(consumerOp->getContext(),
ShapeAttr::get(consumerOp->getContext(),
outputMemref.getShape())));

Operation *memoryReconfigOp = builder.create<ToLayoutOp>(
consumerOp->getLoc(), newTensorType, producerOp->getResult(0),
outputLayout, outputDataType, outputMemConfigAttr,
getDeviceOpValue(consumerOp));

consumerOp->setOperand(edge.operandIndex,
memoryReconfigOp->getResult(0));
}
// TODO (nobradovic): Memory layout reconfig needs to be reimplemented for
// TTNN dialect.
// else {
// tt::LayoutAttr consumerOpOutputLayout = mlir::cast<tt::LayoutAttr>(
// mlir::cast<RankedTensorType>(consumerOp->getResult(0).getType())
// .getEncoding());

// RankedTensorType producerOpTensorType =
// mlir::cast<RankedTensorType>(producerOp->getResult(0).getType());
// llvm::ArrayRef<int64_t> producerOpTensorShape =
// producerOpTensorType.getShape();
// tt::LayoutAttr producerOpLayout =
// mlir::cast<tt::LayoutAttr>(producerOpTensorType.getEncoding());

// // TODO(nobradovic): Match memory space and layout of consumer op.
// This
// // actually needs to be properly resolved based on op type, output
// // layout and other inputs.
// //
// RankedTensorType newTensorType = RankedTensorType::get(
// producerOpTensorShape, producerOpTensorType.getElementType(),
// producerOpLayout
// .withElementType(consumerOp->getContext(),
// consumerOpOutputLayout.getElementType())
// .withMemorySpace(consumerOp->getContext(),
// consumerOpOutputLayout.getMemorySpace())
// .withMemoryLayout(consumerOp->getContext(),
// consumerOpOutputLayout.getMemLayout())
// .withGrid(consumerOp->getContext(), producerOpTensorType,
// consumerOpOutputLayout.getGrid()));

// OpBuilder builder(consumerOp);

// mlir::tensor::EmptyOp emptyOp = builder.create<tensor::EmptyOp>(
// consumerOp->getLoc(), producerOpTensorShape,
// producerOpTensorType.getElementType(),
// mlir::cast<LayoutAttr>(newTensorType.getEncoding()));

// Operation *toLayoutOp = builder.create<ttir::ToLayoutOp>(
// consumerOp->getLoc(), newTensorType, producerOp->getResult(0),
// emptyOp);

// consumerOp->setOperand(edge.operandIndex, toLayoutOp->getResult(0));
// }
}
}
};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memreconfig-enabled=true insert-reshard=add_0_1_2=0" %s | FileCheck %s
// UNSUPPORTED: true
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memreconfig-enabled=true insert-memreconfig=add_0_1_2=0 override-output-layout=add_1_2=1x1:dram:interleaved" %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
#loc = loc("test_ops.py:17_0_0":0:0)
module attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> tensor<1x32x32xf32> {
// CHECK: #[[L1_:.*]] = #tt.memory_space<l1>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK-DAG: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #l1_>, block_sharded>
// CHECK-DAG: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <1x1>, memref<32x32xf32, #dram>, interleaved>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
// CHECK: %[[C:.*]] = "ttnn.add"{{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %{{.*}} = "ttnn.to_layout"(%[[C]], %0) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
// CHECK: %{{.*}} = "ttnn.to_memory_config"(%[[C]]) {{.*}} -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
%5 = "ttir.relu"(%3, %4) <{operandSegmentSizes = array<i32: 1, 1>, operand_constraints = [#any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
Expand Down
Loading

0 comments on commit f041d8d

Please sign in to comment.