Skip to content

Commit

Permalink
L1 interleaved policy (#1117)
Browse files Browse the repository at this point in the history
V1 implementation for L1Interleaved policy (#1132)
  • Loading branch information
fbajraktariTT authored Nov 5, 2024
1 parent 3dbf089 commit c038025
Show file tree
Hide file tree
Showing 24 changed files with 477 additions and 32 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,7 @@ third_party/tt-metal
.cache
*pycache*
*.egg-info
ttrt-artifacts/*
query_results.json
run_results.json
ttrt_report.xml
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def TT_LayoutAttr : TT_Attr<"Layout", "layout"> {
bool isSystemMemorySpace() const { return ::mlir::tt::isSystemMemorySpace(getMemorySpace()); }
bool isDeviceMemorySpace() const { return ::mlir::tt::isDeviceMemorySpace(getMemorySpace()); }
bool hasShardedTensorMemoryLayout() const;
bool hasInterleavedTensorMemoryLayout() const;
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool isTiled() const;
Type getElementType() const;
Type getScalarElementType() const;
Expand Down
47 changes: 47 additions & 0 deletions include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H
#define TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H

#include <llvm/ADT/StringSwitch.h>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt {

enum class MemoryLayoutAnalysisPolicyType { DFSharding, L1Interleaved };

struct MemoryLayoutAnalysisPolicyTypeParser
: public llvm::cl::parser<MemoryLayoutAnalysisPolicyType> {
public:
MemoryLayoutAnalysisPolicyTypeParser(llvm::cl::Option &opt)
: llvm::cl::parser<MemoryLayoutAnalysisPolicyType>(opt) {}

bool parse(llvm::cl::Option &opt, llvm::StringRef argName,
llvm::StringRef arg, MemoryLayoutAnalysisPolicyType &value) {
value = llvm::StringSwitch<MemoryLayoutAnalysisPolicyType>(arg)
.Case("DFSharding", MemoryLayoutAnalysisPolicyType::DFSharding)
.Case("L1Interleaved",
MemoryLayoutAnalysisPolicyType::L1Interleaved);
return false;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
llvm::StringRef policy;
switch (value) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
policy = "DFSharding";
break;
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
policy = "L1Interleaved";
break;
}
os << "memory-layout-analysis-policy=" << policy << "\n";
}
};

} // namespace mlir::tt

#endif // TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSIS_H
1 change: 0 additions & 1 deletion include/ttmlir/Dialect/TT/Utils/OverrideParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#define TTMLIR_DIALECT_TT_UTILS_OVERRIDEPARAMS_H

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include <cstdint>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt {
Expand Down
21 changes: 11 additions & 10 deletions include/ttmlir/Dialect/TTNN/Analysis/DFShardingPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"

namespace mlir::tt::ttnn {

// Process ops in DFS schedulable order and build shard chain configs.
// Schedule is also produced as a side effect of sharding.
//
class DFShardingPolicy {
class DFShardingPolicy : public MemoryLayoutAnalysisPolicy {
private:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;
std::unordered_set<Edge> overrideReshardEdges;

public:
DFShardingPolicy(
Expand All @@ -28,11 +25,15 @@ class DFShardingPolicy {
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize),
overrideReshardEdges() {}

void run(const std::unordered_set<Edge> &overrideReshardEdges);
void run() final;

void setOverrideReshardEdges(const std::unordered_set<Edge> &reshardEdges) {
overrideReshardEdges = reshardEdges;
}
};

} // namespace mlir::tt::ttnn
Expand Down
30 changes: 30 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"

namespace mlir::tt::ttnn {

class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
L1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

void run() final;
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_L1INTERLEAVEDPOLICY_H
11 changes: 5 additions & 6 deletions include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,29 @@
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSIS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"

namespace mlir::tt::ttnn {

enum class MemoryLayoutAnalysisPolicyType {
DFSharding,
};

struct MemoryLayoutAnalysisInput {
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
unsigned usableL1CacheSize = 0;
std::unordered_set<Edge> overrideReshardEdges;
MemoryLayoutAnalysisPolicyType policy;

MemoryLayoutAnalysisInput() : legalLayouts() {}

MemoryLayoutAnalysisInput(
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
unsigned usableL1CacheSize,
const std::unordered_set<Edge> &overrideReshardEdges)
const std::unordered_set<Edge> &overrideReshardEdges,
MemoryLayoutAnalysisPolicyType policy)
: legalLayouts(legalLayouts), usableL1CacheSize(usableL1CacheSize),
overrideReshardEdges(overrideReshardEdges) {}
overrideReshardEdges(overrideReshardEdges), policy(policy) {}

bool operator==(const MemoryLayoutAnalysisInput &rhs) const {
return legalLayouts == rhs.legalLayouts;
Expand Down
39 changes: 39 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"

namespace mlir::tt::ttnn {

class MemoryLayoutAnalysisPolicy {
protected:
Operation *rootOp;
std::vector<L1ChainConfig> *l1ChainConfigs;
llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>> legalLayouts;
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> *schedule;
unsigned usableL1CacheSize = 0;

public:
virtual ~MemoryLayoutAnalysisPolicy() {};

MemoryLayoutAnalysisPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
const llvm::DenseMap<Operation *, std::vector<tt::LayoutAttr>>
&legalLayouts,
llvm::DenseMap<func::FuncOp, llvm::SmallVector<Operation *>> &schedule,
unsigned usableL1CacheSize)
: rootOp(rootOp), l1ChainConfigs(&l1ChainConfigs),
legalLayouts(legalLayouts), schedule(&schedule),
usableL1CacheSize(usableL1CacheSize) {}

virtual void run() = 0;
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSISPOLICY_H
14 changes: 10 additions & 4 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
#define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H

#include "mlir/Pass/PassOptions.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include <cstdint>
#include <llvm/ADT/SmallVector.h>
#include <llvm/ADT/StringRef.h>
#include <llvm/Support/CommandLine.h>

namespace mlir::tt::ttnn {

// Options for the TTIR to TTNN backend pipeline.
//
struct TTIRToTTNNBackendPipelineOptions
Expand Down Expand Up @@ -85,6 +83,14 @@ struct TTIRToTTNNBackendPipelineOptions
"of shard specs."),
llvm::cl::init(false)};

// Specify policy for memory layout analysis.
//
Option<MemoryLayoutAnalysisPolicyType, MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
*this, "memory-layout-analysis-policy",
llvm::cl::desc("Specify policy for memory layout analysis."),
llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)};

// Option to provide a system descriptor flatbuffer file to compile
// against.
//
Expand Down
9 changes: 9 additions & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct TTNNOptimizerOptions {
llvm::StringMap<OutputLayoutOverrideParams> overrideOutputLayout =
llvm::StringMap<OutputLayoutOverrideParams>();
bool memoryLayoutAnalysisEnabled = false;
MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy =
MemoryLayoutAnalysisPolicyType::DFSharding;
bool memReconfigEnabled = false;
int64_t maxLegalLayouts = 64;
};
Expand Down Expand Up @@ -95,6 +97,7 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
memoryLayoutAnalysisEnabled =
std::move(options.memoryLayoutAnalysisEnabled);
memReconfigEnabled = std::move(options.memReconfigEnabled);
memoryLayoutAnalysisPolicy = std::move(options.memoryLayoutAnalysisPolicy);
maxLegalLayouts = std::move(options.maxLegalLayouts);
}

Expand Down Expand Up @@ -122,6 +125,12 @@ class TTNNOptimizerBase : public ::mlir::OperationPass<::mlir::ModuleOp> {
"we support all "
"types of shard specs."),
::llvm::cl::init(false)};
::mlir::Pass::Option<mlir::tt::MemoryLayoutAnalysisPolicyType,
mlir::tt::MemoryLayoutAnalysisPolicyTypeParser>
memoryLayoutAnalysisPolicy{
*this, "memory-layout-analysis-policy",
llvm::cl::desc("Specify policy for memory layout analysis."),
llvm::cl::init(MemoryLayoutAnalysisPolicyType::DFSharding)};
::mlir::Pass::Option<int64_t> maxLegalLayouts{
*this, "max-legal-layouts",
::llvm::cl::desc(
Expand Down
1 change: 1 addition & 0 deletions include/ttmlir/Dialect/TTNN/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "ttmlir/Dialect/TT/Utils/OverrideParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.h"
#include "ttmlir/Dialect/TTNN/IR/TTNN.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOps.h"

Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/TT/IR/TTOpsTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,22 @@ bool LayoutAttr::hasShardedTensorMemoryLayout() const {
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool LayoutAttr::hasInterleavedTensorMemoryLayout() const {
return (getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool LayoutAttr::hasShardedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::HeightSharded or
getMemLayout() == TensorMemoryLayout::WidthSharded or
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

bool LayoutAttr::hasInterleavedL1TensorMemoryLayout() const {
return ::mlir::tt::isL1MemorySpace(getMemorySpace()) and
(getMemLayout() == TensorMemoryLayout::Interleaved);
}

bool LayoutAttr::isTiled() const {
return ::mlir::isa<::mlir::tt::TileType>(getElementType());
}
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRTTNNAnalysis
MemoryLayoutAnalysis.cpp
L1ChainConfig.cpp
DFShardingPolicy.cpp
L1InterleavedPolicy.cpp
ShardSolver.cpp

ADDITIONAL_HEADER_DIRS
Expand Down
3 changes: 1 addition & 2 deletions lib/Dialect/TTNN/Analysis/DFShardingPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

namespace mlir::tt::ttnn {

void DFShardingPolicy::run(
const std::unordered_set<Edge> &overrideReshardEdges) {
void DFShardingPolicy::run() {
rootOp->walk([&](func::FuncOp func) {
DeviceAttr deviceAttr = getCurrentScopeDevice(func);
mlir::tt::scheduler::Scheduler scheduler(&func);
Expand Down
Loading

0 comments on commit c038025

Please sign in to comment.