From 8a4238a9c2fb0deecbfd8cd168c9edc2971fd326 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 21 Jan 2021 17:50:04 +0800 Subject: [PATCH 01/11] fix unstable compute --- src/te/autodiff/ad_simplify.cc | 28 +++++++++++++---------- tests/python/unittest/test_te_autodiff.py | 24 ++++++++++++++++--- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index cc0e82066171..5c0921190b05 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -413,15 +413,17 @@ class FactorOutAtomicFormulasFunctor auto res_b = VisitExpr(op->b); // For the And case we return the union of the sets of atomic formulas - std::unordered_set res_set; - res_set.reserve(res_a.atomic_formulas.size() + res_b.atomic_formulas.size()); + std::unordered_set res_a_set; + res_a_set.reserve(res_a.atomic_formulas.size()); std::copy(res_a.atomic_formulas.begin(), res_a.atomic_formulas.end(), - std::inserter(res_set, res_set.end())); - std::copy(res_b.atomic_formulas.begin(), res_b.atomic_formulas.end(), - std::inserter(res_set, res_set.end())); - - std::vector res{res_set.begin(), res_set.end()}; + std::inserter(res_a_set, res_a_set.end())); + std::vector res = res_a.atomic_formulas; + for (const auto& e : res_b.atomic_formulas) { + if (res_a_set.find(e) == res_a_set.end()) { + res.emplace_back(e); + } + } // And the residuals are combined with && return {res, res_a.rest && res_b.rest}; } @@ -443,10 +445,13 @@ class FactorOutAtomicFormulasFunctor // For the Or case we intersect the sets of atomic formulas std::unordered_set res_set; + std::vector res; res_set.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); - for (const auto& res_b_formula : res_b_set) { + res.reserve(std::min(res_a.atomic_formulas.size(), res_b.atomic_formulas.size())); + for (const auto& res_b_formula : res_b.atomic_formulas) { if (res_a_set.count(res_b_formula)) { res_set.insert(res_b_formula); + res.push_back(res_b_formula); } } @@ -454,13 +459,13 @@ class FactorOutAtomicFormulasFunctor // which are left behind, and then combine them with the residuals into the new residual. std::vector new_cond_a; new_cond_a.reserve(res_a.atomic_formulas.size() - res_set.size()); - for (const auto& formula : res_a_set) { + for (const auto& formula : res_a.atomic_formulas) { if (!res_set.count(formula)) new_cond_a.emplace_back(formula); } std::vector new_cond_b; new_cond_b.reserve(res_b.atomic_formulas.size() - res_set.size()); - for (const auto& formula : res_b_set) { + for (const auto& formula : res_b.atomic_formulas) { if (!res_set.count(formula)) new_cond_b.emplace_back(formula); } @@ -468,7 +473,6 @@ class FactorOutAtomicFormulasFunctor res_b.atomic_formulas = std::move(new_cond_b); PrimExpr new_rest = res_a.to_expr() || res_b.to_expr(); - std::vector res{res_set.begin(), res_set.end()}; return {res, new_rest}; } @@ -775,7 +779,6 @@ arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_ if (eliminate_div_mod) { transf = transf + EliminateDivModFromDomainConditions(transf->dst); } - // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably // should find a better terminating criterion (like stop when the domain volume stops decreasing) // Also 2 steps seems to be slightly better than 3 @@ -787,6 +790,7 @@ arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_ return transf; } + // Use the condition of a reduction op to simplify its domain (axis) PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map& outer_vranges) { if (const ReduceNode* red = expr.as()) { diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index 6031182091fe..be4af238d985 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -343,7 +343,25 @@ def test_reduction_init(): check_grad(B, A0) +def test_stable(): + X = te.placeholder((32, 512, 16, 16), name="X") + W = te.placeholder((1024, 512, 1, 1), name="W") + strides, padding, dilation = 2, 0, 1 + R = topi.nn.conv2d(X, W, strides, padding, dilation) + ones = topi.full_like(R, 1.0) + grads = te.gradient(R, [X], head=ones) + dag = tvm.auto_scheduler.ComputeDAG(grads) + repeat = 100 + for i in range(repeat): + grads = te.gradient(R, [X], head=ones) + new_dag = tvm.auto_scheduler.ComputeDAG(grads) + print(dag) + print(new_dag) + assert str(dag) == str(new_dag) + + if __name__ == "__main__": - test_basic_operation() - test_topi() - test_stride_dilation() + # test_basic_operation() + # test_topi() + # test_stride_dilation() + test_stable() From 3c68f3fc7b6f915f86420e3b5fc94195669330f5 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 21 Jan 2021 23:51:59 +0800 Subject: [PATCH 02/11] fix --- tests/python/unittest/test_te_autodiff.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index be4af238d985..6b012ec9eb04 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -361,7 +361,7 @@ def test_stable(): if __name__ == "__main__": - # test_basic_operation() - # test_topi() - # test_stride_dilation() + test_basic_operation() + test_topi() + test_stride_dilation() test_stable() From 29f1a80fc66ab55ffd9a123a2769f20cd2d4e00a Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Thu, 21 Jan 2021 23:52:59 +0800 Subject: [PATCH 03/11] fix --- tests/python/unittest/test_te_autodiff.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index 6b012ec9eb04..8a5abeaee97b 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -355,8 +355,6 @@ def test_stable(): for i in range(repeat): grads = te.gradient(R, [X], head=ones) new_dag = tvm.auto_scheduler.ComputeDAG(grads) - print(dag) - print(new_dag) assert str(dag) == str(new_dag) From 237d0f414ecb29df81ba41c1fb813963debe166f Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 22 Jan 2021 00:16:07 +0800 Subject: [PATCH 04/11] lint --- src/te/autodiff/ad_simplify.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/te/autodiff/ad_simplify.cc b/src/te/autodiff/ad_simplify.cc index 5c0921190b05..96f278e63be7 100644 --- a/src/te/autodiff/ad_simplify.cc +++ b/src/te/autodiff/ad_simplify.cc @@ -779,6 +779,7 @@ arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_ if (eliminate_div_mod) { transf = transf + EliminateDivModFromDomainConditions(transf->dst); } + // TODO(sgrechanik-h): Repeating the following steps has a positive effect, however we probably // should find a better terminating criterion (like stop when the domain volume stops decreasing) // Also 2 steps seems to be slightly better than 3 @@ -790,7 +791,6 @@ arith::IntConstraintsTransform SimplifyDomain(const arith::IntConstraints& iter_ return transf; } - // Use the condition of a reduction op to simplify its domain (axis) PrimExpr SimplifyReductionDomain(const PrimExpr& expr, const Map& outer_vranges) { if (const ReduceNode* red = expr.as()) { From 0c7af5008e8e91ed237c51a55e9ec709a3b45015 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 22 Jan 2021 11:17:26 +0000 Subject: [PATCH 05/11] sort linear equation --- src/arith/solve_linear_equation.cc | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index 22bf7360563d..cf67acf7e85b 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -427,11 +427,10 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. - for (const auto& p : system_to_solve->ranges) { - const Var& old_var = p.first; - const Range& old_range = p.second; - if (old_to_new_map.count(old_var)) { - PrimExpr express_by_new_vars = old_to_new_map[old_var]; + for (const auto& old_var: system_to_solve->variables) { + if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) { + const Range& old_range = system_to_solve->ranges.at(old_var); + PrimExpr express_by_new_vars = old_to_new_map.at(old_var); PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); PrimExpr upper_cond = analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); @@ -443,6 +442,22 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol } } } + // for (const auto& p : system_to_solve->ranges) { + // const Var& old_var = p.first; + // const Range& old_range = p.second; + // if (old_to_new_map.count(old_var)) { + // PrimExpr express_by_new_vars = old_to_new_map[old_var]; + // PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); + // PrimExpr upper_cond = + // analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); + // if (!tir::is_const_int(lower_cond, 1)) { + // new_relations.push_back(lower_cond); + // } + // if (!tir::is_const_int(upper_cond, 1)) { + // new_relations.push_back(upper_cond); + // } + // } + // } // Add the rest conditions for (const PrimExpr& cond : rest) { From 0bb32d86b547a9b4c9f06f582363833e7bdc0fb8 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 22 Jan 2021 19:43:22 +0800 Subject: [PATCH 06/11] sort inequalities --- src/arith/solve_linear_inequality.cc | 39 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index f4de9ffb197b..a1f9f193e18b 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -95,8 +95,8 @@ struct ExprLess { }; void DebugPrint( - const std::unordered_set& current_ineq_set, - const std::unordered_set& next_ineq_set, + const std::vector& current_ineq_set, + const std::vector& next_ineq_set, const std::vector& rest, const std::vector>& coef_pos, const std::vector>& coef_neg) { std::cout << "Current ineq set:\n["; @@ -148,9 +148,10 @@ class NormalizeComparisons : public ExprMutator { arith::Analyzer analyzer_; }; -void AddInequality(std::unordered_set* inequality_set, +void AddInequality(std::vector* inequality_set, const PrimExpr& new_ineq, Analyzer* analyzer) { - if (analyzer->CanProve(new_ineq) || inequality_set->find(new_ineq) != inequality_set->end()) { + if (analyzer->CanProve(new_ineq) || std::find( + inequality_set->begin(), inequality_set->end(), new_ineq) != inequality_set->end()) { // redundant: follows from the vranges // or has already been added return; @@ -168,13 +169,13 @@ void AddInequality(std::unordered_set } } - inequality_set->insert(new_ineq); + inequality_set->push_back(new_ineq); } void ClassifyByPolarity( const Var& var, - const std::unordered_set& current_ineq_set, - std::unordered_set* next_ineq_set, + const std::vector& current_ineq_set, + std::vector* next_ineq_set, std::vector* rest, std::vector>* coef_pos, std::vector>* coef_neg, Analyzer* analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var @@ -218,14 +219,14 @@ void ClassifyByPolarity( } } -void MoveEquality(std::unordered_set* upper_bounds, - std::unordered_set* lower_bounds, - std::unordered_set* equalities) { +void MoveEquality(std::vector* upper_bounds, + std::vector* lower_bounds, + std::vector* equalities) { // those exist in both upper & lower bounds will be moved to equalities for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { - auto lb = lower_bounds->find(*ub); + auto lb = std::find(lower_bounds->begin(), lower_bounds->end(), *ub); if (lb != lower_bounds->end()) { - equalities->insert(*lb); + equalities->push_back(*lb); lower_bounds->erase(lb); ub = upper_bounds->erase(ub); } else { @@ -249,8 +250,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t // and move to the next variable. // normalized inequality - std::unordered_set current_ineq_set_to_solve; - std::unordered_set next_ineq_set_to_solve; + std::vector current_ineq_set_to_solve; + std::vector next_ineq_set_to_solve; // A vector of pairs (c, e), c > 0, representing formulas of the form c*v + e <= 0 std::vector> coef_pos; // A vector of pairs (c, e), c < 0, representing formulas of the form c*v + e <= 0 @@ -321,8 +322,8 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } // The resulting lower and upper bounds - std::unordered_set upper_bounds; - std::unordered_set lower_bounds; + std::vector upper_bounds; + std::vector lower_bounds; upper_bounds.reserve(coef_pos.size()); lower_bounds.reserve(coef_neg.size()); @@ -345,7 +346,7 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } } // Add the upper bound - upper_bounds.insert(bound); + upper_bounds.push_back(bound); } for (const auto& neg : coef_neg) { PrimExpr bound = make_const(v.dtype(), -coef_lcm / neg.first) * neg.second; @@ -366,10 +367,10 @@ PartialSolvedInequalities SolveLinearInequalities(const IntConstraints& system_t } } // Add the lower bound - lower_bounds.insert(bound); + lower_bounds.push_back(bound); } - std::unordered_set equal; + std::vector equal; equal.reserve(std::min(upper_bounds.size(), lower_bounds.size())); MoveEquality(&upper_bounds, &lower_bounds, &equal); std::vector equal_list(equal.begin(), equal.end()); From 38d838c437d85d6a1df7f5e46d1898d0ebc1c7dd Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Fri, 22 Jan 2021 19:49:30 +0800 Subject: [PATCH 07/11] fix --- src/arith/solve_linear_inequality.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index a1f9f193e18b..c128229edf1a 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -150,8 +150,10 @@ class NormalizeComparisons : public ExprMutator { void AddInequality(std::vector* inequality_set, const PrimExpr& new_ineq, Analyzer* analyzer) { - if (analyzer->CanProve(new_ineq) || std::find( - inequality_set->begin(), inequality_set->end(), new_ineq) != inequality_set->end()) { + if (std::find(inequality_set->begin(), inequality_set->end(), new_ineq) != inequality_set->end()) { + return; + } + if (analyzer->CanProve(new_ineq)) { // redundant: follows from the vranges // or has already been added return; From 56eb07d0446040a999d16b7099fac409048efbe5 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Sun, 24 Jan 2021 23:15:56 +0800 Subject: [PATCH 08/11] fix find --- src/arith/solve_linear_inequality.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index c128229edf1a..6e0e449eade9 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -150,10 +150,8 @@ class NormalizeComparisons : public ExprMutator { void AddInequality(std::vector* inequality_set, const PrimExpr& new_ineq, Analyzer* analyzer) { - if (std::find(inequality_set->begin(), inequality_set->end(), new_ineq) != inequality_set->end()) { - return; - } - if (analyzer->CanProve(new_ineq)) { + if (analyzer->CanProve(new_ineq) || std::find_if(inequality_set->begin(), inequality_set->end(), + [&](const PrimExpr& e) { return StructuralEqual()(e, new_ineq); }) != inequality_set->end()) { // redundant: follows from the vranges // or has already been added return; @@ -226,7 +224,9 @@ void MoveEquality(std::vector* upper_bounds, std::vector* equalities) { // those exist in both upper & lower bounds will be moved to equalities for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { - auto lb = std::find(lower_bounds->begin(), lower_bounds->end(), *ub); + auto lb = std::find(lower_bounds->begin(), lower_bounds->end(), [&](const PrimExpr& e) { + return StructuralEqual()(e, *ub); + }); if (lb != lower_bounds->end()) { equalities->push_back(*lb); lower_bounds->erase(lb); From 455f3ffeb83e6d46c4024c67daa9f54e0712f928 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 25 Jan 2021 16:06:54 +0000 Subject: [PATCH 09/11] lint --- src/arith/solve_linear_equation.cc | 18 +----------------- tests/python/unittest/test_te_autodiff.py | 16 ---------------- 2 files changed, 1 insertion(+), 33 deletions(-) diff --git a/src/arith/solve_linear_equation.cc b/src/arith/solve_linear_equation.cc index cf67acf7e85b..d66e75d9d361 100644 --- a/src/arith/solve_linear_equation.cc +++ b/src/arith/solve_linear_equation.cc @@ -427,7 +427,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol // We have to transform ranges of the old variables into relations over new variables because // new ranges are not enough usually. - for (const auto& old_var: system_to_solve->variables) { + for (const auto& old_var : system_to_solve->variables) { if (system_to_solve->ranges.find(old_var) != system_to_solve->ranges.end()) { const Range& old_range = system_to_solve->ranges.at(old_var); PrimExpr express_by_new_vars = old_to_new_map.at(old_var); @@ -442,22 +442,6 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints& system_to_sol } } } - // for (const auto& p : system_to_solve->ranges) { - // const Var& old_var = p.first; - // const Range& old_range = p.second; - // if (old_to_new_map.count(old_var)) { - // PrimExpr express_by_new_vars = old_to_new_map[old_var]; - // PrimExpr lower_cond = analyzer_solution.Simplify(old_range->min <= express_by_new_vars); - // PrimExpr upper_cond = - // analyzer_solution.Simplify(express_by_new_vars < old_range->min + old_range->extent); - // if (!tir::is_const_int(lower_cond, 1)) { - // new_relations.push_back(lower_cond); - // } - // if (!tir::is_const_int(upper_cond, 1)) { - // new_relations.push_back(upper_cond); - // } - // } - // } // Add the rest conditions for (const PrimExpr& cond : rest) { diff --git a/tests/python/unittest/test_te_autodiff.py b/tests/python/unittest/test_te_autodiff.py index 8a5abeaee97b..6031182091fe 100644 --- a/tests/python/unittest/test_te_autodiff.py +++ b/tests/python/unittest/test_te_autodiff.py @@ -343,23 +343,7 @@ def test_reduction_init(): check_grad(B, A0) -def test_stable(): - X = te.placeholder((32, 512, 16, 16), name="X") - W = te.placeholder((1024, 512, 1, 1), name="W") - strides, padding, dilation = 2, 0, 1 - R = topi.nn.conv2d(X, W, strides, padding, dilation) - ones = topi.full_like(R, 1.0) - grads = te.gradient(R, [X], head=ones) - dag = tvm.auto_scheduler.ComputeDAG(grads) - repeat = 100 - for i in range(repeat): - grads = te.gradient(R, [X], head=ones) - new_dag = tvm.auto_scheduler.ComputeDAG(grads) - assert str(dag) == str(new_dag) - - if __name__ == "__main__": test_basic_operation() test_topi() test_stride_dilation() - test_stable() From b1f799d697de4a2229adfe818d5e16ecd3cc35b7 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 25 Jan 2021 16:19:55 +0000 Subject: [PATCH 10/11] fix find --- src/arith/solve_linear_inequality.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 6e0e449eade9..7b74d9218565 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -224,7 +224,7 @@ void MoveEquality(std::vector* upper_bounds, std::vector* equalities) { // those exist in both upper & lower bounds will be moved to equalities for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { - auto lb = std::find(lower_bounds->begin(), lower_bounds->end(), [&](const PrimExpr& e) { + auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(), [&](const PrimExpr& e) { return StructuralEqual()(e, *ub); }); if (lb != lower_bounds->end()) { From 2acd72b2fb3c0dacfa4e808eef20dd712921d7ac Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 25 Jan 2021 16:39:45 +0000 Subject: [PATCH 11/11] lint --- src/arith/solve_linear_inequality.cc | 37 +++++++++++++--------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index 7b74d9218565..dd9044833546 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -94,11 +94,10 @@ struct ExprLess { } }; -void DebugPrint( - const std::vector& current_ineq_set, - const std::vector& next_ineq_set, - const std::vector& rest, const std::vector>& coef_pos, - const std::vector>& coef_neg) { +void DebugPrint(const std::vector& current_ineq_set, + const std::vector& next_ineq_set, const std::vector& rest, + const std::vector>& coef_pos, + const std::vector>& coef_neg) { std::cout << "Current ineq set:\n["; for (auto& ineq : current_ineq_set) { std::cout << ineq << ", "; @@ -148,10 +147,12 @@ class NormalizeComparisons : public ExprMutator { arith::Analyzer analyzer_; }; -void AddInequality(std::vector* inequality_set, - const PrimExpr& new_ineq, Analyzer* analyzer) { - if (analyzer->CanProve(new_ineq) || std::find_if(inequality_set->begin(), inequality_set->end(), - [&](const PrimExpr& e) { return StructuralEqual()(e, new_ineq); }) != inequality_set->end()) { +void AddInequality(std::vector* inequality_set, const PrimExpr& new_ineq, + Analyzer* analyzer) { + if (analyzer->CanProve(new_ineq) || + std::find_if(inequality_set->begin(), inequality_set->end(), [&](const PrimExpr& e) { + return StructuralEqual()(e, new_ineq); + }) != inequality_set->end()) { // redundant: follows from the vranges // or has already been added return; @@ -172,12 +173,10 @@ void AddInequality(std::vector* inequality_set, inequality_set->push_back(new_ineq); } -void ClassifyByPolarity( - const Var& var, - const std::vector& current_ineq_set, - std::vector* next_ineq_set, - std::vector* rest, std::vector>* coef_pos, - std::vector>* coef_neg, Analyzer* analyzer) { +void ClassifyByPolarity(const Var& var, const std::vector& current_ineq_set, + std::vector* next_ineq_set, std::vector* rest, + std::vector>* coef_pos, + std::vector>* coef_neg, Analyzer* analyzer) { // Take formulas from current_ineq_set and classify them according to polarity wrt var // and store to coef_pos and coef_neg respectively. for (const PrimExpr& ineq : current_ineq_set) { @@ -219,14 +218,12 @@ void ClassifyByPolarity( } } -void MoveEquality(std::vector* upper_bounds, - std::vector* lower_bounds, +void MoveEquality(std::vector* upper_bounds, std::vector* lower_bounds, std::vector* equalities) { // those exist in both upper & lower bounds will be moved to equalities for (auto ub = upper_bounds->begin(); ub != upper_bounds->end();) { - auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(), [&](const PrimExpr& e) { - return StructuralEqual()(e, *ub); - }); + auto lb = std::find_if(lower_bounds->begin(), lower_bounds->end(), + [&](const PrimExpr& e) { return StructuralEqual()(e, *ub); }); if (lb != lower_bounds->end()) { equalities->push_back(*lb); lower_bounds->erase(lb);