Skip to content

Commit

Permalink
Tableau merge cols (#68)
Browse files Browse the repository at this point in the history
* check if an equation is an identity equation

* supprot for merging columns in the tableau when a special equation is encountered

* debug stuff

* store and restore merged vars in the tableau

* cleanup

* cleanup

* a more efficient way for merging: without re-computing a set of basic variables

* dont use new technique by default
  • Loading branch information
guykatzz authored Jun 15, 2018
1 parent 9ec64dd commit f13af12
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 19 deletions.
2 changes: 2 additions & 0 deletions src/configuration/GlobalConfiguration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ const double GlobalConfiguration::PIVOT_CHANGE_COLUMN_TOLERANCE = 0.000000001;
const unsigned GlobalConfiguration::DEGRADATION_CHECKING_FREQUENCY = 10;
const double GlobalConfiguration::DEGRADATION_THRESHOLD = 0.1;
const double GlobalConfiguration::ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD = 0.0001;
const bool GlobalConfiguration::USE_COLUMN_MERGING_EQUATIONS = false;
const double GlobalConfiguration::GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD = 0.1;
const unsigned GlobalConfiguration::MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS = 5;
const unsigned GlobalConfiguration::CONSTRAINT_VIOLATION_THRESHOLD = 20;
Expand Down Expand Up @@ -64,6 +65,7 @@ void GlobalConfiguration::print()
printf( " DEGRADATION_CHECKING_FREQUENCY: %u\n", DEGRADATION_CHECKING_FREQUENCY );
printf( " DEGRADATION_THRESHOLD: %.15lf\n", DEGRADATION_THRESHOLD );
printf( " ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD: %.15lf\n", ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD );
printf( " USE_COLUMN_MERGING_EQUATIONS: %s\n", USE_COLUMN_MERGING_EQUATIONS ? "Yes" : "No" );
printf( " GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD: %.15lf\n", GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD );
printf( " MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS: %u\n", MAX_SIMPLEX_PIVOT_SEARCH_ITERATIONS );
printf( " CONSTRAINT_VIOLATION_THRESHOLD: %u\n", CONSTRAINT_VIOLATION_THRESHOLD );
Expand Down
4 changes: 4 additions & 0 deletions src/configuration/GlobalConfiguration.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class GlobalConfiguration
// to pick another element.
static const double ACCEPTABLE_SIMPLEX_PIVOT_THRESHOLD;

// If true, column-merging equations are given special treatment and cause columns in the tableau
// to be merged (instead of a new row added).
static const bool USE_COLUMN_MERGING_EQUATIONS;

// If a pivot element in a Gaussian elimination iteration is smaller than this threshold times
// the largest element in the column, the elimination engine will attempt to pick another pivot.
static const double GAUSSIAN_ELIMINATION_PIVOT_SCALE_THRESHOLD;
Expand Down
137 changes: 120 additions & 17 deletions src/engine/Engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -743,27 +743,130 @@ void Engine::applySplit( const PiecewiseLinearCaseSplit &split )
List<Equation> equations = split.getEquations();
for ( auto &equation : equations )
{
unsigned auxVariable = _tableau->addEquation( equation );
_activeEntryStrategy->resizeHook( _tableau );

switch ( equation._type )
/*
In the general case, we just add the new equation to the tableau.
However, we also support a very common case: equations of the form
x1 = x2, which are common, e.g., with ReLUs. For these equations we
may be able to merge two columns of the tableau.
*/
unsigned x1, x2;
bool canMergeColumns =
// Only if the flag is on
GlobalConfiguration::USE_COLUMN_MERGING_EQUATIONS &&
// Only if the equation has the correct form
equation.isVariableMergingEquation( x1, x2 ) &&
// And only if the variables are not out of bounds
( !_tableau->isBasic( x1 ) ||
!_tableau->basicOutOfBounds( _tableau->variableToIndex( x1 ) ) )
&&
( !_tableau->isBasic( x2 ) ||
!_tableau->basicOutOfBounds( _tableau->variableToIndex( x2 ) ) );

if ( canMergeColumns )
{
case Equation::GE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;
/*
Special case: x1 and x2 need to be merged.
First, we need to ensure they are both non-basic.
*/
unsigned n = _tableau->getN();
unsigned m = _tableau->getM();

if ( _tableau->isBasic( x1 ) )
{
TableauRow x1Row( n - m );
_tableau->getTableauRow( _tableau->variableToIndex( x1 ), &x1Row );

case Equation::LE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
break;
bool found = false;
unsigned nonBasic;
for ( unsigned i = 0; i < n - m; ++i )
{
if ( !FloatUtils::isZero( x1Row._row[i]._coefficient ) && ( x1Row._row[i]._var != x2 ) )
{
found = true;
nonBasic = x1Row._row[i]._var;
break;
}
}

case Equation::EQ:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;
if ( !found )
throw ReluplexError( ReluplexError::ENGINE_APPLY_SPLIT_FAILED,
"Could not find a variable to pivot with" );

default:
ASSERT( false );
break;
_tableau->setEnteringVariableIndex( _tableau->variableToIndex( nonBasic ) );
_tableau->setLeavingVariableIndex( _tableau->variableToIndex( x1 ) );

// Make sure the change column and pivot row are up-to-date - strategies
// such as projected steepest edge need these for their internal updates.
_tableau->computeChangeColumn();
_tableau->computePivotRow();

_activeEntryStrategy->prePivotHook( _tableau, false );
_tableau->performDegeneratePivot();
_activeEntryStrategy->prePivotHook( _tableau, false );
}

if ( _tableau->isBasic( x2 ) )
{
TableauRow x2Row( n - m );
_tableau->getTableauRow( _tableau->variableToIndex( x2 ), &x2Row );

bool found = false;
unsigned nonBasic;
for ( unsigned i = 0; i < n - m; ++i )
{
if ( !FloatUtils::isZero( x2Row._row[i]._coefficient ) && ( x2Row._row[i]._var != x1 ) )
{
found = true;
nonBasic = x2Row._row[i]._var;
break;
}
}

if ( !found )
throw ReluplexError( ReluplexError::ENGINE_APPLY_SPLIT_FAILED,
"Could not find a variable to pivot with" );

_tableau->setEnteringVariableIndex( _tableau->variableToIndex( nonBasic ) );
_tableau->setLeavingVariableIndex( _tableau->variableToIndex( x2 ) );

// Make sure the change column and pivot row are up-to-date - strategies
// such as projected steepest edge need these for their internal updates.
_tableau->computeChangeColumn();
_tableau->computePivotRow();

_activeEntryStrategy->prePivotHook( _tableau, false );
_tableau->performDegeneratePivot();
_activeEntryStrategy->prePivotHook( _tableau, false );
}

// Both variables are now non-basic, so we can merge their columns
_tableau->mergeColumns( x1, x2 );
}
else
{
// General case: add a new equation to the tableau
unsigned auxVariable = _tableau->addEquation( equation );
_activeEntryStrategy->resizeHook( _tableau );

switch ( equation._type )
{
case Equation::GE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;

case Equation::LE:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
break;

case Equation::EQ:
bounds.append( Tightening( auxVariable, 0.0, Tightening::LB ) );
bounds.append( Tightening( auxVariable, 0.0, Tightening::UB ) );
break;

default:
ASSERT( false );
break;
}
}
}

Expand Down
24 changes: 24 additions & 0 deletions src/engine/Equation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,30 @@ void Equation::dump() const
printf( "%.2lf\n", _scalar );
}

