diff --git a/src/configuration/GlobalConfiguration.cpp b/src/configuration/GlobalConfiguration.cpp index 11aa22507..f1d998890 100644 --- a/src/configuration/GlobalConfiguration.cpp +++ b/src/configuration/GlobalConfiguration.cpp @@ -22,6 +22,7 @@ const double GlobalConfiguration::PIVOT_CHANGE_COLUMN_TOLERANCE = 0.000000001; const unsigned GlobalConfiguration::DEGRADATION_CHECKING_FREQUENCY = 10; const double GlobalConfiguration::DEGRADATION_THRESHOLD = 0.1; const double GlobalConfiguration::ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD = 0.0001; +const bool GlobalConfiguration::USE_COLUMN_MERGING_EQUATIONS = false; const double GlobalConfiguration::GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD = 0.1; const unsigned GlobalConfiguration::MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS = 5; const unsigned GlobalConfiguration::CONSTRAINT_VIOLATION_THRESHOLD = 20; @@ -64,6 +65,7 @@ void GlobalConfiguration::print() printf( " DEGRADATION_CHECKING_FREQUENCY: %u\n", DEGRADATION_CHECKING_FREQUENCY ); printf( " DEGRADATION_THRESHOLD: %.15lf\n", DEGRADATION_THRESHOLD ); printf( " ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD: %.15lf\n", ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD ); + printf( " USE_COLUMN_MERGING_EQUATIONS: %s\n", USE_COLUMN_MERGING_EQUATIONS ? "Yes" : "No" ); printf( " GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD: %.15lf\n", GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD ); printf( " MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS: %u\n", MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS ); printf( " CONSTRAINT_VIOLATION_THRESHOLD: %u\n", CONSTRAINT_VIOLATION_THRESHOLD ); diff --git a/src/configuration/GlobalConfiguration.h b/src/configuration/GlobalConfiguration.h index 3e260215a..b4299bedc 100644 --- a/src/configuration/GlobalConfiguration.h +++ b/src/configuration/GlobalConfiguration.h @@ -55,6 +55,10 @@ class GlobalConfiguration // to pick another element. static const double ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD; + // If true, column-merging equations are given special treatment and cause columns in the tableau + // to be merged (instead of a new row added). + static const bool USE_COLUMN_MERGING_EQUATIONS; + // If a pivot element in a Gaussian elimination iteration is smaller than this threshold times // the largest element in the column, the elimination engine will attempt to pick another pivot. static const double GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD; diff --git a/src/engine/Engine.cpp b/src/engine/Engine.cpp index 2780adf76..a60acc31e 100644 --- a/src/engine/Engine.cpp +++ b/src/engine/Engine.cpp @@ -743,27 +743,130 @@ void Engine::applySplit( const PiecewiseLinearCaseSplit &split ) List equations = split.getEquations(); for ( auto &equation : equations ) { - unsigned auxVariable = _tableau->addEquation( equation ); - _activeEntryStrategy->resizeHook( _tableau ); - - switch ( equation._type ) + /* + In the general case, we just add the new equation to the tableau. + However, we also support a very common case: equations of the form + x1 = x2, which are common, e.g., with ReLUs. For these equations we + may be able to merge two columns of the tableau. + */ + unsigned x1, x2; + bool canMergeColumns = + // Only if the flag is on + GlobalConfiguration::USE_COLUMN_MERGING_EQUATIONS && + // Only if the equation has the correct form + equation.isVariableMergingEquation( x1, x2 ) && + // And only if the variables are not out of bounds + ( !_tableau->isBasic( x1 ) || + !_tableau->basicOutOfBounds( _tableau->variableToIndex( x1 ) ) ) + && + ( !_tableau->isBasic( x2 ) || + !_tableau->basicOutOfBounds( _tableau->variableToIndex( x2 ) ) ); + + if ( canMergeColumns ) { - case Equation::GE: - bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) ); - break; + /* + Special case: x1 and x2 need to be merged. + First, we need to ensure they are both non-basic. + */ + unsigned n = _tableau->getN(); + unsigned m = _tableau->getM(); + + if ( _tableau->isBasic( x1 ) ) + { + TableauRow x1Row( n - m ); + _tableau->getTableauRow( _tableau->variableToIndex( x1 ), &x1Row ); - case Equation::LE: - bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) ); - break; + bool found = false; + unsigned nonBasic; + for ( unsigned i = 0; i < n - m; ++i ) + { + if ( !FloatUtils::isZero( x1Row._row[i]._coefficient ) && ( x1Row._row[i]._var != x2 ) ) + { + found = true; + nonBasic = x1Row._row[i]._var; + break; + } + } - case Equation::EQ: - bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) ); - bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) ); - break; + if ( !found ) + throw ReluplexError( ReluplexError::ENGINE_APPLY_SPLIT_FAILED, + "Could not find a variable to pivot with" ); - default: - ASSERT( false ); - break; + _tableau->setEnteringVariableIndex( _tableau->variableToIndex( nonBasic ) ); + _tableau->setLeavingVariableIndex( _tableau->variableToIndex( x1 ) ); + + // Make sure the change column and pivot row are up-to-date - strategies + // such as projected steepest edge need these for their internal updates. + _tableau->computeChangeColumn(); + _tableau->computePivotRow(); + + _activeEntryStrategy->prePivotHook( _tableau, false ); + _tableau->performDegeneratePivot(); + _activeEntryStrategy->prePivotHook( _tableau, false ); + } + + if ( _tableau->isBasic( x2 ) ) + { + TableauRow x2Row( n - m ); + _tableau->getTableauRow( _tableau->variableToIndex( x2 ), &x2Row ); + + bool found = false; + unsigned nonBasic; + for ( unsigned i = 0; i < n - m; ++i ) + { + if ( !FloatUtils::isZero( x2Row._row[i]._coefficient ) && ( x2Row._row[i]._var != x1 ) ) + { + found = true; + nonBasic = x2Row._row[i]._var; + break; + } + } + + if ( !found ) + throw ReluplexError( ReluplexError::ENGINE_APPLY_SPLIT_FAILED, + "Could not find a variable to pivot with" ); + + _tableau->setEnteringVariableIndex( _tableau->variableToIndex( nonBasic ) ); + _tableau->setLeavingVariableIndex( _tableau->variableToIndex( x2 ) ); + + // Make sure the change column and pivot row are up-to-date - strategies + // such as projected steepest edge need these for their internal updates. + _tableau->computeChangeColumn(); + _tableau->computePivotRow(); + + _activeEntryStrategy->prePivotHook( _tableau, false ); + _tableau->performDegeneratePivot(); + _activeEntryStrategy->prePivotHook( _tableau, false ); + } + + // Both variables are now non-basic, so we can merge their columns + _tableau->mergeColumns( x1, x2 ); + } + else + { + // General case: add a new equation to the tableau + unsigned auxVariable = _tableau->addEquation( equation ); + _activeEntryStrategy->resizeHook( _tableau ); + + switch ( equation._type ) + { + case Equation::GE: + bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) ); + break; + + case Equation::LE: + bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) ); + break; + + case Equation::EQ: + bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) ); + bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) ); + break; + + default: + ASSERT( false ); + break; + } } } diff --git a/src/engine/Equation.cpp b/src/engine/Equation.cpp index c01deb4da..40fcd658c 100644 --- a/src/engine/Equation.cpp +++ b/src/engine/Equation.cpp @@ -114,6 +114,30 @@ void Equation::dump() const printf( "%.2lf\n", _scalar ); } +bool Equation::isVariableMergingEquation( unsigned &x1, unsigned &x2 ) const +{ + if ( _addends.size() != 2 ) + return false; + + if ( !FloatUtils::isZero( _scalar ) ) + return false; + + double coefficientOne = _addends.front()._coefficient; + double coefficientTwo = _addends.back()._coefficient; + + if ( FloatUtils::isZero( coefficientOne ) || FloatUtils::isZero( coefficientTwo ) ) + return false; + + if ( FloatUtils::areEqual( coefficientOne, -coefficientTwo ) ) + { + x1 = _addends.front()._variable; + x2 = _addends.back()._variable; + return true; + } + + return false; +} + // // Local Variables: // compile-command: "make -C ../.. " diff --git a/src/engine/Equation.h b/src/engine/Equation.h index 7a931e3cd..bf73d0e16 100644 --- a/src/engine/Equation.h +++ b/src/engine/Equation.h @@ -52,6 +52,13 @@ class Equation */ void updateVariableIndex( unsigned oldVar, unsigned newVar ); + /* + Return true iff the variable is a "variable merging equation", + i.e. an equation of the form x = y. If true is returned, x1 and + x2 are the merged variables. + */ + bool isVariableMergingEquation( unsigned &x1, unsigned &x2 ) const; + List _addends; double _scalar; EquationType _type; diff --git a/src/engine/ITableau.h b/src/engine/ITableau.h index a480ac913..ad35e7fe9 100644 --- a/src/engine/ITableau.h +++ b/src/engine/ITableau.h @@ -168,6 +168,7 @@ class ITableau virtual Equation *getBasisEquation( unsigned row ) const = 0; virtual double *getInverseBasisMatrix() const = 0; virtual void refreshBasisFactorization() = 0; + virtual void mergeColumns( unsigned x1, unsigned x2 ) = 0; }; #endif // __ITableau_h__ diff --git a/src/engine/ReluplexError.h b/src/engine/ReluplexError.h index c7151b2bc..2a11a9a2a 100644 --- a/src/engine/ReluplexError.h +++ b/src/engine/ReluplexError.h @@ -34,6 +34,7 @@ class ReluplexError : public Error CANNOT_RESTORE_TABLEAU = 12, FAILURE_TO_ADD_NEW_EQUATION = 13, RESTORATION_FAILED_TO_REFACTORIZE_BASIS = 14, + ENGINE_APPLY_SPLIT_FAILED = 15, DEBUGGING_ERROR = 999, }; diff --git a/src/engine/Tableau.cpp b/src/engine/Tableau.cpp index fb2556c47..eef56302c 100644 --- a/src/engine/Tableau.cpp +++ b/src/engine/Tableau.cpp @@ -402,10 +402,17 @@ const double *Tableau::getUpperBounds() const double Tableau::getValue( unsigned variable ) { + /* + If this variable has been merged into another, + we need to be reading the other variable's value + */ + if ( _mergedVariables.exists( variable ) ) + variable = _mergedVariables[variable]; + + // The values of non-basics can be extracted even if the + // assignment is invalid if ( !_basicVariables.exists( variable ) ) { - // The values of non-basics can be extracted even if the - // assignment is invalid unsigned index = _variableToIndex[variable]; return _nonBasicAssignment[index]; } @@ -1112,6 +1119,9 @@ void Tableau::storeState( TableauState &state ) const // Store the _boundsValid indicator state._boundsValid = _boundsValid; + + // Store the merged variables + state._mergedVariables = _mergedVariables; } void Tableau::restoreState( const TableauState &state ) @@ -1149,6 +1159,9 @@ void Tableau::restoreState( const TableauState &state ) // Restore the _boundsValid indicator _boundsValid = state._boundsValid; + // Restore the merged varaibles + _mergedVariables = state._mergedVariables; + computeAssignment(); _costFunctionManager->initialize(); computeCostFunction(); @@ -1911,6 +1924,8 @@ void Tableau::registerCostFunctionManager( ICostFunctionManager *costFunctionMan const double *Tableau::getColumnOfBasis( unsigned column ) const { ASSERT( column < _m ); + ASSERT( !_mergedVariables.exists( _basicIndexToVariable[column] ) ); + unsigned variable = _basicIndexToVariable[column]; return _A + ( variable * _m ); } @@ -1920,6 +1935,35 @@ void Tableau::refreshBasisFactorization() _basisFactorization->obtainFreshBasis(); } +void Tableau::mergeColumns( unsigned x1, unsigned x2 ) +{ + ASSERT( !isBasic( x1 ) ); + ASSERT( !isBasic( x2 ) ); + + /* + If x2 has tighter bounds than x1, adjust the bounds + for x1. + */ + if ( FloatUtils::lt( _upperBounds[x2], _upperBounds[x1] ) ) + tightenUpperBound( x1, _upperBounds[x2] ); + if ( FloatUtils::gt( _lowerBounds[x2], _lowerBounds[x1] ) ) + tightenLowerBound( x1, _lowerBounds[x2] ); + + /* + Merge column x2 of the constraint matrix into x1 + and zero-out column x2 + */ + for ( unsigned row = 0; row < _m; ++row ) + { + _A[(x1 * _m) + row] += _A[(x2 * _m) + row]; + _A[(x2 * _m) + row] = 0.0; + } + _mergedVariables[x2] = x1; + + computeAssignment(); + computeCostFunction(); +} + // // Local Variables: // compile-command: "make -C ../.. " diff --git a/src/engine/Tableau.h b/src/engine/Tableau.h index 4f15d7a9d..7e0840bbc 100644 --- a/src/engine/Tableau.h +++ b/src/engine/Tableau.h @@ -401,6 +401,12 @@ class Tableau : public ITableau, public IBasisFactorization::BasisColumnOracle */ void refreshBasisFactorization(); + /* + Merge two columns of the constraint matrix and re-initialize + the tableau. + */ + void mergeColumns( unsigned x1, unsigned x2 ); + private: /* Variable watchers @@ -547,6 +553,13 @@ class Tableau : public ITableau, public IBasisFactorization::BasisColumnOracle */ ICostFunctionManager *_costFunctionManager; + /* + _mergedVariables[x] = y means that x = y, and that + variable x has been merged into variable y. So, when + extracting a solution for x, we should read the value of y. + */ + Map _mergedVariables; + /* Free all allocated memory. */ diff --git a/src/engine/TableauState.h b/src/engine/TableauState.h index b117518e7..07876523c 100644 --- a/src/engine/TableauState.h +++ b/src/engine/TableauState.h @@ -15,6 +15,7 @@ #include "IBasisFactorization.h" #include "ITableau.h" +#include "Map.h" #include "Set.h" class TableauState @@ -104,6 +105,13 @@ class TableauState Indicator whether the bounds are valid */ bool _boundsValid; + + /* + _mergedVariables[x] = y means that x = y, and that + variable x has been merged into variable y. So, when + extracting a solution for x, we should read the value of y. + */ + Map _mergedVariables; }; #endif // __TableauState_h__ diff --git a/src/engine/tests/MockTableau.h b/src/engine/tests/MockTableau.h index 441d43a13..7bc30fc41 100644 --- a/src/engine/tests/MockTableau.h +++ b/src/engine/tests/MockTableau.h @@ -520,6 +520,10 @@ class MockTableau : public ITableau void refreshBasisFactorization() { } + + void mergeColumns( unsigned /* x1 */, unsigned /* x2 */ ) + { + } }; #endif // __MockTableau_h__