Skip to content

Commit

Permalink
Add DiscRematerializationPass to reduce peak memory
Browse files Browse the repository at this point in the history
  • Loading branch information
eedalong committed Jun 19, 2024
1 parent 4d35390 commit 37257ed
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tao_compiler/mlir/disc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,35 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "disc_rematerialization",
srcs = ["transforms/disc_rematerialization.cc"],
hdrs = [
"transforms/passes.h",
"transforms/rewriters.h",
],
deps = [
":lmhlo_disc",
":pass_details",
":placement_utils",
":shape_utils",
":fusion_utils",
"@org_tensorflow//tensorflow/compiler/xla/mlir_hlo:lhlo",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MemRefDialect",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:ShapeDialect",
"@llvm-project//mlir:ShapeTransforms",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
"@llvm-project//mlir:SCFDialect",
],
alwayslink = 1,
)

cc_library(
name = "disc_lower_to_library_call",
srcs = ["transforms/disc_lower_to_library_call.cc"],
Expand Down Expand Up @@ -2490,6 +2519,7 @@ cc_library(
":disc_optimization_barrier_expand",
":disc_parallel_loop_collapsing",
":disc_parallel_loop_tiling",
":disc_rematerialization",
":disc_remove_dead_buffer",
":disc_remove_shape_constraints",
":disc_shape_optimization",
Expand Down
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/disc_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,8 @@ LogicalResult LowerHLOToLLVM(ModuleOp m, const DISCLoweringOptions& options) {
pm.addNestedPass<FuncOp>(bufferization::createBufferDeallocationPass());
pm.addNestedPass<FuncOp>(disc_ral::createDiscBufferDeallocationPass());

pm.addPass(mhlo_disc::createDiscRematerializationPass());

pm.addPass(disc_ral::createRalInjectExecutionContextPass());
pm.addNestedPass<FuncOp>(
disc_ral::createDiscLowerToLibraryCallPass(gpu_enabled));
Expand Down
276 changes: 276 additions & 0 deletions tao_compiler/mlir/disc/transforms/disc_rematerialization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
// Copyright 2021 The BladeDISC Authors. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// This file implements logic for lowering HLO DISC dialect to LHLO DISC
// dialect.

#include <algorithm>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <list>
#include <iterator>
#include <string>
#include <utility>
#include <vector>

#include "lhlo/IR/lhlo_ops.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/disc/IR/disc_shape_ops.h"
#include "mlir/disc/IR/lhlo_disc_ops.h"
#include "mlir/disc/disc_util.h"
#include "mlir/disc/transforms/PassDetail.h"
#include "mlir/disc/transforms/fusion_utils.h"
#include "mlir/disc/transforms/placement_utils.h"
#include "mlir/disc/transforms/rewriters.h"
#include "mlir/disc/transforms/shape_utils.h"

namespace mlir {
using placement_utils::kDiscPlaceAssignment;
using placement_utils::kGpu;

namespace mhlo_disc {
namespace {

bool IsRematerializable(const Operation* op) {
return true;
}

enum class RematStrategy{
// Recompute the node at a later program point.
kRecompute,
// Change the layout into a compact form and uncompress it back at a later
// program point.
kCompress,
// Copy the data off the device to the host to be copied back later.
kHostOffload,

// Combination of different strategies.
kRecomputeAndCompress,
kRecomputeAndHostOffload,
kCompressAndHostOffload,
kAll
};

struct Item {
Value memref;
std::vector<int> live_range;
};

class LivingItems{
public:
LivingItems() = default;
void Add(const Value memref, const std::unordered_map<int64_t, int>& op_position_map) {
// Add memrefs and their live range.
}

void Remove(const Value memref) {
int64_t key = reinterpret_cast<int64_t>(memref.getAsOpaquePointer());
int index = live_range_map_[key];
live_range_map_.erase(key);
//living_items_.erase(std::advance(living_items_.begin(), index));
}

bool IsExist(const Value memref) {
int64_t key = reinterpret_cast<int64_t>(memref.getAsOpaquePointer());
return live_range_map_.find(key) != live_range_map_.end();
}

private:
std::list<Item> living_items_;
std::map<int64_t, int> live_range_map_;

};

class MemoryUsageTracker {
public:
MemoryUsageTracker() = default;

void SetAllOperationPositionInfo(const std::unordered_map<int64_t, int>& operation_position_map) {
operation_position_map_ = operation_position_map;
}

void ProcessAlloc(memref::AllocOp op) {
auto memref = op.getResult();
if (NeedSkip(memref)) {
return;
}
current_peak_memory_usage_ += GetMemoryUsageForValue(memref);
living_items_.Add(memref, operation_position_map_);
}

void ProcessDealloc(memref::DeallocOp op) {
auto memref = op.getOperation()->getOperand(0);
if(!living_items_.IsExist(memref)) {
return;
}

current_peak_memory_usage_ -= GetMemoryUsageForValue(memref);
living_items_.Remove(memref);
}

size_t GetMemoryUsageForValue(Value memref) {
auto memref_ty = memref.getType().dyn_cast_or_null<MemRefType>();
if(!memref_ty) {
return 0;
}

assert(memref_ty.getLayout().isIdentity());
if(memref_ty.hasStaticShape()) {
int byte_width = memref_ty.getElementTypeBitWidth() / 8;
auto shape = memref_ty.getShape();
size_t logical_size = byte_width;
for (size_t dimSize : shape) {
logical_size *= dimSize;
}
return logical_size;
} else {
return 1;
}
}

size_t GetRecomputationScore(Value memref) {
return 0;
}

size_t GetOffloadScore(Value memref) {
return 0;
}

size_t GetCompressionScore(Value memref) {
return 0;
}

std::pair<RematStrategy, size_t> GetRematEvaluation(Value memref) {
switch(remat_strategy_) {
case RematStrategy::kRecompute:
return std::make_pair(RematStrategy::kRecompute, GetRecomputationScore(memref));
case RematStrategy::kCompress:
return std::make_pair(RematStrategy::kCompress, GetCompressionScore(memref));
case RematStrategy::kHostOffload:
return std::make_pair(RematStrategy::kHostOffload, GetOffloadScore(memref));
default:
return std::make_pair(RematStrategy::kRecompute, GetRecomputationScore(memref));
}
}

void RematerializeToTargetMemoryUsage(size_t peak_memory_target) {

}

void RematerializeToLowestMemoryUsage() {
// Iterate until we cannot get more memory-saving benefit
}

size_t GetCurrentPeakMemoryUsage() { return current_peak_memory_usage_; }

bool NeedSkip(const Value memref) {
return GetMemoryUsageForValue(memref) < kSmallMemrefSize;
}
private:
LivingItems living_items_;
size_t current_peak_memory_usage_;
const size_t kSmallMemrefSize = 50 * 1024 * 1024; // memoryrefs under kSmallMemrefSize are not considered when remat;
std::unordered_map<int64_t, int> operation_position_map_;
RematStrategy remat_strategy_;
};



struct DiscRematerializationPass : public DiscRematerializationPassBase<DiscRematerializationPass> {
using DiscRematerializationPassBase<DiscRematerializationPass>::DiscRematerializationPassBase;

void getDependentDialects(DialectRegistry& registry) const override {
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect>();
}
private:
MemoryUsageTracker memory_usage_tracker_;

public:
DiscRematerializationPass() = default;

bool IsDynmaicShapeGraph() {
return false;
}

size_t GetPeakMemoryLimit() {
if(IsDynmaicShapeGraph()) {
return -1;
}
return 30ll * 1024ll * 1024ll * 1024ll; // 30GB
}

void runOnOperation() override {
auto& context = getContext();
RewritePatternSet patterns(&context);
ConversionTarget target(context);
target.addLegalDialect<arith::ArithDialect, lmhlo_disc::LmhloDiscDialect,
memref::MemRefDialect, shape::ShapeDialect,
tensor::TensorDialect>();

ModuleOp module = getOperation();
auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
std::unordered_map<int64_t, int> op_position_map;
for (auto& block : main_func.getBody()) {
for (auto& op : block) {
op_position_map[reinterpret_cast<int64_t>(&op)] = op_position_map.size();
}
}

memory_usage_tracker_.SetAllOperationPositionInfo(op_position_map);
// iterate over op_position_map
for (const auto& pair : op_position_map) {
Operation* op = reinterpret_cast<Operation*>(pair.first);
if(isa<memref::AllocOp>(op)) {
memory_usage_tracker_.ProcessAlloc(cast<memref::AllocOp>(op));
if(!IsDynmaicShapeGraph() && memory_usage_tracker_.GetCurrentPeakMemoryUsage() > GetPeakMemoryLimit()) {
memory_usage_tracker_.RematerializeToTargetMemoryUsage(GetPeakMemoryLimit());
}
} else if(isa<memref::DeallocOp>(op)) {
memory_usage_tracker_.ProcessDealloc(cast<memref::DeallocOp>(op));
}
}

// Dynamic Shape Graph Processing
if(IsDynmaicShapeGraph()) {
memory_usage_tracker_.RematerializeToLowestMemoryUsage();
}
return;
}
};
} // namespace

std::unique_ptr<OperationPass<ModuleOp>> createDiscRematerializationPass() {
return std::make_unique<DiscRematerializationPass>();
}

} // namespace mhlo_disc
} // namespace mlir
5 changes: 5 additions & 0 deletions tao_compiler/mlir/disc/transforms/mhlo_disc_passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,8 @@ def DiscOpSchedulePass : Pass<"disc-op-schedule", "ModuleOp"> {
let summary = "Schedule ops in a function";
let constructor = "createDiscOpSchedulePass()";
}

def DiscRematerializationPass : Pass<"disc-rematerialization", "ModuleOp"> {
let summary = "Remat to reduce peak memory";
let constructor = "createDiscRematerializationPass()";
}
2 changes: 2 additions & 0 deletions tao_compiler/mlir/disc/transforms/passes.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,8 @@ createDiscOptimizationBarrierExpandPass();

std::unique_ptr<OperationPass<ModuleOp>> createDiscOpSchedulePass();

std::unique_ptr<OperationPass<ModuleOp>> createDiscRematerializationPass();

std::unique_ptr<OperationPass<ModuleOp>> createDiscArgsMutationExpandPass();

} // namespace mhlo_disc
Expand Down

0 comments on commit 37257ed

Please sign in to comment.