Skip to content

Commit

Permalink
Strict ordering in memory reuse algorithm (#8441)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengtbf authored Jun 18, 2022
1 parent d7ef39f commit d79ba3d
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions oneflow/core/job/intra_job_mem_sharing_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ void MemReusedAlgorithm_AllocateByOrderAndMutualExclusion(

void MemReusedAlgorithm_MemSizeFirstAlgo(
const HashMap<RegstDescProto*, std::vector<RegstDescProto*>>& regst2mutual_exclusion_regsts,
MemBlockResultInfo* result) {
const HashMap<RegstDescProto*, int64_t>& regst2alloc_order, MemBlockResultInfo* result) {
std::vector<RegstDescProto*> order;
order.reserve(regst2mutual_exclusion_regsts.size());
HashMap<RegstDescProto*, int64_t> regst_desc2size;
Expand All @@ -538,15 +538,18 @@ void MemReusedAlgorithm_MemSizeFirstAlgo(
.second);
}
std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) {
return regst_desc2size.at(lhs) > regst_desc2size.at(rhs);
int64_t l_size = regst_desc2size.at(lhs);
int64_t r_size = regst_desc2size.at(rhs);
if (l_size == r_size) { return regst2alloc_order.at(lhs) < regst2alloc_order.at(rhs); }
return l_size > r_size;
});
MemReusedAlgorithm_AllocateByOrderAndMutualExclusion(order, regst_desc2size,
regst2mutual_exclusion_regsts, result);
}

void MemReusedAlgorithm_MutualExclusionFirstAlgo(
const HashMap<RegstDescProto*, std::vector<RegstDescProto*>>& regst2mutual_exclusion_regsts,
MemBlockResultInfo* result) {
const HashMap<RegstDescProto*, int64_t>& regst2alloc_order, MemBlockResultInfo* result) {
std::vector<RegstDescProto*> order;
order.reserve(regst2mutual_exclusion_regsts.size());
HashMap<RegstDescProto*, int64_t> regst_desc2size;
Expand All @@ -556,8 +559,10 @@ void MemReusedAlgorithm_MutualExclusionFirstAlgo(
.second);
}
std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) {
return regst2mutual_exclusion_regsts.at(lhs).size()
< regst2mutual_exclusion_regsts.at(rhs).size();
int64_t l_size = regst2mutual_exclusion_regsts.at(lhs).size();
int64_t r_size = regst2mutual_exclusion_regsts.at(rhs).size();
if (l_size == r_size) { return regst2alloc_order.at(lhs) < regst2alloc_order.at(rhs); }
return l_size > r_size;
});
MemReusedAlgorithm_AllocateByOrderAndMutualExclusion(order, regst_desc2size,
regst2mutual_exclusion_regsts, result);
Expand Down Expand Up @@ -704,12 +709,20 @@ void SelectAlgorithmGenMemBlockOffset4Regsts(
MemBlockResultInfo* result) {
CHECK_EQ(result->mem_block_size, 0);
CHECK(result->regst_desc2offset.empty());

// NOTE(chengcheng): When mem size or exclusion num equal, there need second order by allocate.
HashMap<RegstDescProto*, int64_t> regst2alloc_order;
for (int64_t i = 0; i < alloc_regsts_timeline.size(); ++i) {
const auto& regsts = alloc_regsts_timeline.at(i);
for (RegstDescProto* regst : regsts) { CHECK(regst2alloc_order.emplace(regst, i).second); }
}
switch (algo_id) {
case kMemSizeFirstAlgo:
MemReusedAlgorithm_MemSizeFirstAlgo(regst2mutual_exclusion_regsts, result);
MemReusedAlgorithm_MemSizeFirstAlgo(regst2mutual_exclusion_regsts, regst2alloc_order, result);
break;
case kMutualExclusionFirstAlgo:
MemReusedAlgorithm_MutualExclusionFirstAlgo(regst2mutual_exclusion_regsts, result);
MemReusedAlgorithm_MutualExclusionFirstAlgo(regst2mutual_exclusion_regsts, regst2alloc_order,
result);
break;
case kTimeLineAlgo:
MemReusedAlgorithm_TimeLineAlgo(alloc_regsts_timeline, free_regsts_timeline, result);
Expand Down

0 comments on commit d79ba3d

Please sign in to comment.