From 59012d90651dea7ae17a97f3b68ccd91f13ef2de Mon Sep 17 00:00:00 2001 From: Marton Hajdu Date: Tue, 6 Aug 2024 13:26:45 +0200 Subject: [PATCH 1/6] Add default ordering comparator; create ordering comparator in separate virtual function --- Forwards.hpp | 2 +- Indexing/Index.hpp | 2 +- Inferences/ForwardDemodulation.cpp | 2 +- Kernel/KBO.cpp | 85 +---------------------------- Kernel/KBO.hpp | 80 ++++++++++++++++++++++++++- Kernel/KBOComparator.cpp | 66 ++++++++++------------ Kernel/KBOComparator.hpp | 10 ++-- Kernel/LPO.cpp | 9 +-- Kernel/LPO.hpp | 2 +- Kernel/LPOComparator.cpp | 19 ++++--- Kernel/LPOComparator.hpp | 6 +- Kernel/Ordering.cpp | 7 +++ Kernel/Ordering.hpp | 17 ++++-- Kernel/SKIKBO.cpp | 5 -- Kernel/SKIKBO.hpp | 3 - Shell/InstanceRedundancyHandler.cpp | 6 +- 16 files changed, 155 insertions(+), 166 deletions(-) diff --git a/Forwards.hpp b/Forwards.hpp index 39b528451..9bde3ea3a 100644 --- a/Forwards.hpp +++ b/Forwards.hpp @@ -111,7 +111,7 @@ class LiteralSelector; class Ordering; typedef Lib::SmartPtr OrderingSP; struct OrderingComparator; -typedef std::unique_ptr OrderingComparatorUP; +typedef std::unique_ptr OrderingComparatorUP; typedef unsigned SplitLevel; typedef const SharedSet SplitSet; diff --git a/Indexing/Index.hpp b/Indexing/Index.hpp index a325cd1b5..c28b0ea0e 100644 --- a/Indexing/Index.hpp +++ b/Indexing/Index.hpp @@ -116,7 +116,7 @@ struct TermLiteralClause struct DemodulatorData { DemodulatorData(TypedTermList term, TermList rhs, Clause* clause, bool preordered, const Ordering& ord) - : term(term), rhs(rhs), clause(clause), preordered(preordered) + : term(term), rhs(rhs), clause(clause), preordered(preordered), comparator(ord.createComparator(term, rhs)) { #if VDEBUG ASS(term.containsAllVariablesOf(rhs)); diff --git a/Inferences/ForwardDemodulation.cpp b/Inferences/ForwardDemodulation.cpp index 1b7b7ac56..ae9af4c39 100644 --- a/Inferences/ForwardDemodulation.cpp +++ b/Inferences/ForwardDemodulation.cpp @@ -173,7 +173,7 @@ bool ForwardDemodulationImpl::perform(Clause* cl, Clause* auto appl = lhs.isVar() ? (SubstApplicator*)&applWithEqSort : (SubstApplicator*)&applWithoutEqSort; if (_precompiledComparison) { - if (!preordered && (_preorderedOnly || !ordering.isGreater(lhs,rhs,appl,const_cast(qr.data->comparator)))) { + if (!preordered && (_preorderedOnly || !qr.data->comparator->check(appl))) { continue; } } else { diff --git a/Kernel/KBO.cpp b/Kernel/KBO.cpp index 691ea282f..74226d34e 100644 --- a/Kernel/KBO.cpp +++ b/Kernel/KBO.cpp @@ -38,82 +38,6 @@ using namespace std; using namespace Lib; using namespace Shell; - -/** - * Class to represent the current state of the KBO comparison. - * Based on Bernd Loechner's "Things to Know when Implementing KBO" - * (https://doi.org/10.1007/s10817-006-9031-4) - * @since 30/04/2008 flight Brussels-Tel Aviv - */ -class KBO::State -{ -public: - /** Initialise the state */ - State(KBO* kbo) - : _kbo(*kbo) - {} - - void init() - { - _weightDiff=0; - _posNum=0; - _negNum=0; - _lexResult=EQUAL; - _varDiffs.reset(); - } - - /** - * Lexicographic traversal of two terms with same top symbol, - * i.e. traversing their symbols in lockstep, as descibed in - * the Loechner et al. paper above. It performs a bidirectional - * comparison between the two terms, i.e. we can get any value - * of @b Result. - */ - Result traverseLexBidir(AppliedTerm t1, AppliedTerm t2); - /** - * Optimised, unidirectional version of @b traverseLexBidir - * where we only care about @b GREATER and @b EQUAL, otherwise - * it returns as early as possible with @b INCOMPARABLE. - */ - Result traverseLexUnidir(AppliedTerm t1, AppliedTerm t2); - /** - * Performs a non-lexicographic (i.e. non-lockstep) traversal - * of two terms in case their top symbols are not the same. - */ - template - Result traverseNonLex(AppliedTerm t1, AppliedTerm t2); - - template void traverse(AppliedTerm tt); - - Result result(AppliedTerm t1, AppliedTerm t2); -protected: - template void recordVariable(unsigned var); - - bool checkVars() const { return _negNum <= 0; } - Result innerResult(TermList t1, TermList t2); - Result applyVariableCondition(Result res) - { - if(_posNum>0 && (res==LESS || res==EQUAL)) { - res=INCOMPARABLE; - } else if(_negNum>0 && (res==GREATER || res==EQUAL)) { - res=INCOMPARABLE; - } - return res; - } - - int _weightDiff; - /** The variable counters */ - DHMap _varDiffs; - /** Number of variables, that occur more times in the first literal */ - int _posNum; - /** Number of variables, that occur more times in the second literal */ - int _negNum; - /** First comparison result */ - Result _lexResult; - /** The ordering used */ - KBO& _kbo; -}; // class KBO::State - /** * Return result of comparison between @b l1 and @b l2 under * an assumption, that @b traverse method have been called @@ -955,14 +879,9 @@ bool KBO::isGreater(AppliedTerm lhs, AppliedTerm rhs) const return isGreaterOrEq(lhs,rhs)==GREATER; } -bool KBO::isGreater(TermList lhs, TermList rhs, const SubstApplicator* applicator, OrderingComparatorUP& comparator) const +OrderingComparatorUP KBO::createComparator(TermList lhs, TermList rhs) const { - if (!comparator) { - // cout << "preprocessing " << lhs << " " << rhs << endl; - comparator = make_unique(lhs, rhs, *this); - // cout << comparator->toString() << endl; - } - return static_cast(comparator.get())->check(applicator); + return make_unique(lhs, rhs, *this); } int KBO::symbolWeight(const Term* t) const diff --git a/Kernel/KBO.hpp b/Kernel/KBO.hpp index 4cfd03f3a..92257e6c1 100644 --- a/Kernel/KBO.hpp +++ b/Kernel/KBO.hpp @@ -158,7 +158,7 @@ class KBO Result compare(AppliedTerm t1, AppliedTerm t2) const override; bool isGreater(AppliedTerm t1, AppliedTerm t2) const override; - bool isGreater(TermList lhs, TermList rhs, const SubstApplicator* applicator, OrderingComparatorUP& comparator) const override; + OrderingComparatorUP createComparator(TermList lhs, TermList rhs) const override; protected: Result isGreaterOrEq(AppliedTerm tt1, AppliedTerm tt2) const; @@ -166,7 +166,6 @@ class KBO Result comparePredicates(Literal* l1, Literal* l2) const override; - class State; friend class KBOComparator; // int functionSymbolWeight(unsigned fun) const; @@ -186,6 +185,83 @@ class KBO template void showConcrete_(std::ostream&) const; + /** + * Class to represent the current state of the KBO comparison. + * Based on Bernd Loechner's "Things to Know when Implementing KBO" + * (https://doi.org/10.1007/s10817-006-9031-4) + * @since 30/04/2008 flight Brussels-Tel Aviv + */ + class State + { + public: + /** Initialise the state */ + State(KBO* kbo) + : _kbo(*kbo) + {} + + void init() + { + _weightDiff=0; + _posNum=0; + _negNum=0; + _lexResult=EQUAL; + _varDiffs.reset(); + } + + /** + * Lexicographic traversal of two terms with same top symbol, + * i.e. traversing their symbols in lockstep, as descibed in + * the Loechner et al. paper above. It performs a bidirectional + * comparison between the two terms, i.e. we can get any value + * of @b Result. + */ + Result traverseLexBidir(AppliedTerm t1, AppliedTerm t2); + /** + * Optimised, unidirectional version of @b traverseLexBidir + * where we only care about @b GREATER and @b EQUAL, otherwise + * it returns as early as possible with @b INCOMPARABLE. + */ + Result traverseLexUnidir(AppliedTerm t1, AppliedTerm t2); + /** + * Performs a non-lexicographic (i.e. non-lockstep) traversal + * of two terms in case their top symbols are not the same. + */ + template + Result traverseNonLex(AppliedTerm t1, AppliedTerm t2); + + template void traverse(AppliedTerm tt); + + Result result(AppliedTerm t1, AppliedTerm t2); + protected: + template void recordVariable(unsigned var); + + bool checkVars() const { return _negNum <= 0; } + Result innerResult(TermList t1, TermList t2); + Result applyVariableCondition(Result res) + { + if(_posNum>0 && (res==LESS || res==EQUAL)) { + res=INCOMPARABLE; + } else if(_negNum>0 && (res==GREATER || res==EQUAL)) { + res=INCOMPARABLE; + } + return res; + } + + friend class KBOComparator; + + int _weightDiff; + /** The variable counters */ + DHMap _varDiffs; + /** Number of variables, that occur more times in the first literal */ + int _posNum; + /** Number of variables, that occur more times in the second literal */ + int _negNum; + /** First comparison result */ + Result _lexResult; + /** The ordering used */ + KBO& _kbo; + }; // class KBO::State + /** * State used for comparing terms and literals */ diff --git a/Kernel/KBOComparator.cpp b/Kernel/KBOComparator.cpp index 82a007a8c..fd1fc3c4c 100644 --- a/Kernel/KBOComparator.cpp +++ b/Kernel/KBOComparator.cpp @@ -20,12 +20,20 @@ using namespace std; using namespace Lib; using namespace Shell; -KBOComparator::KBOComparator(TermList tl1, TermList tl2, const KBO& kbo) - : _kbo(kbo), _instructions() +KBOComparator::KBOComparator(TermList lhs, TermList rhs, const KBO& kbo) + : OrderingComparator(lhs, rhs, kbo), _ready(false), _instructions() { +} + +void KBOComparator::makeReady() +{ + ASS(!_ready); + + const auto& kbo = static_cast(_ord); + // stack of subcomparisons in lexicographic order (w.r.t. tl1 and tl2) Stack> todo; - todo.push(make_pair(tl1,tl2)); + todo.push(make_pair(_lhs,_rhs)); while (todo.isNonEmpty()) { auto kv = todo.pop(); @@ -81,14 +89,14 @@ KBOComparator::KBOComparator(TermList tl1, TermList tl2, const KBO& kbo) // if both are proper terms, we calculate // weight and variable balances first - DHMap vars; - int w = 0; - countSymbols(kbo, vars, w, lhs, 1); - countSymbols(kbo, vars, w, rhs, -1); // we only care about the non-zero weights and counts bool varInbalance = false; - DHMap::Iterator vit(vars); + // TODO kbo.state could be nulled out until this + // to make sure no one overwrites the values + auto state = kbo._state; + auto w = kbo._state->_weightDiff; + decltype(state->_varDiffs)::Iterator vit(state->_varDiffs); Stack> nonzeros; while (vit.hasNext()) { unsigned v; @@ -96,6 +104,7 @@ KBOComparator::KBOComparator(TermList tl1, TermList tl2, const KBO& kbo) vit.next(v,cnt); if (cnt!=0) { nonzeros.push(make_pair(v,cnt)); + w-=cnt; // we have to remove the variable weights from w } if (cnt<0) { varInbalance = true; @@ -154,8 +163,15 @@ KBOComparator::KBOComparator(TermList tl1, TermList tl2, const KBO& kbo) } } -bool KBOComparator::check(const SubstApplicator* applicator) const +bool KBOComparator::check(const SubstApplicator* applicator) { + if (!_ready) { + makeReady(); + _ready = true; + } + + const auto& kbo = static_cast(_ord); + for (unsigned i = 0; i < _instructions.size();) { switch (static_cast(_instructions[i]._tag())) { case InstructionTag::WEIGHT: { @@ -181,7 +197,7 @@ bool KBOComparator::check(const SubstApplicator* applicator) const return false; } } - auto w = _kbo.computeWeight(tt); + auto w = kbo.computeWeight(tt); weight += coeff*w; // due to descending order of counts, // this also means failure @@ -200,7 +216,7 @@ bool KBOComparator::check(const SubstApplicator* applicator) const break; } case InstructionTag::COMPARE_VV: { - auto res = _kbo.isGreaterOrEq( + auto res = kbo.isGreaterOrEq( AppliedTerm(TermList::var(_instructions[i]._firstUint()), applicator, true), AppliedTerm(TermList::var(_instructions[i]._secondUint()), applicator, true)); if (res==Ordering::EQUAL) { @@ -211,7 +227,7 @@ bool KBOComparator::check(const SubstApplicator* applicator) const } case InstructionTag::COMPARE_VT: { ASS(_instructions[i+1]._tag()==InstructionTag::DATA); - auto res = _kbo.isGreaterOrEq( + auto res = kbo.isGreaterOrEq( AppliedTerm(TermList::var(_instructions[i]._firstUint()), applicator, true), AppliedTerm(TermList(_instructions[i+1]._term()), applicator, true)); if (res==Ordering::EQUAL) { @@ -223,7 +239,7 @@ bool KBOComparator::check(const SubstApplicator* applicator) const case InstructionTag::COMPARE_TV: { ASS(_instructions[i+1]._tag()==InstructionTag::DATA); // note that in this case the term is the second argument - auto res = _kbo.isGreaterOrEq( + auto res = kbo.isGreaterOrEq( AppliedTerm(TermList(_instructions[i+1]._term()), applicator, true), AppliedTerm(TermList::var(_instructions[i]._firstUint()), applicator, true)); if (res==Ordering::EQUAL) { @@ -242,30 +258,6 @@ bool KBOComparator::check(const SubstApplicator* applicator) const return false; } - -void KBOComparator::countSymbols(const KBO& kbo, DHMap& vars, int& w, TermList t, int coeff) -{ - if (t.isVar()) { - int* vcnt; - vars.getValuePtr(t.var(), vcnt, 0); - (*vcnt) += coeff; - return; - } - - w += coeff*kbo.symbolWeight(t.term()); - SubtermIterator sti(t.term()); - while (sti.hasNext()) { - auto st = sti.next(); - if (st.isVar()) { - int* vcnt; - vars.getValuePtr(st.var(), vcnt, 0); - (*vcnt) += coeff; - } else { - w += coeff*kbo.symbolWeight(st.term()); - } - } -} - std::string KBOComparator::toString() const { std::stringstream str; diff --git a/Kernel/KBOComparator.hpp b/Kernel/KBOComparator.hpp index e0220bbe0..1210cb8d8 100644 --- a/Kernel/KBOComparator.hpp +++ b/Kernel/KBOComparator.hpp @@ -33,16 +33,14 @@ class KBOComparator : public OrderingComparator { public: - /** The runtime specialization happens in the constructor. */ - KBOComparator(TermList tl1, TermList tl2, const KBO& kbo); + KBOComparator(TermList lhs, TermList rhs, const KBO& kbo); /** Executes the runtime specialized instructions with concrete substitution. */ - bool check(const SubstApplicator* applicator) const; + bool check(const SubstApplicator* applicator) override; std::string toString() const override; private: - // TODO this could be done with KBO::State - static void countSymbols(const KBO& kbo, DHMap& vars, int& w, TermList t, int coeff); + void makeReady(); enum InstructionTag { DATA = 0u, @@ -112,7 +110,7 @@ class KBOComparator private: uint64_t _content; }; - const KBO& _kbo; + bool _ready; Stack _instructions; }; diff --git a/Kernel/LPO.cpp b/Kernel/LPO.cpp index 82a77ab94..ccf3b56ab 100644 --- a/Kernel/LPO.cpp +++ b/Kernel/LPO.cpp @@ -247,14 +247,9 @@ Ordering::Result LPO::majo(AppliedTerm s, AppliedTerm t, const TermList* tl, uns return GREATER; } -bool LPO::isGreater(TermList lhs, TermList rhs, const SubstApplicator* applicator, OrderingComparatorUP& comparator) const +OrderingComparatorUP LPO::createComparator(TermList lhs, TermList rhs) const { - if (!comparator) { - // cout << "preprocessing " << lhs << " " << rhs << endl; - comparator = make_unique(lhs, rhs, *this); - // cout << comparator->toString() << endl; - } - return static_cast(comparator.get())->check(applicator); + return make_unique(lhs, rhs, *this); } void LPO::showConcrete(ostream&) const diff --git a/Kernel/LPO.hpp b/Kernel/LPO.hpp index 433f36501..9cb81665f 100644 --- a/Kernel/LPO.hpp +++ b/Kernel/LPO.hpp @@ -48,7 +48,7 @@ class LPO Result compare(TermList tl1, TermList tl2) const override; Result compare(AppliedTerm tl1, AppliedTerm tl2) const override; bool isGreater(AppliedTerm tl1, AppliedTerm tl2) const override; - bool isGreater(TermList lhs, TermList rhs, const SubstApplicator* applicator, OrderingComparatorUP& comparator) const override; + OrderingComparatorUP createComparator(TermList lhs, TermList rhs) const override; void showConcrete(std::ostream&) const override; diff --git a/Kernel/LPOComparator.cpp b/Kernel/LPOComparator.cpp index 5d3a2904b..67a6519e1 100644 --- a/Kernel/LPOComparator.cpp +++ b/Kernel/LPOComparator.cpp @@ -458,16 +458,21 @@ pair,BranchTag>* LPOComparator::createHelper(TermList tl1, Te return *ptr; } -LPOComparator::LPOComparator(TermList tl1, TermList tl2, const LPO& lpo) - : _lpo(lpo), _instructions(), _res() +LPOComparator::LPOComparator(TermList lhs, TermList rhs, const LPO& lpo) + : OrderingComparator(lhs, rhs, lpo), _ready(false), _instructions(), _res(BranchTag::T_JUMP) { - auto kv = createHelper(tl1, tl2, lpo); - _instructions = kv->first; - _res = kv->second; } -bool LPOComparator::check(const SubstApplicator* applicator) const +bool LPOComparator::check(const SubstApplicator* applicator) { + const auto& lpo = static_cast(_ord); + if (!_ready) { + auto kv = createHelper(_lhs, _rhs, lpo); + _res = kv->second; + _instructions = kv->first; + _ready = true; + } + // we calculate all three values in each iteration // to optimise CPU branch prediction (the values are // computed regardless and hence no branching is needed) @@ -476,7 +481,7 @@ bool LPOComparator::check(const SubstApplicator* applicator) const auto curr = _instructions.begin(); while (cond) { - auto comp = _lpo.lpo(AppliedTerm(curr->lhs,applicator,true),AppliedTerm(curr->rhs,applicator,true)); + auto comp = lpo.lpo(AppliedTerm(curr->lhs,applicator,true),AppliedTerm(curr->rhs,applicator,true)); const auto& branch = curr->getBranch(comp); cond = branch.tag == BranchTag::T_JUMP; diff --git a/Kernel/LPOComparator.hpp b/Kernel/LPOComparator.hpp index 8deb0c30b..d4c2aa4c8 100644 --- a/Kernel/LPOComparator.hpp +++ b/Kernel/LPOComparator.hpp @@ -29,10 +29,10 @@ class LPOComparator { public: /** The runtime specialization happens in the constructor. */ - LPOComparator(TermList tl1, TermList tl2, const LPO& lpo); + LPOComparator(TermList lhs, TermList rhs, const LPO& lpo); /** Executes the runtime specialized instructions with concrete substitution. */ - bool check(const SubstApplicator* applicator) const; + bool check(const SubstApplicator* applicator) override; std::string toString() const override; /** @@ -106,7 +106,7 @@ class LPOComparator static pair,Instruction::BranchTag> alphaChain(const LPO& lpo, Term* s, unsigned i, TermList tl2); static pair,Instruction::BranchTag>* createHelper(TermList tl1, TermList tl2, const LPO& lpo); - const LPO& _lpo; + bool _ready; /** This is non-empty if @b _res is @b BranchTag::T_JUMP */ Stack _instructions; diff --git a/Kernel/Ordering.cpp b/Kernel/Ordering.cpp index 62380edac..bdfadf731 100644 --- a/Kernel/Ordering.cpp +++ b/Kernel/Ordering.cpp @@ -49,6 +49,13 @@ using namespace std; using namespace Lib; using namespace Kernel; +bool OrderingComparator::check(const SubstApplicator* applicator) +{ + return _ord.isGreater( + AppliedTerm(_lhs, applicator, /*aboveVar*/true), + AppliedTerm(_rhs, applicator, /*aboveVar*/true)); +} + OrderingSP Ordering::s_globalOrdering; /** diff --git a/Kernel/Ordering.hpp b/Kernel/Ordering.hpp index 9440e2a0a..9a573740c 100644 --- a/Kernel/Ordering.hpp +++ b/Kernel/Ordering.hpp @@ -41,8 +41,14 @@ using namespace Shell; */ struct OrderingComparator { + OrderingComparator(TermList lhs, TermList rhs, const Ordering& ord) : _lhs(lhs), _rhs(rhs), _ord(ord) {} virtual ~OrderingComparator() = default; - virtual std::string toString() const = 0; + virtual std::string toString() const { return _lhs.toString()+" > "+_rhs.toString(); } + virtual bool check(const SubstApplicator* applicator); + + TermList _lhs; + TermList _rhs; + const Ordering& _ord; }; /** @@ -84,11 +90,10 @@ class Ordering virtual bool isGreater(AppliedTerm t1, AppliedTerm t2) const { return compare(t1, t2) == Result::GREATER; } - /** Optimised function used for checking that @b lhs is greater than @b rhs, - * under substitution represented by @b applicator. */ - virtual bool isGreater(TermList lhs, TermList rhs, const SubstApplicator* applicator, OrderingComparatorUP& comparator) const - { return isGreater(AppliedTerm(lhs, applicator, /* aboveVar */ true), - AppliedTerm(rhs, applicator, /* aboveVar */ true)); } + /** Creates optimised object to check that @b lhs is greater than @b rhs. + * @see OrderingComparator. */ + virtual OrderingComparatorUP createComparator(TermList lhs, TermList rhs) const + { return std::make_unique(lhs, rhs, *this); } virtual void show(std::ostream& out) const = 0; diff --git a/Kernel/SKIKBO.cpp b/Kernel/SKIKBO.cpp index ba193ce3b..7c730737f 100644 --- a/Kernel/SKIKBO.cpp +++ b/Kernel/SKIKBO.cpp @@ -627,11 +627,6 @@ Ordering::Result SKIKBO::compare(TermList tl1, TermList tl2) const return res; } -bool SKIKBO::isGreater(TermList lhs, TermList rhs, const SubstApplicator* applicator, OrderingComparatorUP& comparator) const -{ - NOT_IMPLEMENTED; -} - int SKIKBO::symbolWeight(Term* t) const { if (t->isSort()){ diff --git a/Kernel/SKIKBO.hpp b/Kernel/SKIKBO.hpp index 11e673f50..0d0c106ae 100644 --- a/Kernel/SKIKBO.hpp +++ b/Kernel/SKIKBO.hpp @@ -63,9 +63,6 @@ class SKIKBO using PrecedenceOrdering::compare; Result compare(TermList tl1, TermList tl2) const override; - Result compare(AppliedTerm tl1, AppliedTerm tl2) const override { NOT_IMPLEMENTED; } - bool isGreater(AppliedTerm tl1, AppliedTerm tl2) const override { NOT_IMPLEMENTED; } - bool isGreater(TermList lhs, TermList rhs, const SubstApplicator* applicator, OrderingComparatorUP& comparator) const override; static unsigned maximumReductionLength(Term* t); static TermList reduce(TermStack& args, TermList& head); diff --git a/Shell/InstanceRedundancyHandler.cpp b/Shell/InstanceRedundancyHandler.cpp index a54597601..947cf8872 100644 --- a/Shell/InstanceRedundancyHandler.cpp +++ b/Shell/InstanceRedundancyHandler.cpp @@ -124,9 +124,9 @@ class InstanceRedundancyHandler::SubstitutionCoverTree TermList operator()(unsigned v) const override { return matcher.bindings[v]; } } applicator; - if (ord->isGreater(TermList(ld->lhs),TermList(ld->rhs),&applicator,ld->comp)) { - return true; - } + // if (ord->isGreater(TermList(ld->lhs),TermList(ld->rhs),&applicator,ld->comp)) { + // return true; + // } } } matcher.reset(); From 750c1f6317c6678d936c486b854c91fcc1d77ced Mon Sep 17 00:00:00 2001 From: Marton Hajdu Date: Wed, 7 Aug 2024 10:10:20 +0200 Subject: [PATCH 2/6] Add second LPOComparator implementation which preprocesses lazily; memory deallocation not yet handled --- Kernel/LPO.cpp | 3 +- Kernel/LPO.hpp | 1 + Kernel/LPOComparator.cpp | 199 +++++++++++++++++++++++++++++++++++++++ Kernel/LPOComparator.hpp | 65 +++++++++++++ 4 files changed, 267 insertions(+), 1 deletion(-) diff --git a/Kernel/LPO.cpp b/Kernel/LPO.cpp index ccf3b56ab..65057fdf3 100644 --- a/Kernel/LPO.cpp +++ b/Kernel/LPO.cpp @@ -249,7 +249,8 @@ Ordering::Result LPO::majo(AppliedTerm s, AppliedTerm t, const TermList* tl, uns OrderingComparatorUP LPO::createComparator(TermList lhs, TermList rhs) const { - return make_unique(lhs, rhs, *this); + // return make_unique(lhs, rhs, *this); + return make_unique(lhs, rhs, *this); } void LPO::showConcrete(ostream&) const diff --git a/Kernel/LPO.hpp b/Kernel/LPO.hpp index 9cb81665f..01d083ff6 100644 --- a/Kernel/LPO.hpp +++ b/Kernel/LPO.hpp @@ -66,6 +66,7 @@ class LPO Result majo(AppliedTerm s, AppliedTerm t, const TermList* tl, unsigned arity) const; friend class LPOComparator; + friend class LPOComparator2; }; } diff --git a/Kernel/LPOComparator.cpp b/Kernel/LPOComparator.cpp index 67a6519e1..5755ddda9 100644 --- a/Kernel/LPOComparator.cpp +++ b/Kernel/LPOComparator.cpp @@ -491,4 +491,203 @@ bool LPOComparator::check(const SubstApplicator* applicator) return res; } +// LPOComparator2 + +LPOComparator2::LPOComparator2(TermList lhs, TermList rhs, const LPO& lpo) + : OrderingComparator(lhs, rhs, lpo), _root(new Node(lhs, rhs)) +{ +} + +LPOComparator2::~LPOComparator2() +{ + delete _root.n; +} + +ostream& operator<<(ostream& out, const LPOComparator2::BranchTag& t) +{ + switch (t) { + case LPOComparator2::BranchTag::T_EQUAL: + out << "="; + break; + case LPOComparator2::BranchTag::T_GREATER: + out << ">"; + break; + case LPOComparator2::BranchTag::T_INCOMPARABLE: + out << "!"; + break; + case LPOComparator2::BranchTag::T_COMPARISON: + out << "$"; + break; + case LPOComparator2::BranchTag::T_UNKNOWN: + out << "?"; + break; + } + return out; +} + +std::ostream& operator<<(std::ostream& str, const LPOComparator2::Branch& branch) +{ + str << branch.tag << " "; + if (branch.n) { + str << branch.n->lhs << " " << branch.n->rhs; + } else { + str << "null"; + } + return str; +} + +ostream& operator<<(ostream& str, const LPOComparator2& comp) +{ + str << "comparator for " << comp._lhs << " > " << comp._rhs << endl; + Stack> todo; + todo.push(make_pair(comp._root,0)); + unsigned cnt = 1; + + while (todo.isNonEmpty()) { + auto kv = todo.pop(); + str << cnt++ << " "; + for (unsigned i = 0; i < kv.second; i++) { + str << " "; + } + str << kv.first << endl; + if (kv.first.n) { + todo.push(make_pair(kv.first.n->incBranch,kv.second+1)); + todo.push(make_pair(kv.first.n->gtBranch,kv.second+1)); + todo.push(make_pair(kv.first.n->eqBranch,kv.second+1)); + } + } + return str; +} + +std::string LPOComparator2::toString() const +{ + std::stringstream str; + str << *this << endl; + return str.str(); +} + +/** + * Implements an @b LPO::majo call via instructions. + */ +void LPOComparator2::majoChain(Branch* branch, TermList tl1, Term* t, unsigned i, Branch success, Branch fail) +{ + for (unsigned j = i; j < t->arity(); j++) { + *branch = Branch(new Node(tl1,*t->nthArgument(j))); + branch->n->eqBranch = fail; + branch->n->incBranch = fail; + branch = &branch->n->gtBranch; + } + *branch = success; +} + +/** + * Implements an @b LPO::alpha call via instructions. + */ +void LPOComparator2::alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2, Branch success, Branch fail) +{ + for (unsigned j = i; j < s->arity(); j++) { + *branch = Branch(new Node(*s->nthArgument(j),tl2)); + branch->n->eqBranch = success; + branch->n->gtBranch = success; + branch = &branch->n->incBranch; + } + *branch = fail; +} + +void LPOComparator2::expand(Branch& branch, const LPO& lpo) +{ + while (branch.tag == BranchTag::T_UNKNOWN) { + // Use compare here to filter out as many + // precomputable comparisons as possible. + auto node = branch.n; + auto comp = lpo.compare(node->lhs,node->rhs); + if (comp != Ordering::INCOMPARABLE) { + if (comp == Ordering::LESS) { + branch = node->incBranch; + } else if (comp == Ordering::GREATER) { + branch = node->gtBranch; + } else { + branch = node->eqBranch; + } + continue; + } + // If we have a variable, we cannot preprocess further. + if (node->lhs.isVar() || node->rhs.isVar()) { + branch.tag = BranchTag::T_COMPARISON; + continue; + } + + auto s = node->lhs.term(); + auto t = node->rhs.term(); + + // preserve original branches + auto eqBranch = node->eqBranch; + auto gtBranch = node->gtBranch; + auto incBranch = node->incBranch; + + switch (lpo.comparePrecedences(s, t)) { + case Ordering::EQUAL: { + ASS(s->arity()); // constants cannot be incomparable + + // copies for unification + auto lhs = node->lhs; + auto rhs = node->rhs; + + // TODO release branch->n somewhere + auto curr = &branch; + + // lexicographic comparisons + for (unsigned i = 0; i < s->arity(); i++) + { + auto s_arg = *lhs.term()->nthArgument(i); + auto t_arg = *rhs.term()->nthArgument(i); + *curr = Branch(new Node(s_arg,t_arg)); + // greater branch is a majo chain + majoChain(&curr->n->gtBranch, lhs, rhs.term(), i+1, gtBranch, incBranch); + // incomparable branch is an alpha chain + alphaChain(&curr->n->incBranch, lhs.term(), i+1, rhs, gtBranch, incBranch); + curr = &curr->n->eqBranch; + if (!unify(s_arg,t_arg,lhs,rhs)) { + *curr = incBranch; + goto loop_end; + } + } + *curr = eqBranch; + break; + } + case Ordering::GREATER: { + ASS(t->arity()); + majoChain(&branch, node->lhs, t, 0, gtBranch, incBranch); + break; + } + case Ordering::LESS: { + ASS(s->arity()); + alphaChain(&branch, s, 0, node->rhs, gtBranch, incBranch); + break; + } + default: + ASSERTION_VIOLATION; + } +loop_end: + continue; + } +} + +bool LPOComparator2::check(const SubstApplicator* applicator) +{ + const auto& lpo = static_cast(_ord); + auto curr = &_root; + + while (curr->tag == BranchTag::T_COMPARISON || curr->tag == BranchTag::T_UNKNOWN) { + expand(*curr, lpo); + if (curr->tag != BranchTag::T_COMPARISON) { + break; + } + ASS(curr->n); + auto comp = lpo.lpo(AppliedTerm(curr->n->lhs,applicator,true),AppliedTerm(curr->n->rhs,applicator,true)); + curr = &curr->n->getBranch(comp); + } + return curr->tag == BranchTag::T_GREATER; +} + } diff --git a/Kernel/LPOComparator.hpp b/Kernel/LPOComparator.hpp index d4c2aa4c8..5cf5786bc 100644 --- a/Kernel/LPOComparator.hpp +++ b/Kernel/LPOComparator.hpp @@ -117,5 +117,70 @@ class LPOComparator Instruction::BranchTag _res; }; +class LPOComparator2 +: public OrderingComparator +{ +public: + /** The runtime specialization happens in the constructor. */ + LPOComparator2(TermList lhs, TermList rhs, const LPO& lpo); + ~LPOComparator2() override; + + /** Executes the runtime specialized instructions with concrete substitution. */ + bool check(const SubstApplicator* applicator) override; + std::string toString() const override; + + enum class BranchTag : uint8_t { + T_EQUAL, + T_GREATER, + T_INCOMPARABLE, + T_COMPARISON, + T_UNKNOWN, + }; + + struct Node; + + struct Branch { + BranchTag tag; + Node* n; + + explicit Branch(BranchTag t) : tag(t), n(nullptr) {} + explicit Branch(Node* n) : tag(BranchTag::T_UNKNOWN), n(n) {} + }; + + struct Node { + Node(TermList lhs, TermList rhs) + : lhs(lhs), rhs(rhs), eqBranch(BranchTag::T_EQUAL), gtBranch(BranchTag::T_GREATER), incBranch(BranchTag::T_INCOMPARABLE) {} + + auto& getBranch(Ordering::Result r) { + switch (r) { + case Ordering::EQUAL: + return eqBranch; + case Ordering::GREATER: + return gtBranch; + case Ordering::INCOMPARABLE: + return incBranch; + default: + ASSERTION_VIOLATION; + } + } + + TermList lhs; + TermList rhs; + Branch eqBranch; + Branch gtBranch; + Branch incBranch; + }; + + +private: + static void majoChain(Branch* branch, TermList tl1, Term* t, unsigned i, Branch success, Branch fail); + static void alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2, Branch success, Branch fail); + static void expand(Branch& branch, const LPO& lpo); + + friend ostream& operator<<(ostream& str, const LPOComparator2& comp); + + Branch _root; +}; + } #endif From 7b8c313fdf3f0d2cd5ca701ab4e238167faa2d5a Mon Sep 17 00:00:00 2001 From: Marton Hajdu Date: Wed, 7 Aug 2024 11:48:09 +0200 Subject: [PATCH 3/6] Release nodes properly; add some assertions --- Kernel/LPOComparator.cpp | 132 ++++++++++++++++++--------------------- Kernel/LPOComparator.hpp | 34 ++++++++-- 2 files changed, 90 insertions(+), 76 deletions(-) diff --git a/Kernel/LPOComparator.cpp b/Kernel/LPOComparator.cpp index 5755ddda9..ef32415df 100644 --- a/Kernel/LPOComparator.cpp +++ b/Kernel/LPOComparator.cpp @@ -498,23 +498,15 @@ LPOComparator2::LPOComparator2(TermList lhs, TermList rhs, const LPO& lpo) { } -LPOComparator2::~LPOComparator2() -{ - delete _root.n; -} - ostream& operator<<(ostream& out, const LPOComparator2::BranchTag& t) { switch (t) { - case LPOComparator2::BranchTag::T_EQUAL: - out << "="; + case LPOComparator2::BranchTag::T_NOT_GREATER: + out << "!>"; break; case LPOComparator2::BranchTag::T_GREATER: out << ">"; break; - case LPOComparator2::BranchTag::T_INCOMPARABLE: - out << "!"; - break; case LPOComparator2::BranchTag::T_COMPARISON: out << "$"; break; @@ -539,8 +531,8 @@ std::ostream& operator<<(std::ostream& str, const LPOComparator2::Branch& branch ostream& operator<<(ostream& str, const LPOComparator2& comp) { str << "comparator for " << comp._lhs << " > " << comp._rhs << endl; - Stack> todo; - todo.push(make_pair(comp._root,0)); + Stack> todo; + todo.push(make_pair(&comp._root,0)); unsigned cnt = 1; while (todo.isNonEmpty()) { @@ -549,11 +541,11 @@ ostream& operator<<(ostream& str, const LPOComparator2& comp) for (unsigned i = 0; i < kv.second; i++) { str << " "; } - str << kv.first << endl; - if (kv.first.n) { - todo.push(make_pair(kv.first.n->incBranch,kv.second+1)); - todo.push(make_pair(kv.first.n->gtBranch,kv.second+1)); - todo.push(make_pair(kv.first.n->eqBranch,kv.second+1)); + str << *kv.first << endl; + if (kv.first->n) { + todo.push(make_pair(&kv.first->n->incBranch,kv.second+1)); + todo.push(make_pair(&kv.first->n->gtBranch,kv.second+1)); + todo.push(make_pair(&kv.first->n->eqBranch,kv.second+1)); } } return str; @@ -571,6 +563,7 @@ std::string LPOComparator2::toString() const */ void LPOComparator2::majoChain(Branch* branch, TermList tl1, Term* t, unsigned i, Branch success, Branch fail) { + ASS(branch); for (unsigned j = i; j < t->arity(); j++) { *branch = Branch(new Node(tl1,*t->nthArgument(j))); branch->n->eqBranch = fail; @@ -585,6 +578,7 @@ void LPOComparator2::majoChain(Branch* branch, TermList tl1, Term* t, unsigned i */ void LPOComparator2::alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2, Branch success, Branch fail) { + ASS(branch); for (unsigned j = i; j < s->arity(); j++) { *branch = Branch(new Node(*s->nthArgument(j),tl2)); branch->n->eqBranch = success; @@ -600,6 +594,8 @@ void LPOComparator2::expand(Branch& branch, const LPO& lpo) // Use compare here to filter out as many // precomputable comparisons as possible. auto node = branch.n; + node->acquire(); + auto comp = lpo.compare(node->lhs,node->rhs); if (comp != Ordering::INCOMPARABLE) { if (comp == Ordering::LESS) { @@ -609,67 +605,63 @@ void LPOComparator2::expand(Branch& branch, const LPO& lpo) } else { branch = node->eqBranch; } - continue; + goto loop_end; } // If we have a variable, we cannot preprocess further. if (node->lhs.isVar() || node->rhs.isVar()) { branch.tag = BranchTag::T_COMPARISON; - continue; + goto loop_end; } - auto s = node->lhs.term(); - auto t = node->rhs.term(); - - // preserve original branches - auto eqBranch = node->eqBranch; - auto gtBranch = node->gtBranch; - auto incBranch = node->incBranch; - - switch (lpo.comparePrecedences(s, t)) { - case Ordering::EQUAL: { - ASS(s->arity()); // constants cannot be incomparable - - // copies for unification - auto lhs = node->lhs; - auto rhs = node->rhs; - - // TODO release branch->n somewhere - auto curr = &branch; - - // lexicographic comparisons - for (unsigned i = 0; i < s->arity(); i++) - { - auto s_arg = *lhs.term()->nthArgument(i); - auto t_arg = *rhs.term()->nthArgument(i); - *curr = Branch(new Node(s_arg,t_arg)); - // greater branch is a majo chain - majoChain(&curr->n->gtBranch, lhs, rhs.term(), i+1, gtBranch, incBranch); - // incomparable branch is an alpha chain - alphaChain(&curr->n->incBranch, lhs.term(), i+1, rhs, gtBranch, incBranch); - curr = &curr->n->eqBranch; - if (!unify(s_arg,t_arg,lhs,rhs)) { - *curr = incBranch; - goto loop_end; + { + auto s = node->lhs.term(); + auto t = node->rhs.term(); + + switch (lpo.comparePrecedences(s, t)) { + case Ordering::EQUAL: { + ASS(s->arity()); // constants cannot be incomparable + + // copies for unification + auto lhs = node->lhs; + auto rhs = node->rhs; + + auto curr = &branch; + + // lexicographic comparisons + for (unsigned i = 0; i < s->arity(); i++) + { + auto s_arg = *lhs.term()->nthArgument(i); + auto t_arg = *rhs.term()->nthArgument(i); + *curr = Branch(new Node(s_arg,t_arg)); + // greater branch is a majo chain + majoChain(&curr->n->gtBranch, lhs, rhs.term(), i+1, node->gtBranch, node->incBranch); + // incomparable branch is an alpha chain + alphaChain(&curr->n->incBranch, lhs.term(), i+1, rhs, node->gtBranch, node->incBranch); + curr = &curr->n->eqBranch; + if (!unify(s_arg,t_arg,lhs,rhs)) { + *curr = node->incBranch; + goto loop_end; + } } + *curr = node->eqBranch; + break; + } + case Ordering::GREATER: { + ASS(t->arity()); + majoChain(&branch, node->lhs, t, 0, node->gtBranch, node->incBranch); + break; + } + case Ordering::LESS: { + ASS(s->arity()); + alphaChain(&branch, s, 0, node->rhs, node->gtBranch, node->incBranch); + break; + } + default: + ASSERTION_VIOLATION; } - *curr = eqBranch; - break; - } - case Ordering::GREATER: { - ASS(t->arity()); - majoChain(&branch, node->lhs, t, 0, gtBranch, incBranch); - break; - } - case Ordering::LESS: { - ASS(s->arity()); - alphaChain(&branch, s, 0, node->rhs, gtBranch, incBranch); - break; - } - default: - ASSERTION_VIOLATION; } loop_end: - continue; + node->release(); } } @@ -678,9 +670,9 @@ bool LPOComparator2::check(const SubstApplicator* applicator) const auto& lpo = static_cast(_ord); auto curr = &_root; - while (curr->tag == BranchTag::T_COMPARISON || curr->tag == BranchTag::T_UNKNOWN) { + while (curr->n) { expand(*curr, lpo); - if (curr->tag != BranchTag::T_COMPARISON) { + if (!curr->n) { break; } ASS(curr->n); diff --git a/Kernel/LPOComparator.hpp b/Kernel/LPOComparator.hpp index 5cf5786bc..831461dee 100644 --- a/Kernel/LPOComparator.hpp +++ b/Kernel/LPOComparator.hpp @@ -123,16 +123,14 @@ class LPOComparator2 public: /** The runtime specialization happens in the constructor. */ LPOComparator2(TermList lhs, TermList rhs, const LPO& lpo); - ~LPOComparator2() override; /** Executes the runtime specialized instructions with concrete substitution. */ bool check(const SubstApplicator* applicator) override; std::string toString() const override; enum class BranchTag : uint8_t { - T_EQUAL, T_GREATER, - T_INCOMPARABLE, + T_NOT_GREATER, T_COMPARISON, T_UNKNOWN, }; @@ -143,13 +141,36 @@ class LPOComparator2 BranchTag tag; Node* n; - explicit Branch(BranchTag t) : tag(t), n(nullptr) {} - explicit Branch(Node* n) : tag(BranchTag::T_UNKNOWN), n(n) {} + explicit Branch(BranchTag t) : tag(t), n(nullptr) { ASS(t==BranchTag::T_GREATER || t==BranchTag::T_NOT_GREATER); } + explicit Branch(Node* n) : tag(BranchTag::T_UNKNOWN), n(n) { ASS(n); n->acquire(); } + Branch(const Branch& other) : tag(other.tag), n(other.n) { if (n) { n->acquire(); } } + ~Branch() { if (n) { n->release(); } } + Branch& operator=(const Branch& other) { + if (this != &other) { + tag = other.tag; + if (n) { n->release(); } + n = other.n; + if (n) { n->acquire(); } + } + return *this; + } }; struct Node { Node(TermList lhs, TermList rhs) - : lhs(lhs), rhs(rhs), eqBranch(BranchTag::T_EQUAL), gtBranch(BranchTag::T_GREATER), incBranch(BranchTag::T_INCOMPARABLE) {} + : refcnt(0), lhs(lhs), rhs(rhs), eqBranch(BranchTag::T_NOT_GREATER), gtBranch(BranchTag::T_GREATER), incBranch(BranchTag::T_NOT_GREATER) {} + + void acquire() { + refcnt++; + } + + void release() { + ASS(refcnt); + refcnt--; + if (!refcnt) { + delete this; + } + } auto& getBranch(Ordering::Result r) { switch (r) { @@ -164,6 +185,7 @@ class LPOComparator2 } } + unsigned refcnt; TermList lhs; TermList rhs; Branch eqBranch; From 1c235ee62eda61f37df3fdc8ba1090785ef14a32 Mon Sep 17 00:00:00 2001 From: Marton Hajdu Date: Wed, 4 Sep 2024 16:00:05 +0200 Subject: [PATCH 4/6] Clean up branch --- Kernel/KBOComparator.cpp | 15 +- Kernel/LPO.cpp | 3 +- Kernel/LPOComparator.cpp | 471 ++------------------------------------- Kernel/LPOComparator.hpp | 103 +-------- 4 files changed, 28 insertions(+), 564 deletions(-) diff --git a/Kernel/KBOComparator.cpp b/Kernel/KBOComparator.cpp index fd1fc3c4c..66db8574b 100644 --- a/Kernel/KBOComparator.cpp +++ b/Kernel/KBOComparator.cpp @@ -92,10 +92,12 @@ void KBOComparator::makeReady() // we only care about the non-zero weights and counts bool varInbalance = false; - // TODO kbo.state could be nulled out until this - // to make sure no one overwrites the values auto state = kbo._state; - auto w = kbo._state->_weightDiff; +#if VDEBUG + // we make sure kbo._state is not used while we're using it + kbo._state = nullptr; +#endif + auto w = state->_weightDiff; decltype(state->_varDiffs)::Iterator vit(state->_varDiffs); Stack> nonzeros; while (vit.hasNext()) { @@ -110,6 +112,11 @@ void KBOComparator::makeReady() varInbalance = true; } } +#if VDEBUG + kbo._state = state; + state = nullptr; +#endif + // if the condition below does not hold, the weight/var balances are satisfied if (w < 0 || varInbalance) { // reinterpret weight here to unsigned because the compiler might not do it @@ -185,7 +192,7 @@ bool KBOComparator::check(const SubstApplicator* applicator) auto var = _instructions[j]._firstUint(); auto coeff = _instructions[j]._coeff(); - AppliedTerm tt(TermList(var,false), applicator, true); + AppliedTerm tt(TermList::var(var), applicator, true); VariableIterator vit(tt.term); while (vit.hasNext()) { diff --git a/Kernel/LPO.cpp b/Kernel/LPO.cpp index 65057fdf3..ccf3b56ab 100644 --- a/Kernel/LPO.cpp +++ b/Kernel/LPO.cpp @@ -249,8 +249,7 @@ Ordering::Result LPO::majo(AppliedTerm s, AppliedTerm t, const TermList* tl, uns OrderingComparatorUP LPO::createComparator(TermList lhs, TermList rhs) const { - // return make_unique(lhs, rhs, *this); - return make_unique(lhs, rhs, *this); + return make_unique(lhs, rhs, *this); } void LPO::showConcrete(ostream&) const diff --git a/Kernel/LPOComparator.cpp b/Kernel/LPOComparator.cpp index 34c024a79..c36fbecf4 100644 --- a/Kernel/LPOComparator.cpp +++ b/Kernel/LPOComparator.cpp @@ -18,10 +18,6 @@ using namespace std; using namespace Lib; using namespace Shell; -using Instruction = LPOComparator::Instruction; -using Branch = Instruction::Branch; -using BranchTag = Instruction::BranchTag; - bool unify(TermList tl1, TermList tl2, TermList& orig1, TermList& orig2) { RobSubstitution rsubst; @@ -52,472 +48,31 @@ bool unify(TermList tl1, TermList tl2, TermList& orig1, TermList& orig2) return true; } -ostream& operator<<(ostream& out, const Branch& b) -{ - switch (b.tag) { - case BranchTag::T_EQUAL: - out << "="; - break; - case BranchTag::T_GREATER: - out << ">"; - break; - case BranchTag::T_INCOMPARABLE: - out << "?"; - break; - case BranchTag::T_JUMP: - out << b.jump_pos; - break; - } - return out; -} - -ostream& operator<<(ostream& out, const Instruction& n) -{ - out << "instr " << n.lhs << " " << n.rhs << " " << n.bs[0] << " " << n.bs[1] << " " << n.bs[2]; - return out; -} - -std::string LPOComparator::toString() const -{ - std::stringstream str; - switch (_res) { - case BranchTag::T_EQUAL: - str << "equal" << endl; - break; - case BranchTag::T_GREATER: - str << "greater" << endl; - break; - case BranchTag::T_INCOMPARABLE: - str << "incomparable" << endl; - break; - case BranchTag::T_JUMP: - for (unsigned i = 0; i < _instructions.size(); i++) { - str << i << " " << _instructions[i] << endl; - } - break; - } - return str.str(); -} - -void Branch::update(Branch eqBranch, Branch gtBranch, Branch incBranch, unsigned jump_offset) -{ - switch (tag) { - case BranchTag::T_EQUAL: - *this = eqBranch; - break; - case BranchTag::T_GREATER: - *this = gtBranch; - break; - case BranchTag::T_INCOMPARABLE: - *this = incBranch; - break; - case BranchTag::T_JUMP: - jump_pos += jump_offset; - break; - } -} - -void updateBranchInRange(Stack& st, unsigned startIndex, unsigned endIndex, Branch prevBranch, Branch newBranch) -{ - for (unsigned i = startIndex; i < endIndex; i++) { - for (unsigned j = 0; j < 3; j++) { - if (st[i].bs[j] == prevBranch) { - st[i].bs[j] = newBranch; - } - } - } -} - -/** - * Pushes instructions to the end of an instruction stack, while replacing each - * equal/greater/incomparable branch with @b eqBranch, @b gtBranch, @b incBranch, - * respectively, and adding an appropriate offset to any jump operation. - */ -void pushInstructions(Stack& st, const Stack& other, Branch eqBranch, Branch gtBranch, Branch incBranch) -{ - auto startIndex = st.size(); - for (const auto& n : other) { - st.push(n); - for (unsigned j = 0; j < 3; j++) { - st.top().bs[j].update(eqBranch, gtBranch, incBranch, startIndex); - } - } -} - -void deleteDuplicates(Stack& st) -{ - unsigned removedCnt = 0; - Map lastPos; - std::vector removedAfter(st.size(),0); - - // First pass, remember the last position of - // any duplicate instruction, and update every - // jump to the respective last instruction. - // Also, save how many elements we marked removed - // after each not-removed element. - for (int i = st.size()-1; i >= 0; i--) { - auto& curr = st[i]; - for (unsigned j = 0; j < 3; j++) { - if (curr.bs[j].tag == BranchTag::T_JUMP) { - auto& jpos = curr.bs[j].jump_pos; - jpos = lastPos.get(st[jpos]); - } - } - - unsigned* ptr; - if (lastPos.getValuePtr(curr, ptr, i)) { - removedAfter[i] = removedCnt; - } else { - removedCnt++; - } - } - if (!removedCnt) { - return; - } - // The first instruction should be completely - // unique, otherwise we would be looping. - ASS_EQ(lastPos.get(st[0]),0); - - // Second pass, create the resulting stack - // without duplicates, and doing so apply - // appropriate offsets to each jump where we - // removed elements in between. - Stack res; - for (unsigned i = 0; i < st.size(); i++) { - auto curr = st[i]; - if (lastPos.get(curr)!=i) { - continue; - } - for (unsigned j = 0; j < 3; j++) { - if (curr.bs[j].tag == BranchTag::T_JUMP) { - auto& jpos = curr.bs[j].jump_pos; - jpos -= removedAfter[i]-removedAfter[jpos]; - jpos -= i-res.size(); - } - } - res.push(curr); - } - swap(res,st); -} - -#define INDEX_UNINITIALIZED -1 - -/** - * Implements an @b LPO::majo call via instructions. - */ -pair,BranchTag> LPOComparator::majoChain(const LPO& lpo, TermList tl1, Term* t, unsigned i) -{ - Stack res; - int prevIndex = INDEX_UNINITIALIZED; - for (unsigned j = i; j < t->arity(); j++) { - auto compRes = createHelper(tl1,*t->nthArgument(j),lpo); - // If the comparison is 'greater', we skip this iteration. - if (compRes->second == BranchTag::T_GREATER) { - continue; - } - // Otherwise, if the comparison is 'equal' or 'incomparable', - // we can return early with a negative result. - if (compRes->second != BranchTag::T_JUMP) { - res.reset(); - return make_pair(res,BranchTag::T_INCOMPARABLE); - } - // Update previous 'greater' branches - // to a jump to the current index. - if (prevIndex != INDEX_UNINITIALIZED) { - updateBranchInRange(res, prevIndex, (unsigned)res.size(), - Branch::gt(), Branch::jump(res.size())); - } - prevIndex = res.size(); - // Push the comparison instructions and replace each - // 'equal' with 'incomparable' in them. - pushInstructions(res, compRes->first, Branch::inc(), Branch::gt(), Branch::inc()); - } - return make_pair(res,res.isEmpty() ? BranchTag::T_GREATER : BranchTag::T_JUMP); -} - -/** - * Implements an @b LPO::alpha call via instructions. - */ -pair,BranchTag> LPOComparator::alphaChain(const LPO& lpo, Term* s, unsigned i, TermList tl2) -{ - Stack res; - int prevIndex = INDEX_UNINITIALIZED; - for (unsigned j = i; j < s->arity(); j++) { - auto compRes = createHelper(*s->nthArgument(j),tl2,lpo); - // If the comparison is 'incomparable', we skip this iteration. - if (compRes->second == BranchTag::T_INCOMPARABLE) { - continue; - } - // Otherwise, if the comparison is 'greater' or 'equal', - // we can return early with a positive result. - if (compRes->second != BranchTag::T_JUMP) { - res.reset(); - return make_pair(res,BranchTag::T_GREATER); - } - // Update previous 'incomparable' branches - // to a jump to the current index. - if (prevIndex != INDEX_UNINITIALIZED) { - updateBranchInRange(res, prevIndex, (unsigned)res.size(), - Branch::inc(), Branch::jump(res.size())); - } - prevIndex = res.size(); - // Push the comparison instructions and replace each - // 'equal' with 'greater' in them. - pushInstructions(res, std::move(compRes->first), Branch::gt(), Branch::gt(), Branch::inc()); - } - return make_pair(res,res.isEmpty() ? BranchTag::T_INCOMPARABLE : BranchTag::T_JUMP); -} - -pair,BranchTag>* LPOComparator::createHelper(TermList tl1, TermList tl2, const LPO& lpo) -{ - static DHMap,pair,Instruction::BranchTag>*> _cache; - - pair,BranchTag>** ptr; - // We have a local cache for subresults to avoid too much computation. - if (!_cache.getValuePtr(make_pair(tl1,tl2),ptr)) { - return *ptr; - } - // Allocate on heap so that cache reallocation - // won't affect partial results. - auto res = new pair(Stack(), BranchTag::T_JUMP); - *ptr = res; - - // Use compare here to filter out as many - // precomputable comparisons as possible. - auto comp = lpo.compare(tl1,tl2); - if (comp != Ordering::INCOMPARABLE) { - if (comp == Ordering::LESS) { - (*ptr)->second = BranchTag::T_INCOMPARABLE; - } else if (comp == Ordering::GREATER) { - (*ptr)->second = BranchTag::T_GREATER; - } else { - (*ptr)->second = BranchTag::T_EQUAL; - } - return *ptr; - } - // If we have a variable, we cannot preprocess further. - if (tl1.isVar() || tl2.isVar()) { - (*ptr)->first.push(Instruction(tl1,tl2)); - (*ptr)->second = BranchTag::T_JUMP; - return *ptr; - } - - auto s = tl1.term(); - auto t = tl2.term(); - - switch (lpo.comparePrecedences(s, t)) { - case Ordering::EQUAL: { - ASS(s->arity()); // constants cannot be incomparable - - int prevStartIndex = INDEX_UNINITIALIZED; - unsigned prevEndIndex = 0; // to silence a gcc warning (we overwrite the value below anyway, at least where it matters) - - // copies for unification - auto tl1s = tl1; - auto tl2s = tl2; - - // lexicographic comparisons - for (unsigned i = 0; i < s->arity(); i++) { - auto s_arg = *tl1s.term()->nthArgument(i); - auto t_arg = *tl2s.term()->nthArgument(i); - auto compRes = createHelper(s_arg,t_arg,lpo); - - // If the comparison is 'equal', we skip this iteration. - if (compRes->second == BranchTag::T_EQUAL) { - // In the next iteration these two arguments are - // assumed to be equal, so we can restrict the - // comparisons by doing a unification here. - ALWAYS(unify(s_arg,t_arg,tl1s,tl2s)); - continue; - } - - auto majoRes = majoChain(lpo, tl1s, tl2s.term(), i+1); - // If the comparison is 'greater', the rest of the - // instructions consist of the majo chain. - if (compRes->second == BranchTag::T_GREATER) { - if (majoRes.second != BranchTag::T_JUMP) { - // The majo chain is empty. - if (prevStartIndex != INDEX_UNINITIALIZED) { - // Update previous 'equal' values to the new return value. - updateBranchInRange(res->first, prevStartIndex, prevEndIndex, - Branch::eq(), Branch{ majoRes.second, 0 }); - } else { - // Update the overall return value to the new return value. - res->first.reset(); - res->second = majoRes.second; - } - } else { - // The majo chain is non-empty, update the previous 'equal' - // values to the new jump offset and push the chain. - if (prevStartIndex != INDEX_UNINITIALIZED) { - updateBranchInRange(res->first, prevStartIndex, prevEndIndex, - Branch::eq(), Branch::jump(res->first.size())); - } - pushInstructions(res->first, majoRes.first, Branch::eq(), Branch::gt(), Branch::inc()); - } - break; - } - - auto alphaRes = alphaChain(lpo, tl1s.term(), i+1, tl2s); - // If the comparison is 'incomparable', the rest of the - // instructions consist of the alpha chain. - if (compRes->second == BranchTag::T_INCOMPARABLE) { - // The alpha chain is empty. - if (alphaRes.second != BranchTag::T_JUMP) { - if (prevStartIndex != INDEX_UNINITIALIZED) { - // Update previous 'equal' values to the new return value. - updateBranchInRange(res->first, prevStartIndex, prevEndIndex, - Branch::eq(), Branch{ alphaRes.second, 0 }); - } else { - // Update the overall return value to the new return value. - res->first.reset(); - res->second = alphaRes.second; - } - } else { - // The alpha chain is non-empty, update the previous 'equal' - // values to the new jump offset and push the chain. - if (prevStartIndex != INDEX_UNINITIALIZED) { - updateBranchInRange(res->first, prevStartIndex, prevEndIndex, - Branch::eq(), Branch::jump(res->first.size())); - } - pushInstructions(res->first, alphaRes.first, Branch::eq(), Branch::gt(), Branch::inc()); - } - break; - } - - // Otherwise, we have to create a branching point here. - - // Replace the previous 'equal' instructions with a jump to here. - if (prevStartIndex != INDEX_UNINITIALIZED) { - updateBranchInRange(res->first, prevStartIndex, prevEndIndex, - Branch::eq(), Branch::jump(res->first.size())); - } - prevStartIndex = res->first.size(); - prevEndIndex = res->first.size() + compRes->first.size(); - // The majo branch will be after the comparison instructions - Branch majoBranch{ - majoRes.second, - (uint16_t)(majoRes.second == BranchTag::T_JUMP ? res->first.size() + compRes->first.size() : 0) - }; - // The alpha branch will be after the majo branch instructions - Branch alphaBranch{ - alphaRes.second, - (uint16_t)(alphaRes.second == BranchTag::T_JUMP ? res->first.size() + compRes->first.size() + majoRes.first.size() : 0) - }; - - // push all three sets of instructions if needed - pushInstructions(res->first, compRes->first, Branch::eq(), majoBranch, alphaBranch); - - if (majoRes.second == BranchTag::T_JUMP) { - pushInstructions(res->first, majoRes.first, Branch::eq(), Branch::gt(), Branch::inc()); - } - - if (alphaRes.second == BranchTag::T_JUMP) { - pushInstructions(res->first, alphaRes.first, Branch::eq(), Branch::gt(), Branch::inc()); - } - - if (!unify(s_arg,t_arg,tl1s,tl2s)) { - // If we cannot unify, the rest of the iterations will be - // 'incomparable', update the previous 'equal' values to that. - updateBranchInRange(res->first, prevStartIndex, prevEndIndex, Branch::eq(), Branch::inc()); - break; - } - } - break; - } - case Ordering::GREATER: { - auto subres = majoChain(lpo, tl1, t, 0); - if (subres.second == BranchTag::T_JUMP) { - pushInstructions(res->first, subres.first, Branch::eq(), Branch::gt(), Branch::inc()); - } else { - res->second = subres.second; - } - break; - } - case Ordering::LESS: { - auto subres = alphaChain(lpo, s, 0, tl2); - if (subres.second == BranchTag::T_JUMP) { - pushInstructions(res->first, subres.first, Branch::eq(), Branch::gt(), Branch::inc()); - } else { - res->second = subres.second; - } - break; - } - default: - ASSERTION_VIOLATION; - } - ASS((res->second != BranchTag::T_JUMP) == res->first.isEmpty()); - deleteDuplicates(res->first); - ASS((res->second != BranchTag::T_JUMP) == res->first.isEmpty()); - ASS(res->second != BranchTag::T_GREATER || lpo.compare(tl1,tl2)==Ordering::GREATER); - ASS(res->second != BranchTag::T_EQUAL || lpo.compare(tl1,tl2)==Ordering::EQUAL); - ASS(res->second != BranchTag::T_INCOMPARABLE || lpo.compare(tl1,tl2)==Ordering::LESS || lpo.compare(tl1,tl2)==Ordering::INCOMPARABLE); - ptr = _cache.findPtr(make_pair(tl1,tl2)); - ASS(ptr); - *ptr = res; - return *ptr; -} - LPOComparator::LPOComparator(TermList lhs, TermList rhs, const LPO& lpo) - : OrderingComparator(lhs, rhs, lpo), _ready(false), _instructions(), _res(BranchTag::T_JUMP) -{ -} - -bool LPOComparator::check(const SubstApplicator* applicator) -{ - const auto& lpo = static_cast(_ord); - if (!_ready) { - auto kv = createHelper(_lhs, _rhs, lpo); - _res = kv->second; - _instructions = kv->first; - _ready = true; - } - - // we calculate all three values in each iteration - // to optimise CPU branch prediction (the values are - // computed regardless and hence no branching is needed) - bool cond = _res == BranchTag::T_JUMP; - bool res = _res == BranchTag::T_GREATER; - auto curr = _instructions.begin(); - - while (cond) { - auto comp = lpo.lpo(AppliedTerm(curr->lhs,applicator,true),AppliedTerm(curr->rhs,applicator,true)); - const auto& branch = curr->getBranch(comp); - - cond = branch.tag == BranchTag::T_JUMP; - res = branch.tag == BranchTag::T_GREATER; - curr = _instructions.begin() + branch.jump_pos; - } - return res; -} - -// LPOComparator2 - -LPOComparator2::LPOComparator2(TermList lhs, TermList rhs, const LPO& lpo) : OrderingComparator(lhs, rhs, lpo), _root(new Node(lhs, rhs)) { } -ostream& operator<<(ostream& out, const LPOComparator2::BranchTag& t) +ostream& operator<<(ostream& out, const LPOComparator::BranchTag& t) { switch (t) { - case LPOComparator2::BranchTag::T_NOT_GREATER: + case LPOComparator::BranchTag::T_NOT_GREATER: out << "!>"; break; - case LPOComparator2::BranchTag::T_GREATER: + case LPOComparator::BranchTag::T_GREATER: out << ">"; break; - case LPOComparator2::BranchTag::T_COMPARISON: + case LPOComparator::BranchTag::T_COMPARISON: out << "$"; break; - case LPOComparator2::BranchTag::T_UNKNOWN: + case LPOComparator::BranchTag::T_UNKNOWN: out << "?"; break; } return out; } -std::ostream& operator<<(std::ostream& str, const LPOComparator2::Branch& branch) +std::ostream& operator<<(std::ostream& str, const LPOComparator::Branch& branch) { str << branch.tag << " "; if (branch.n) { @@ -528,10 +83,10 @@ std::ostream& operator<<(std::ostream& str, const LPOComparator2::Branch& branch return str; } -ostream& operator<<(ostream& str, const LPOComparator2& comp) +ostream& operator<<(ostream& str, const LPOComparator& comp) { str << "comparator for " << comp._lhs << " > " << comp._rhs << endl; - Stack> todo; + Stack> todo; todo.push(make_pair(&comp._root,0)); unsigned cnt = 1; @@ -551,7 +106,7 @@ ostream& operator<<(ostream& str, const LPOComparator2& comp) return str; } -std::string LPOComparator2::toString() const +std::string LPOComparator::toString() const { std::stringstream str; str << *this << endl; @@ -561,7 +116,7 @@ std::string LPOComparator2::toString() const /** * Implements an @b LPO::majo call via instructions. */ -void LPOComparator2::majoChain(Branch* branch, TermList tl1, Term* t, unsigned i, Branch success, Branch fail) +void LPOComparator::majoChain(Branch* branch, TermList tl1, Term* t, unsigned i, Branch success, Branch fail) { ASS(branch); for (unsigned j = i; j < t->arity(); j++) { @@ -576,7 +131,7 @@ void LPOComparator2::majoChain(Branch* branch, TermList tl1, Term* t, unsigned i /** * Implements an @b LPO::alpha call via instructions. */ -void LPOComparator2::alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2, Branch success, Branch fail) +void LPOComparator::alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2, Branch success, Branch fail) { ASS(branch); for (unsigned j = i; j < s->arity(); j++) { @@ -588,7 +143,7 @@ void LPOComparator2::alphaChain(Branch* branch, Term* s, unsigned i, TermList tl *branch = fail; } -void LPOComparator2::expand(Branch& branch, const LPO& lpo) +void LPOComparator::expand(Branch& branch, const LPO& lpo) { while (branch.tag == BranchTag::T_UNKNOWN) { // Use compare here to filter out as many @@ -665,7 +220,7 @@ void LPOComparator2::expand(Branch& branch, const LPO& lpo) } } -bool LPOComparator2::check(const SubstApplicator* applicator) +bool LPOComparator::check(const SubstApplicator* applicator) { const auto& lpo = static_cast(_ord); auto curr = &_root; diff --git a/Kernel/LPOComparator.hpp b/Kernel/LPOComparator.hpp index 831461dee..5a250ae2c 100644 --- a/Kernel/LPOComparator.hpp +++ b/Kernel/LPOComparator.hpp @@ -20,114 +20,18 @@ namespace Kernel { using namespace Lib; using namespace std; -/** - * Runtime specialized LPO ordering check, based on the LPO check - * that has quadratic time complexity @see LPO::lpo. - */ class LPOComparator : public OrderingComparator { public: - /** The runtime specialization happens in the constructor. */ LPOComparator(TermList lhs, TermList rhs, const LPO& lpo); /** Executes the runtime specialized instructions with concrete substitution. */ bool check(const SubstApplicator* applicator) override; std::string toString() const override; - /** - * Represents comparing check between two terms and branching - * information based on the result. The comparison results in - * either @b GREATER, @b EQUAL or @b INCOMPARABLE, hence there - * are three branches. - */ - struct Instruction { - /** - * Possible values for a branch, i.e. return the result - * of the comparison, or jump to a different instruction. - */ - enum class BranchTag : uint8_t { - T_EQUAL, - T_GREATER, - T_INCOMPARABLE, - T_JUMP, - }; - - struct Branch { - BranchTag tag; - uint16_t jump_pos; // jump positions are absolute - - std::tuple asTuple() const - { return std::make_tuple(tag, jump_pos); } - - IMPL_COMPARISONS_FROM_TUPLE(Branch); - IMPL_HASH_FROM_TUPLE(Branch); - - static constexpr Branch eq() { return Branch{ BranchTag::T_EQUAL, 0 }; } - static constexpr Branch gt() { return Branch{ BranchTag::T_GREATER, 0 }; } - static constexpr Branch inc() { return Branch{ BranchTag::T_INCOMPARABLE, 0 }; } - static constexpr Branch jump(uint16_t pos) { return Branch{ BranchTag::T_JUMP, pos }; } - - void update(Branch eqBranch, Branch gtBranch, Branch incBranch, unsigned jump_offset); - }; - - Instruction(TermList lhs, TermList rhs) - : lhs(lhs), rhs(rhs), bs() { bs[0] = Branch::eq(); bs[1] = Branch::gt(); bs[2] = Branch::inc(); } - - constexpr const auto& getBranch(Ordering::Result r) const { - switch (r) { - case Ordering::EQUAL: - return bs[0]; - case Ordering::GREATER: - return bs[1]; - case Ordering::INCOMPARABLE: - return bs[2]; - default: - ASSERTION_VIOLATION; - } - } - - std::tuple asTuple() const - { return std::make_tuple(lhs, rhs, bs[0], bs[1], bs[2]); } - - IMPL_COMPARISONS_FROM_TUPLE(Instruction); - IMPL_HASH_FROM_TUPLE(Instruction); - - // two terms for the comparison - TermList lhs; - TermList rhs; - // three branches for the three possible comparison results - Branch bs[3]; - - }; - -private: - static pair,Instruction::BranchTag> majoChain(const LPO& lpo, TermList tl1, Term* t, unsigned i); - static pair,Instruction::BranchTag> alphaChain(const LPO& lpo, Term* s, unsigned i, TermList tl2); - static pair,Instruction::BranchTag>* createHelper(TermList tl1, TermList tl2, const LPO& lpo); - - bool _ready; - - /** This is non-empty if @b _res is @b BranchTag::T_JUMP */ - Stack _instructions; - - /** It contains the result of the comparison if the terms - * are comparable, otherwise it contains @b BranchTag::T_JUMP - * to indicate that @b _instructions have to be executed. */ - Instruction::BranchTag _res; -}; - -class LPOComparator2 -: public OrderingComparator -{ -public: - /** The runtime specialization happens in the constructor. */ - LPOComparator2(TermList lhs, TermList rhs, const LPO& lpo); - - /** Executes the runtime specialized instructions with concrete substitution. */ - bool check(const SubstApplicator* applicator) override; - std::string toString() const override; - + /* A branch initially has a T_UNKNOWN tag, and after first processing becomes either + * a specific result T_GREATER/T_NOT_GREATER or a pointer to a comparison node. */ enum class BranchTag : uint8_t { T_GREATER, T_NOT_GREATER, @@ -192,14 +96,13 @@ class LPOComparator2 Branch gtBranch; Branch incBranch; }; - private: static void majoChain(Branch* branch, TermList tl1, Term* t, unsigned i, Branch success, Branch fail); static void alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2, Branch success, Branch fail); static void expand(Branch& branch, const LPO& lpo); - friend ostream& operator<<(ostream& str, const LPOComparator2& comp); + friend ostream& operator<<(ostream& str, const LPOComparator& comp); Branch _root; }; From bb7fac815a6e21e6f1b3f66306f7b22899c9c74b Mon Sep 17 00:00:00 2001 From: Marton Hajdu Date: Tue, 10 Sep 2024 09:55:08 +0200 Subject: [PATCH 5/6] Clean up OrderingComparator initializations --- Kernel/KBOComparator.cpp | 2 +- Kernel/KBOComparator.hpp | 2 +- Kernel/LPOComparator.cpp | 24 +++++++++++++----------- Kernel/LPOComparator.hpp | 36 ++++++++---------------------------- 4 files changed, 23 insertions(+), 41 deletions(-) diff --git a/Kernel/KBOComparator.cpp b/Kernel/KBOComparator.cpp index 66db8574b..f9879142b 100644 --- a/Kernel/KBOComparator.cpp +++ b/Kernel/KBOComparator.cpp @@ -21,7 +21,7 @@ using namespace Lib; using namespace Shell; KBOComparator::KBOComparator(TermList lhs, TermList rhs, const KBO& kbo) - : OrderingComparator(lhs, rhs, kbo), _ready(false), _instructions() + : OrderingComparator(lhs, rhs, kbo) { } diff --git a/Kernel/KBOComparator.hpp b/Kernel/KBOComparator.hpp index b1fc33326..15adc4578 100644 --- a/Kernel/KBOComparator.hpp +++ b/Kernel/KBOComparator.hpp @@ -111,7 +111,7 @@ class KBOComparator private: uint64_t _content; }; - bool _ready; + bool _ready = false; Stack _instructions; }; diff --git a/Kernel/LPOComparator.cpp b/Kernel/LPOComparator.cpp index c36fbecf4..51e513d2c 100644 --- a/Kernel/LPOComparator.cpp +++ b/Kernel/LPOComparator.cpp @@ -49,7 +49,7 @@ bool unify(TermList tl1, TermList tl2, TermList& orig1, TermList& orig2) } LPOComparator::LPOComparator(TermList lhs, TermList rhs, const LPO& lpo) - : OrderingComparator(lhs, rhs, lpo), _root(new Node(lhs, rhs)) + : OrderingComparator(lhs, rhs, lpo), _root(lhs, rhs) { } @@ -120,7 +120,7 @@ void LPOComparator::majoChain(Branch* branch, TermList tl1, Term* t, unsigned i, { ASS(branch); for (unsigned j = i; j < t->arity(); j++) { - *branch = Branch(new Node(tl1,*t->nthArgument(j))); + *branch = Branch(tl1,*t->nthArgument(j)); branch->n->eqBranch = fail; branch->n->incBranch = fail; branch = &branch->n->gtBranch; @@ -135,7 +135,7 @@ void LPOComparator::alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2 { ASS(branch); for (unsigned j = i; j < s->arity(); j++) { - *branch = Branch(new Node(*s->nthArgument(j),tl2)); + *branch = Branch(*s->nthArgument(j),tl2); branch->n->eqBranch = success; branch->n->gtBranch = success; branch = &branch->n->incBranch; @@ -145,12 +145,14 @@ void LPOComparator::alphaChain(Branch* branch, Term* s, unsigned i, TermList tl2 void LPOComparator::expand(Branch& branch, const LPO& lpo) { - while (branch.tag == BranchTag::T_UNKNOWN) { + while (branch.tag == BranchTag::T_UNKNOWN) + { + // take temporary ownership of node + Branch nodeHolder = branch; + auto node = nodeHolder.n.get(); + // Use compare here to filter out as many // precomputable comparisons as possible. - auto node = branch.n; - node->acquire(); - auto comp = lpo.compare(node->lhs,node->rhs); if (comp != Ordering::INCOMPARABLE) { if (comp == Ordering::LESS) { @@ -160,12 +162,12 @@ void LPOComparator::expand(Branch& branch, const LPO& lpo) } else { branch = node->eqBranch; } - goto loop_end; + continue; } // If we have a variable, we cannot preprocess further. if (node->lhs.isVar() || node->rhs.isVar()) { branch.tag = BranchTag::T_COMPARISON; - goto loop_end; + continue; } { @@ -187,7 +189,7 @@ void LPOComparator::expand(Branch& branch, const LPO& lpo) { auto s_arg = *lhs.term()->nthArgument(i); auto t_arg = *rhs.term()->nthArgument(i); - *curr = Branch(new Node(s_arg,t_arg)); + *curr = Branch(s_arg,t_arg); // greater branch is a majo chain majoChain(&curr->n->gtBranch, lhs, rhs.term(), i+1, node->gtBranch, node->incBranch); // incomparable branch is an alpha chain @@ -216,7 +218,7 @@ void LPOComparator::expand(Branch& branch, const LPO& lpo) } } loop_end: - node->release(); + continue; } } diff --git a/Kernel/LPOComparator.hpp b/Kernel/LPOComparator.hpp index 5a250ae2c..cf184c678 100644 --- a/Kernel/LPOComparator.hpp +++ b/Kernel/LPOComparator.hpp @@ -39,43 +39,24 @@ class LPOComparator T_UNKNOWN, }; - struct Node; + class Node; struct Branch { BranchTag tag; - Node* n; + std::shared_ptr n; explicit Branch(BranchTag t) : tag(t), n(nullptr) { ASS(t==BranchTag::T_GREATER || t==BranchTag::T_NOT_GREATER); } - explicit Branch(Node* n) : tag(BranchTag::T_UNKNOWN), n(n) { ASS(n); n->acquire(); } - Branch(const Branch& other) : tag(other.tag), n(other.n) { if (n) { n->acquire(); } } - ~Branch() { if (n) { n->release(); } } - Branch& operator=(const Branch& other) { - if (this != &other) { - tag = other.tag; - if (n) { n->release(); } - n = other.n; - if (n) { n->acquire(); } - } - return *this; - } + explicit Branch(TermList lhs, TermList rhs) : tag(BranchTag::T_UNKNOWN), n(new Node(lhs, rhs)) {} }; - struct Node { + class Node { Node(TermList lhs, TermList rhs) - : refcnt(0), lhs(lhs), rhs(rhs), eqBranch(BranchTag::T_NOT_GREATER), gtBranch(BranchTag::T_GREATER), incBranch(BranchTag::T_NOT_GREATER) {} - - void acquire() { - refcnt++; - } + : lhs(lhs), rhs(rhs), eqBranch(BranchTag::T_NOT_GREATER), gtBranch(BranchTag::T_GREATER), incBranch(BranchTag::T_NOT_GREATER) {} - void release() { - ASS(refcnt); - refcnt--; - if (!refcnt) { - delete this; - } - } + // only allow calling ctor from Branch + friend struct Branch; + public: auto& getBranch(Ordering::Result r) { switch (r) { case Ordering::EQUAL: @@ -89,7 +70,6 @@ class LPOComparator } } - unsigned refcnt; TermList lhs; TermList rhs; Branch eqBranch; From df71b217fbd09d04b9c7a898f506627cc0f36940 Mon Sep 17 00:00:00 2001 From: Marton Hajdu Date: Thu, 12 Sep 2024 17:28:25 +0200 Subject: [PATCH 6/6] Change std::shared_ptr to Lib::SmartPtr --- Kernel/LPOComparator.cpp | 2 +- Kernel/LPOComparator.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Kernel/LPOComparator.cpp b/Kernel/LPOComparator.cpp index 51e513d2c..ec8ec443c 100644 --- a/Kernel/LPOComparator.cpp +++ b/Kernel/LPOComparator.cpp @@ -149,7 +149,7 @@ void LPOComparator::expand(Branch& branch, const LPO& lpo) { // take temporary ownership of node Branch nodeHolder = branch; - auto node = nodeHolder.n.get(); + auto node = nodeHolder.n.ptr(); // Use compare here to filter out as many // precomputable comparisons as possible. diff --git a/Kernel/LPOComparator.hpp b/Kernel/LPOComparator.hpp index cf184c678..15cb18283 100644 --- a/Kernel/LPOComparator.hpp +++ b/Kernel/LPOComparator.hpp @@ -43,7 +43,7 @@ class LPOComparator struct Branch { BranchTag tag; - std::shared_ptr n; + SmartPtr n; explicit Branch(BranchTag t) : tag(t), n(nullptr) { ASS(t==BranchTag::T_GREATER || t==BranchTag::T_NOT_GREATER); } explicit Branch(TermList lhs, TermList rhs) : tag(BranchTag::T_UNKNOWN), n(new Node(lhs, rhs)) {}