From d79ba3d329645c50c6e22e38188ec1b73b72a423 Mon Sep 17 00:00:00 2001 From: cheng cheng <472491134@qq.com> Date: Sun, 19 Jun 2022 02:52:30 +0800 Subject: [PATCH] Strict ordering in memory reuse algorithm (#8441) --- .../core/job/intra_job_mem_sharing_util.cpp | 27 ++++++++++++++----- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/oneflow/core/job/intra_job_mem_sharing_util.cpp b/oneflow/core/job/intra_job_mem_sharing_util.cpp index 1af896e1b57..6ee9e8ecea0 100644 --- a/oneflow/core/job/intra_job_mem_sharing_util.cpp +++ b/oneflow/core/job/intra_job_mem_sharing_util.cpp @@ -528,7 +528,7 @@ void MemReusedAlgorithm_AllocateByOrderAndMutualExclusion( void MemReusedAlgorithm_MemSizeFirstAlgo( const HashMap>& regst2mutual_exclusion_regsts, - MemBlockResultInfo* result) { + const HashMap& regst2alloc_order, MemBlockResultInfo* result) { std::vector order; order.reserve(regst2mutual_exclusion_regsts.size()); HashMap regst_desc2size; @@ -538,7 +538,10 @@ void MemReusedAlgorithm_MemSizeFirstAlgo( .second); } std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) { - return regst_desc2size.at(lhs) > regst_desc2size.at(rhs); + int64_t l_size = regst_desc2size.at(lhs); + int64_t r_size = regst_desc2size.at(rhs); + if (l_size == r_size) { return regst2alloc_order.at(lhs) < regst2alloc_order.at(rhs); } + return l_size > r_size; }); MemReusedAlgorithm_AllocateByOrderAndMutualExclusion(order, regst_desc2size, regst2mutual_exclusion_regsts, result); @@ -546,7 +549,7 @@ void MemReusedAlgorithm_MemSizeFirstAlgo( void MemReusedAlgorithm_MutualExclusionFirstAlgo( const HashMap>& regst2mutual_exclusion_regsts, - MemBlockResultInfo* result) { + const HashMap& regst2alloc_order, MemBlockResultInfo* result) { std::vector order; order.reserve(regst2mutual_exclusion_regsts.size()); HashMap regst_desc2size; @@ -556,8 +559,10 @@ void MemReusedAlgorithm_MutualExclusionFirstAlgo( .second); } std::sort(order.begin(), order.end(), [&](RegstDescProto* lhs, RegstDescProto* rhs) { - return regst2mutual_exclusion_regsts.at(lhs).size() - < regst2mutual_exclusion_regsts.at(rhs).size(); + int64_t l_size = regst2mutual_exclusion_regsts.at(lhs).size(); + int64_t r_size = regst2mutual_exclusion_regsts.at(rhs).size(); + if (l_size == r_size) { return regst2alloc_order.at(lhs) < regst2alloc_order.at(rhs); } + return l_size > r_size; }); MemReusedAlgorithm_AllocateByOrderAndMutualExclusion(order, regst_desc2size, regst2mutual_exclusion_regsts, result); @@ -704,12 +709,20 @@ void SelectAlgorithmGenMemBlockOffset4Regsts( MemBlockResultInfo* result) { CHECK_EQ(result->mem_block_size, 0); CHECK(result->regst_desc2offset.empty()); + + // NOTE(chengcheng): When mem size or exclusion num equal, there need second order by allocate. + HashMap regst2alloc_order; + for (int64_t i = 0; i < alloc_regsts_timeline.size(); ++i) { + const auto& regsts = alloc_regsts_timeline.at(i); + for (RegstDescProto* regst : regsts) { CHECK(regst2alloc_order.emplace(regst, i).second); } + } switch (algo_id) { case kMemSizeFirstAlgo: - MemReusedAlgorithm_MemSizeFirstAlgo(regst2mutual_exclusion_regsts, result); + MemReusedAlgorithm_MemSizeFirstAlgo(regst2mutual_exclusion_regsts, regst2alloc_order, result); break; case kMutualExclusionFirstAlgo: - MemReusedAlgorithm_MutualExclusionFirstAlgo(regst2mutual_exclusion_regsts, result); + MemReusedAlgorithm_MutualExclusionFirstAlgo(regst2mutual_exclusion_regsts, regst2alloc_order, + result); break; case kTimeLineAlgo: MemReusedAlgorithm_TimeLineAlgo(alloc_regsts_timeline, free_regsts_timeline, result);