-
Notifications
You must be signed in to change notification settings - Fork 160
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add DiscRematerializationPass to reduce peak memory
- Loading branch information
Showing
5 changed files
with
315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
276 changes: 276 additions & 0 deletions
276
tao_compiler/mlir/disc/transforms/disc_rematerialization.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,276 @@ | ||
// Copyright 2021 The BladeDISC Authors. All rights reserved. | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// This file implements logic for lowering HLO DISC dialect to LHLO DISC | ||
// dialect. | ||
|
||
#include <algorithm> | ||
#include <cstdint> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <list> | ||
#include <iterator> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
#include "lhlo/IR/lhlo_ops.h" | ||
#include "llvm/Support/Debug.h" | ||
#include "mlir/Dialect/Arith/IR/Arith.h" | ||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h" | ||
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" | ||
#include "mlir/Dialect/Func/Transforms/FuncConversions.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/Shape/IR/Shape.h" | ||
#include "mlir/Dialect/Shape/Transforms/Passes.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/IR/AffineMap.h" | ||
#include "mlir/IR/Attributes.h" | ||
#include "mlir/IR/Builders.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/BuiltinTypes.h" | ||
#include "mlir/IR/Location.h" | ||
#include "mlir/IR/MLIRContext.h" | ||
#include "mlir/IR/Operation.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
#include "mlir/disc/IR/disc_shape_ops.h" | ||
#include "mlir/disc/IR/lhlo_disc_ops.h" | ||
#include "mlir/disc/disc_util.h" | ||
#include "mlir/disc/transforms/PassDetail.h" | ||
#include "mlir/disc/transforms/fusion_utils.h" | ||
#include "mlir/disc/transforms/placement_utils.h" | ||
#include "mlir/disc/transforms/rewriters.h" | ||
#include "mlir/disc/transforms/shape_utils.h" | ||
|
||
namespace mlir { | ||
using placement_utils::kDiscPlaceAssignment; | ||
using placement_utils::kGpu; | ||
|
||
namespace mhlo_disc { | ||
namespace { | ||
|
||
bool IsRematerializable(const Operation* op) { | ||
return true; | ||
} | ||
|
||
enum class RematStrategy{ | ||
// Recompute the node at a later program point. | ||
kRecompute, | ||
// Change the layout into a compact form and uncompress it back at a later | ||
// program point. | ||
kCompress, | ||
// Copy the data off the device to the host to be copied back later. | ||
kHostOffload, | ||
|
||
// Combination of different strategies. | ||
kRecomputeAndCompress, | ||
kRecomputeAndHostOffload, | ||
kCompressAndHostOffload, | ||
kAll | ||
}; | ||
|
||
struct Item { | ||
Value memref; | ||
std::vector<int> live_range; | ||
}; | ||
|
||
class LivingItems{ | ||
public: | ||
LivingItems() = default; | ||
void Add(const Value memref, const std::unordered_map<int64_t, int>& op_position_map) { | ||
// Add memrefs and their live range. | ||
} | ||
|
||
void Remove(const Value memref) { | ||
int64_t key = reinterpret_cast<int64_t>(memref.getAsOpaquePointer()); | ||
int index = live_range_map_[key]; | ||
live_range_map_.erase(key); | ||
//living_items_.erase(std::advance(living_items_.begin(), index)); | ||
} | ||
|
||
bool IsExist(const Value memref) { | ||
int64_t key = reinterpret_cast<int64_t>(memref.getAsOpaquePointer()); | ||
return live_range_map_.find(key) != live_range_map_.end(); | ||
} | ||
|
||
private: | ||
std::list<Item> living_items_; | ||
std::map<int64_t, int> live_range_map_; | ||
|
||
}; | ||
|
||
class MemoryUsageTracker { | ||
public: | ||
MemoryUsageTracker() = default; | ||
|
||
void SetAllOperationPositionInfo(const std::unordered_map<int64_t, int>& operation_position_map) { | ||
operation_position_map_ = operation_position_map; | ||
} | ||
|
||
void ProcessAlloc(memref::AllocOp op) { | ||
auto memref = op.getResult(); | ||
if (NeedSkip(memref)) { | ||
return; | ||
} | ||
current_peak_memory_usage_ += GetMemoryUsageForValue(memref); | ||
living_items_.Add(memref, operation_position_map_); | ||
} | ||
|
||
void ProcessDealloc(memref::DeallocOp op) { | ||
auto memref = op.getOperation()->getOperand(0); | ||
if(!living_items_.IsExist(memref)) { | ||
return; | ||
} | ||
|
||
current_peak_memory_usage_ -= GetMemoryUsageForValue(memref); | ||
living_items_.Remove(memref); | ||
} | ||
|
||
size_t GetMemoryUsageForValue(Value memref) { | ||
auto memref_ty = memref.getType().dyn_cast_or_null<MemRefType>(); | ||
if(!memref_ty) { | ||
return 0; | ||
} | ||
|
||
assert(memref_ty.getLayout().isIdentity()); | ||
if(memref_ty.hasStaticShape()) { | ||
int byte_width = memref_ty.getElementTypeBitWidth() / 8; | ||
auto shape = memref_ty.getShape(); | ||
size_t logical_size = byte_width; | ||
for (size_t dimSize : shape) { | ||
logical_size *= dimSize; | ||
} | ||
return logical_size; | ||
} else { | ||
return 1; | ||
} | ||
} | ||
|
||
size_t GetRecomputationScore(Value memref) { | ||
return 0; | ||
} | ||
|
||
size_t GetOffloadScore(Value memref) { | ||
return 0; | ||
} | ||
|
||
size_t GetCompressionScore(Value memref) { | ||
return 0; | ||
} | ||
|
||
std::pair<RematStrategy, size_t> GetRematEvaluation(Value memref) { | ||
switch(remat_strategy_) { | ||
case RematStrategy::kRecompute: | ||
return std::make_pair(RematStrategy::kRecompute, GetRecomputationScore(memref)); | ||
case RematStrategy::kCompress: | ||
return std::make_pair(RematStrategy::kCompress, GetCompressionScore(memref)); | ||
case RematStrategy::kHostOffload: | ||
return std::make_pair(RematStrategy::kHostOffload, GetOffloadScore(memref)); | ||
default: | ||
return std::make_pair(RematStrategy::kRecompute, GetRecomputationScore(memref)); | ||
} | ||
} | ||
|
||
void RematerializeToTargetMemoryUsage(size_t peak_memory_target) { | ||
|
||
} | ||
|
||
void RematerializeToLowestMemoryUsage() { | ||
// Iterate until we cannot get more memory-saving benefit | ||
} | ||
|
||
size_t GetCurrentPeakMemoryUsage() { return current_peak_memory_usage_; } | ||
|
||
bool NeedSkip(const Value memref) { | ||
return GetMemoryUsageForValue(memref) < kSmallMemrefSize; | ||
} | ||
private: | ||
LivingItems living_items_; | ||
size_t current_peak_memory_usage_; | ||
const size_t kSmallMemrefSize = 50 * 1024 * 1024; // memoryrefs under kSmallMemrefSize are not considered when remat; | ||
std::unordered_map<int64_t, int> operation_position_map_; | ||
RematStrategy remat_strategy_; | ||
}; | ||
|
||
|
||
|
||
struct DiscRematerializationPass : public DiscRematerializationPassBase<DiscRematerializationPass> { | ||
using DiscRematerializationPassBase<DiscRematerializationPass>::DiscRematerializationPassBase; | ||
|
||
void getDependentDialects(DialectRegistry& registry) const override { | ||
registry.insert<lmhlo_disc::LmhloDiscDialect, memref::MemRefDialect>(); | ||
} | ||
private: | ||
MemoryUsageTracker memory_usage_tracker_; | ||
|
||
public: | ||
DiscRematerializationPass() = default; | ||
|
||
bool IsDynmaicShapeGraph() { | ||
return false; | ||
} | ||
|
||
size_t GetPeakMemoryLimit() { | ||
if(IsDynmaicShapeGraph()) { | ||
return -1; | ||
} | ||
return 30ll * 1024ll * 1024ll * 1024ll; // 30GB | ||
} | ||
|
||
void runOnOperation() override { | ||
auto& context = getContext(); | ||
RewritePatternSet patterns(&context); | ||
ConversionTarget target(context); | ||
target.addLegalDialect<arith::ArithDialect, lmhlo_disc::LmhloDiscDialect, | ||
memref::MemRefDialect, shape::ShapeDialect, | ||
tensor::TensorDialect>(); | ||
|
||
ModuleOp module = getOperation(); | ||
auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main"); | ||
std::unordered_map<int64_t, int> op_position_map; | ||
for (auto& block : main_func.getBody()) { | ||
for (auto& op : block) { | ||
op_position_map[reinterpret_cast<int64_t>(&op)] = op_position_map.size(); | ||
} | ||
} | ||
|
||
memory_usage_tracker_.SetAllOperationPositionInfo(op_position_map); | ||
// iterate over op_position_map | ||
for (const auto& pair : op_position_map) { | ||
Operation* op = reinterpret_cast<Operation*>(pair.first); | ||
if(isa<memref::AllocOp>(op)) { | ||
memory_usage_tracker_.ProcessAlloc(cast<memref::AllocOp>(op)); | ||
if(!IsDynmaicShapeGraph() && memory_usage_tracker_.GetCurrentPeakMemoryUsage() > GetPeakMemoryLimit()) { | ||
memory_usage_tracker_.RematerializeToTargetMemoryUsage(GetPeakMemoryLimit()); | ||
} | ||
} else if(isa<memref::DeallocOp>(op)) { | ||
memory_usage_tracker_.ProcessDealloc(cast<memref::DeallocOp>(op)); | ||
} | ||
} | ||
|
||
// Dynamic Shape Graph Processing | ||
if(IsDynmaicShapeGraph()) { | ||
memory_usage_tracker_.RematerializeToLowestMemoryUsage(); | ||
} | ||
return; | ||
} | ||
}; | ||
} // namespace | ||
|
||
std::unique_ptr<OperationPass<ModuleOp>> createDiscRematerializationPass() { | ||
return std::make_unique<DiscRematerializationPass>(); | ||
} | ||
|
||
} // namespace mhlo_disc | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters