Skip to content

Commit

Permalink
refactor affine correction
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Oct 8, 2023
1 parent c8e443b commit 390bdd8
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,47 @@ public static double[] getQQPlus(double[] eigenVectors, double[] inverseEigenVec
}
}

// double[] reduced = new double[stateCount];
// for (int j = 0; j < stateCount; ++j) {
// double sum = 0.0;
// for (int k = 0; k < stateCount; ++k) {
// if (k != index) {
// sum += eigenVectors[k] * inverseEigenVectors[k * stateCount + j];
// }
// }
// reduced[j] = sum;
// }
// reduced[0] -=1;

// TODO Determine the stateCount unique values and just return them

return result;
}

public static double[] getQQPlus(final double[] eigenVectors,
final double[] inverseEigenVectors,
final double[] eigenValues,
final int stateCount) {

double[] result = new double[stateCount * stateCount];

for (int i = 0; i < stateCount; ++i) {
for (int j = 0; j < stateCount; ++j) {
double sum = 0.0;
for (int k = 0; k < stateCount; ++k) {
if (eigenValues[k] != 0.0) {
sum += eigenVectors[i * stateCount + k] * inverseEigenVectors[k * stateCount + j];
}
}
result[i * stateCount + j] = sum;
}
}

double[] reduced = new double[stateCount];
for (int j = 0; j < stateCount; ++j) {
double sum = 0.0;
for (int k = 0; k < stateCount; ++k) {
if (k != index) {
if (eigenValues[k] != 0.0) {
sum += eigenVectors[k] * inverseEigenVectors[k * stateCount + j];
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.HashMap;
import java.util.Map;
import java.util.logging.Logger;

/**
Expand Down Expand Up @@ -74,14 +75,15 @@ public String getInfo() {
}

@Override
CorrectionTermCache createCache() { return null; }
CorrectionTermCache createCache(SubstitutionModel model,
List<Integer> accumulateMap) { return null; }

@Override
void emptyCache(CorrectionTermCache cache) { }

@Override
double computeCorrection(int i, int j, double[] crossProducts, int stateCount,
AbstractLogAdditiveSubstitutionModelGradient gradient) {
CorrectionTermCache correctionTermCache) {
return 0.0;
}
},
Expand All @@ -92,8 +94,12 @@ public String getInfo() {
}

@Override
CorrectionTermCache createCache() {
return new CorrectionTermCache();
CorrectionTermCache createCache(SubstitutionModel model,
List<Integer> accumulateMap) {
if (accumulateMap.size() > 1) {
throw new RuntimeException("Not yet implemented");
}
return new CorrectionTermCache(model);
}

@Override
Expand All @@ -103,9 +109,9 @@ void emptyCache(CorrectionTermCache cache) {

@Override
double computeCorrection(int i, int j, double[] crossProducts, int stateCount,
AbstractLogAdditiveSubstitutionModelGradient gradient) {
CorrectionTermCache correctionTermCache) {

double[] affineMatrix = gradient.getAffineCorrectionMatrix(i, j);
double[] affineMatrix = correctionTermCache.getAffineMatrix(i, j);

double correction = 0.0;
for (int m = 0; m < stateCount; ++m) {
Expand All @@ -126,12 +132,12 @@ void emptyCache(CorrectionTermCache cache) {

abstract String getInfo();

abstract CorrectionTermCache createCache();
abstract CorrectionTermCache createCache(SubstitutionModel model, List<Integer> accumulateMap);

abstract void emptyCache(CorrectionTermCache cache);

abstract double computeCorrection(int i, int j, double[] crossProducts, int stateCount,
AbstractLogAdditiveSubstitutionModelGradient gradient);
CorrectionTermCache correctionTermCache);

public String getLabel() {
return label;
Expand All @@ -150,14 +156,7 @@ public static ApproximationMode factory(String label) {
}

private static final ApproximationMode DEFAULT_MODE = ApproximationMode.FIRST_ORDER;

public AbstractLogAdditiveSubstitutionModelGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
ComplexSubstitutionModel substitutionModel) {
this(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel, DEFAULT_MODE); // TODO Remove this constructor
}


public AbstractLogAdditiveSubstitutionModelGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
Expand All @@ -171,7 +170,7 @@ public AbstractLogAdditiveSubstitutionModelGradient(String traitName,
this.crossProductAccumulationMap = createCrossProductAccumulationMap(likelihoodDelegate.getBranchModel(),
substitutionModel);
this.mode = mode;
this.correctionTermCache = mode.createCache();
this.correctionTermCache = mode.createCache(substitutionModel, crossProductAccumulationMap);

String name = SubstitutionModelCrossProductDelegate.getName(traitName);

Expand Down Expand Up @@ -237,64 +236,8 @@ public double[] getGradientLogDensity() {
return gradient;
}

private double[] getAffineCorrectionMatrix(int i, int j) {

double[] affineMatrix = correctionTermCache.get(i * stateCount + j);

if (affineMatrix == null) {

affineMatrix = new double[stateCount * stateCount]; // TODO there are only stateCount unique values

if (crossProductAccumulationMap.size() > 1) {
throw new RuntimeException("Not yet implemented");
}

// TODO Start cache for each (i,j) .. only depends on substitutionModel
EigenDecomposition ed = substitutionModel.getEigenDecomposition();
int index = findZeroEigenvalueIndex(ed.getEigenValues());

double[] eigenVectors = ed.getEigenVectors();
double[] inverseEigenVectors = ed.getInverseEigenVectors();

double[] qQPlus = getQQPlus(eigenVectors, inverseEigenVectors, index);

for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; n++) {
// TODO there are only stateCount unique values
affineMatrix[index12(m, n)] = (m == i) ?
(qQPlus[index12(m, i)] - 1.0) * qQPlus[index12(j, n)] :
qQPlus[index12(m, i)] * qQPlus[index12(j, n)];
}
}
// TODO End cache

correctionTermCache.put(i * stateCount + j, affineMatrix);
} else {
System.err.println("Using cached value");
}

return affineMatrix;
}

double correction(int i, int j, double[] crossProducts) {
return mode.computeCorrection(i, j, crossProducts, stateCount, this);
}

private int findZeroEigenvalueIndex(double[] eigenvalues) {
for (int i = 0; i < stateCount; ++i) {
if (eigenvalues[i] == 0) {
return i;
}
}
return -1;
}

private double[] getQQPlus(double[] eigenVectors, double[] inverseEigenVectors, int index) {
return DifferentiableSubstitutionModelUtil.getQQPlus(eigenVectors, inverseEigenVectors, index, stateCount);
}

private int index12(int i, int j) {
return i * stateCount + j;
return mode.computeCorrection(i, j, crossProducts, stateCount, correctionTermCache);
}

private void accumulateAcrossSubstitutionModelInstances(double[] crossProducts) {
Expand Down Expand Up @@ -387,7 +330,62 @@ public void modelRestored(Model model) {

}

static class CorrectionTermCache extends HashMap<Integer, double[]> { }
static class CorrectionTermCache {

private final SubstitutionModel model;
private final Map<Integer, double[]> map;
private final int stateCount;
private double[] qQPlus;

CorrectionTermCache(SubstitutionModel model) {
this.model = model;
this.map = new HashMap<>();
this.stateCount = model.getDataType().getStateCount();
this.qQPlus = null;
}

private int index12(int i, int j) {
return i * stateCount + j;
}

private double[] getQQPlus() {
if (qQPlus == null) {
EigenDecomposition ed = model.getEigenDecomposition();
qQPlus = DifferentiableSubstitutionModelUtil.getQQPlus(ed.getEigenVectors(),
ed.getInverseEigenVectors(), ed.getEigenValues(), stateCount);
}
return qQPlus;
}

double[] getAffineMatrix(int i, int j) {

double[] affineMatrix = map.get(i * stateCount + j);

if (affineMatrix == null) {

affineMatrix = new double[stateCount * stateCount]; // TODO there are only stateCount unique values

double[] qQPlus = getQQPlus();

for (int m = 0; m < stateCount; ++m) {
for (int n = 0; n < stateCount; n++) {
affineMatrix[index12(m, n)] = (m == i) ?
(qQPlus[index12(m, i)] - 1.0) * qQPlus[index12(j, n)] :
qQPlus[index12(m, i)] * qQPlus[index12(j, n)];
}
}

map.put(i * stateCount + j, affineMatrix);
}

return affineMatrix;
}

public void clear() {
map.clear();
qQPlus = null;
}
}

private final CorrectionTermCache correctionTermCache;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ public GamGpSubstitutionModelGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
GlmSubstitutionModel substitutionModel) {
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel);
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel,
ApproximationMode.FIRST_ORDER);
throw new RuntimeException("Not yet implemented");
}

Expand Down

0 comments on commit 390bdd8

Please sign in to comment.