bool Equation::isVariableMergingEquation( unsigned &x1, unsigned &x2 ) const
{
if ( _addends.size() != 2 )
return false;

if ( !FloatUtils::isZero( _scalar ) )
return false;

double coefficientOne = _addends.front()._coefficient;
double coefficientTwo = _addends.back()._coefficient;

if ( FloatUtils::isZero( coefficientOne ) || FloatUtils::isZero( coefficientTwo ) )
return false;

if ( FloatUtils::areEqual( coefficientOne, -coefficientTwo ) )
{
x1 = _addends.front()._variable;
x2 = _addends.back()._variable;
return true;
}

return false;
}

//
// Local Variables:
// compile-command: "make -C ../.. "
Expand Down
7 changes: 7 additions & 0 deletions src/engine/Equation.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ class Equation
*/
void updateVariableIndex( unsigned oldVar, unsigned newVar );

/*
Return true iff the variable is a "variable merging equation",
i.e. an equation of the form x = y. If true is returned, x1 and
x2 are the merged variables.
*/
bool isVariableMergingEquation( unsigned &x1, unsigned &x2 ) const;

List<Addend> _addends;
double _scalar;
EquationType _type;
Expand Down
1 change: 1 addition & 0 deletions src/engine/ITableau.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class ITableau
virtual Equation *getBasisEquation( unsigned row ) const = 0;
virtual double *getInverseBasisMatrix() const = 0;
virtual void refreshBasisFactorization() = 0;
virtual void mergeColumns( unsigned x1, unsigned x2 ) = 0;
};

#endif // __ITableau_h__
Expand Down
1 change: 1 addition & 0 deletions src/engine/ReluplexError.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class ReluplexError : public Error
CANNOT_RESTORE_TABLEAU = 12,
FAILURE_TO_ADD_NEW_EQUATION = 13,
RESTORATION_FAILED_TO_REFACTORIZE_BASIS = 14,
ENGINE_APPLY_SPLIT_FAILED = 15,

