Skip to content

Commit

Permalink
cache some affine correction calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Oct 8, 2023
1 parent 52331dc commit c8e443b
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 88 deletions.
2 changes: 1 addition & 1 deletion ci/TestXML/testAffineCorrectionGradient.xml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
</patterns>

<newick id="startingTree">
(A:1.0,B:1.0);
(A:2.0,B:1.0);
</newick>

<treeModel id="treeModel">
Expand Down
44 changes: 29 additions & 15 deletions src/dr/evomodel/substmodel/DifferentiableSubstitutionModelUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -323,28 +323,42 @@ public static double[] getQQPlus(double[] eigenVectors, double[] inverseEigenVec
}
}

return result;
}

public static double[] getQPlusQ(double[] eigenVectors, double[] inverseEigenVectors, int index, int stateCount) {

double[] result = new double[stateCount * stateCount];

for (int i = 0; i < stateCount; ++i) {
for (int j = 0; j < stateCount; ++j) {
double sum = 0.0;
for (int k = 0; k < stateCount; ++k) {
if (k != index) {
sum += inverseEigenVectors[i * stateCount + k] * eigenVectors[k * stateCount + j];
}
double[] reduced = new double[stateCount];
for (int j = 0; j < stateCount; ++j) {
double sum = 0.0;
for (int k = 0; k < stateCount; ++k) {
if (k != index) {
sum += eigenVectors[k] * inverseEigenVectors[k * stateCount + j];
}
result[i * stateCount + j] = sum;
}
reduced[j] = sum;
}
reduced[0] -=1;

// TODO Determine the stateCount unique values and just return them

return result;
}

// public static double[] getQPlusQ(double[] eigenVectors, double[] inverseEigenVectors, int index, int stateCount) {
//
// double[] result = new double[stateCount * stateCount];
//
// for (int i = 0; i < stateCount; ++i) {
// for (int j = 0; j < stateCount; ++j) {
// double sum = 0.0;
// for (int k = 0; k < stateCount; ++k) {
// if (k != index) {
// sum += inverseEigenVectors[i * stateCount + k] * eigenVectors[k * stateCount + j];
// }
// }
// result[i * stateCount + j] = sum;
// }
// }
//
// return result;
// }

private static int index12(int i, int j, int stateCount) {
return i * stateCount + j;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.ModelListener;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Citable;
import dr.xml.Reportable;

import java.util.ArrayList;
import java.util.List;
import java.util.HashMap;
import java.util.logging.Logger;

/**
Expand All @@ -61,8 +61,6 @@ public abstract class AbstractLogAdditiveSubstitutionModelGradient implements

protected final ComplexSubstitutionModel substitutionModel;
protected final int stateCount;
// protected final int whichSubstitutionModel;
// protected final int substitutionModelCount;
protected final List<Integer> crossProductAccumulationMap;

private final ApproximationMode mode;
Expand All @@ -74,25 +72,72 @@ public enum ApproximationMode {
public String getInfo() {
return "a first-order";
}

@Override
CorrectionTermCache createCache() { return null; }

@Override
void emptyCache(CorrectionTermCache cache) { }

@Override
double computeCorrection(int i, int j, double[] crossProducts, int stateCount,
AbstractLogAdditiveSubstitutionModelGradient gradient) {
return 0.0;
}
},
AFFINE_CORRECTED("affineCorrected") {
@Override
public String getInfo() {
return "an affine-corrected";
}

@Override
CorrectionTermCache createCache() {
return new CorrectionTermCache();
}

@Override
void emptyCache(CorrectionTermCache cache) {
cache.clear();
}

@Override
double computeCorrection(int i, int j, double[] crossProducts, int stateCount,
AbstractLogAdditiveSubstitutionModelGradient gradient) {

double[] affineMatrix = gradient.getAffineCorrectionMatrix(i, j);

double correction = 0.0;
for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; ++n) {
correction += crossProducts[m * stateCount + n] *
affineMatrix[m * stateCount + n];
}
}

return correction;
}

};

ApproximationMode(String label) {
this.label = label;
}

public abstract String getInfo();
abstract String getInfo();

abstract CorrectionTermCache createCache();

abstract void emptyCache(CorrectionTermCache cache);

abstract double computeCorrection(int i, int j, double[] crossProducts, int stateCount,
AbstractLogAdditiveSubstitutionModelGradient gradient);

public String getLabel() {
return label;
}

private String label;
private final String label;

public static ApproximationMode factory(String label) {
for (ApproximationMode mode : ApproximationMode.values()) {
Expand Down Expand Up @@ -123,20 +168,10 @@ public AbstractLogAdditiveSubstitutionModelGradient(String traitName,
this.branchModel = likelihoodDelegate.getBranchModel();
this.substitutionModel = substitutionModel;
this.stateCount = substitutionModel.getDataType().getStateCount();

// this.whichSubstitutionModel = determineSubstitutionNumber(
// likelihoodDelegate.getBranchModel(), substitutionModel);
// this.substitutionModelCount = determineSubstitutionModelCount(likelihoodDelegate.getBranchModel());

this.crossProductAccumulationMap = createCrossProductAccumulationMap(likelihoodDelegate.getBranchModel(),
substitutionModel);

this.mode = mode;

// this.crossProductAccumulationMap = new int[0];
// if (substitutionModelCount > 1) {
// updateCrossProductAccumulationMap();
// }
this.correctionTermCache = mode.createCache();

String name = SubstitutionModelCrossProductDelegate.getName(traitName);

Expand All @@ -153,6 +188,9 @@ public AbstractLogAdditiveSubstitutionModelGradient(String traitName,
treeTraitProvider = treeDataLikelihood.getTreeTrait(name);
assert (treeTraitProvider != null);

this.branchModel.addModelListener(this);
this.substitutionModel.addModelListener(this);

Logger.getLogger("dr.evomodel.treedatalikelihood.discrete").info(
"Gradient wrt " + traitName + " using " + mode.getInfo() + " approximation");
}
Expand All @@ -174,16 +212,9 @@ public double[] getGradientLogDensity() {
double[] crossProducts = (double[]) treeTraitProvider.getTrait(tree, null);
double[] generator = new double[crossProducts.length];

// if (whichSubstitutionModel > 1 || substitutionModelCount > 1) {
accumulateAcrossSubstitutionModelInstances(crossProducts);
// }

substitutionModel.getInfinitesimalMatrix(generator);
// crossProducts = correctDifferentials(crossProducts);

if (DEBUG_CROSS_PRODUCTS) {
savedDifferentials = crossProducts.clone();
}

double[] pi = substitutionModel.getFrequencyModel().getFrequencies();

Expand All @@ -206,45 +237,47 @@ public double[] getGradientLogDensity() {
return gradient;
}

double correction(int i, int j, double[] crossProducts) {
private double[] getAffineCorrectionMatrix(int i, int j) {

if (mode == ApproximationMode.FIRST_ORDER) {
return 0.0;
}
double[] affineMatrix = correctionTermCache.get(i * stateCount + j);

double[] affineMatrix = new double[stateCount * stateCount];
if (affineMatrix == null) {

if (crossProductAccumulationMap.size() > 1) {
throw new RuntimeException("Not yet implemented");
}
affineMatrix = new double[stateCount * stateCount]; // TODO there are only stateCount unique values

// TODO Start cache for each (i,j) .. only depends on substitutionModel
EigenDecomposition ed = substitutionModel.getEigenDecomposition();
int index = findZeroEigenvalueIndex(ed.getEigenValues());
if (crossProductAccumulationMap.size() > 1) {
throw new RuntimeException("Not yet implemented");
}

double[] eigenVectors = ed.getEigenVectors();
double[] inverseEigenVectors = ed.getInverseEigenVectors();
// TODO Start cache for each (i,j) .. only depends on substitutionModel
EigenDecomposition ed = substitutionModel.getEigenDecomposition();
int index = findZeroEigenvalueIndex(ed.getEigenValues());

double[] qQPlus = getQQPlus(eigenVectors, inverseEigenVectors, index);
double[] eigenVectors = ed.getEigenVectors();
double[] inverseEigenVectors = ed.getInverseEigenVectors();

for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; n++) {
// TODO there are only stateCount unique values
affineMatrix[index12(m,n)] = (m == i) ?
(qQPlus[index12(m,i)] - 1.0) * qQPlus[index12(j,n)] :
qQPlus[index12(m,i)] * qQPlus[index12(j,n)];
}
}
// TODO End cache
double[] qQPlus = getQQPlus(eigenVectors, inverseEigenVectors, index);

double correction = 0.0;
for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; ++n) {
correction += crossProducts[index12(m,n)] * affineMatrix[index12(m,n)];
for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; n++) {
// TODO there are only stateCount unique values
affineMatrix[index12(m, n)] = (m == i) ?
(qQPlus[index12(m, i)] - 1.0) * qQPlus[index12(j, n)] :
qQPlus[index12(m, i)] * qQPlus[index12(j, n)];
}
}
// TODO End cache

correctionTermCache.put(i * stateCount + j, affineMatrix);
} else {
System.err.println("Using cached value");
}

return correction;
return affineMatrix;
}

double correction(int i, int j, double[] crossProducts) {
return mode.computeCorrection(i, j, crossProducts, stateCount, this);
}

private int findZeroEigenvalueIndex(double[] eigenvalues) {
Expand All @@ -260,19 +293,10 @@ private double[] getQQPlus(double[] eigenVectors, double[] inverseEigenVectors,
return DifferentiableSubstitutionModelUtil.getQQPlus(eigenVectors, inverseEigenVectors, index, stateCount);
}

private double[] getQPlusQ(double[] eigenVectors, double[] inverseEigenVectors, int index) {
return DifferentiableSubstitutionModelUtil.getQPlusQ(eigenVectors, inverseEigenVectors, index, stateCount);
}

private int index12(int i, int j) {
return i * stateCount + j;
}

@SuppressWarnings("unused")
private int index21(int i, int j) {
return j * stateCount + i;
}

private void accumulateAcrossSubstitutionModelInstances(double[] crossProducts) {
final int length = stateCount * stateCount;

Expand Down Expand Up @@ -325,10 +349,6 @@ public String getReport() {
String message = GradientWrtParameterProvider.getReportAndCheckForError(this, 0.0, Double.POSITIVE_INFINITY, getReportTolerance());
sb.append(message);

if (DEBUG_CROSS_PRODUCTS) {
sb.append("\n\tdifferentials: ").append(new WrappedVector.Raw(savedDifferentials, 0, savedDifferentials.length));
}

if (COUNT_TOTAL_OPERATIONS) {
sb.append("\n\tgetCrossProductGradientCount = ").append(gradientCount);
sb.append("\n\taverageGradientTime = ");
Expand All @@ -348,15 +368,14 @@ Double getReportTolerance() {
return null;
}

// This has not been rigorously tested for epochs that change structure
@SuppressWarnings("unused")
protected void handleModelChangedEvent(Model model, Object object, int index) {
// if (model == branchModel) {
// updateCrossProductAccumulationMap();
// }
if (model == branchModel) {
// crossProductAccumulationMap = createCrossProductAccumulationMap(branchModel, substitutionModel);
throw new RuntimeException("Not yet implemented");
} else if (model == substitutionModel) {
mode.emptyCache(correctionTermCache);
} else {
throw new RuntimeException("Unknown model");
}
}

Expand All @@ -368,11 +387,11 @@ public void modelRestored(Model model) {

}

protected static final boolean COUNT_TOTAL_OPERATIONS = false;
protected static final boolean DEBUG_CROSS_PRODUCTS = false;
static class CorrectionTermCache extends HashMap<Integer, double[]> { }

protected double[] savedDifferentials;
private final CorrectionTermCache correctionTermCache;

protected static final boolean COUNT_TOTAL_OPERATIONS = false;
protected long gradientCount = 0;
protected long totalGradientTime = 0;
}

0 comments on commit c8e443b

Please sign in to comment.