DEBUGGING_ERROR = 999,
};
Expand Down
48 changes: 46 additions & 2 deletions src/engine/Tableau.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,17 @@ const double *Tableau::getUpperBounds() const

double Tableau::getValue( unsigned variable )
{
/*
If this variable has been merged into another,
we need to be reading the other variable's value
*/
if ( _mergedVariables.exists( variable ) )
variable = _mergedVariables[variable];

// The values of non-basics can be extracted even if the
// assignment is invalid
if ( !_basicVariables.exists( variable ) )
{
// The values of non-basics can be extracted even if the
// assignment is invalid
unsigned index = _variableToIndex[variable];
return _nonBasicAssignment[index];
}
Expand Down Expand Up @@ -1112,6 +1119,9 @@ void Tableau::storeState( TableauState &state ) const

// Store the _boundsValid indicator
state._boundsValid = _boundsValid;

// Store the merged variables
state._mergedVariables = _mergedVariables;
}

void Tableau::restoreState( const TableauState &state )
Expand Down Expand Up @@ -1149,6 +1159,9 @@ void Tableau::restoreState( const TableauState &state )
// Restore the _boundsValid indicator
_boundsValid = state._boundsValid;

// Restore the merged varaibles
_mergedVariables = state._mergedVariables;

computeAssignment();
_costFunctionManager->initialize();
computeCostFunction();
Expand Down Expand Up @@ -1911,6 +1924,8 @@ void Tableau::registerCostFunctionManager( ICostFunctionManager *costFunctionMan
const double *Tableau::getColumnOfBasis( unsigned column ) const
{
ASSERT( column < _m );
ASSERT( !_mergedVariables.exists( _basicIndexToVariable[column] ) );

unsigned variable = _basicIndexToVariable[column];
return _A + ( variable * _m );
}
Expand All @@ -1920,6 +1935,35 @@ void Tableau::refreshBasisFactorization()
_basisFactorization->obtainFreshBasis();
}

void Tableau::mergeColumns( unsigned x1, unsigned x2 )
{
ASSERT( !isBasic( x1 ) );
ASSERT( !isBasic( x2 ) );

/*
If x2 has tighter bounds than x1, adjust the bounds
for x1.
*/
if ( FloatUtils::lt( _upperBounds[x2], _upperBounds[x1] ) )
tightenUpperBound( x1, _upperBounds[x2] );
if ( FloatUtils::gt( _lowerBounds[x2], _lowerBounds[x1] ) )
tightenLowerBound( x1, _lowerBounds[x2] );

/*
Merge column x2 of the constraint matrix into x1
and zero-out column x2
*/
for ( unsigned row = 0; row < _m; ++row )
{
_A[(x1 * _m) + row] += _A[(x2 * _m) + row];
_A[(x2 * _m) + row] = 0.0;
}
_mergedVariables[x2] = x1;

computeAssignment();
computeCostFunction();
}

//
// Local Variables:
// compile-command: "make -C ../.. "
Expand Down
13 changes: 13 additions & 0 deletions src/engine/Tableau.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,12 @@ class Tableau : public ITableau, public IBasisFactorization::BasisColumnOracle
*/
void refreshBasisFactorization();

/*
Merge two columns of the constraint matrix and re-initialize
the tableau.
*/
void mergeColumns( unsigned x1, unsigned x2 );

private:
/*
Variable watchers
Expand Down Expand Up @@ -547,6 +553,13 @@ class Tableau : public ITableau, public IBasisFactorization::BasisColumnOracle
*/
ICostFunctionManager *_costFunctionManager;

/*
_mergedVariables[x] = y means that x = y, and that
variable x has been merged into variable y. So, when
extracting a solution for x, we should read the value of y.
*/
Map<unsigned, unsigned> _mergedVariables;

/*
Free all allocated memory.
*/
Expand Down
8 changes: 8 additions & 0 deletions src/engine/TableauState.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "IBasisFactorization.h"
#include "ITableau.h"
#include "Map.h"
#include "Set.h"

class TableauState
Expand Down Expand Up @@ -104,6 +105,13 @@ class TableauState
Indicator whether the bounds are valid
*/
bool _boundsValid;

/*
_mergedVariables[x] = y means that x = y, and that
variable x has been merged into variable y. So, when
extracting a solution for x, we should read the value of y.
*/
Map<unsigned, unsigned> _mergedVariables;
};

#endif // __TableauState_h__
Expand Down
4 changes: 4 additions & 0 deletions src/engine/tests/MockTableau.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,10 @@ class MockTableau : public ITableau
void refreshBasisFactorization()
{
}

void mergeColumns( unsigned /* x1 */, unsigned /* x2 */ )
{
}
};

#endif // __MockTableau_h__
Expand Down

0 comments on commit f13af12

Please sign in to comment.