Skip to content

Commit

Permalink
results still differ from old LogGaussianProcess, but suspect LGP is …
Browse files Browse the repository at this point in the history
…wrong in mean
  • Loading branch information
msuchard committed Nov 24, 2023
1 parent bb7d8c6 commit ed7214c
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class GaussianProcessFieldParser extends AbstractXMLObjectParser {
private static final String DIMENSION = "dim";
private static final String MEAN = "mean";
private static final String BASES = "bases";
private static final String NOISE = "gaussianNoise";

public String getParserName() { return PARSER_NAME; }

Expand All @@ -57,13 +58,15 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {

RandomField.WeightProvider weights = parseWeightProvider(xo, dim);

Parameter noise = xo.hasChildNamed(NOISE) ? (Parameter) xo.getElementFirstChild(NOISE) : null;

String id = xo.hasId() ? xo.getId() : PARSER_NAME;

int order = 1;

List<BasisDimension> bases = parseBases(xo);

return new AdditiveGaussianProcessDistribution(id, order, dim, mean, null, bases, weights);
return new AdditiveGaussianProcessDistribution(id, order, dim, mean, noise, bases, weights);
}

private List<BasisDimension> parseBases(XMLObject xo) {
Expand All @@ -88,6 +91,7 @@ private List<BasisDimension> parseBases(XMLObject xo) {
new ElementRule(DesignMatrix.class),
// TODO parse kernel
}, 0, Integer.MAX_VALUE),
new ElementRule(NOISE, Parameter.class, "", true),
WEIGHTS_RULE,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
String id = xo.hasId() ? xo.getId() : PARSER_NAME;

List<Parameter> parameters = new ArrayList<>();
parameters.add((Parameter) xo.getChild(Parameter.class));
Parameter parameter = (Parameter) xo.getChild(Parameter.class);
if (parameter != null) {
parameters.add(parameter);
}

final GaussianProcessKernel kernel;
try {
Expand All @@ -60,7 +63,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {

private final XMLSyntaxRule[] rules = {
AttributeRule.newStringRule(TYPE),
new ElementRule(Parameter.class),
new ElementRule(Parameter.class, true),
};

public String getParserDescription() { // TODO update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import dr.inference.distribution.RandomField;
import dr.inference.model.*;
import dr.math.distributions.RandomFieldDistribution;
import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionCommon_D64;
import org.ejml.data.Complex64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.decomposition.CholeskyDecomposition;
Expand All @@ -36,6 +38,8 @@
import java.util.Arrays;
import java.util.List;

import static dr.math.matrixAlgebra.missingData.MissingOps.invertAndGetDeterminant;

/**
* @author Marc Suchard
* @author Filippo Monti
Expand Down Expand Up @@ -68,6 +72,8 @@ public class AdditiveGaussianProcessDistribution extends RandomFieldDistribution
private boolean precisionAndDeterminantKnown;
private boolean gramianAndVarianceKnown;

private static final boolean USE_CHOLESKY = true;

public AdditiveGaussianProcessDistribution(String name,
int order,
int dim,
Expand All @@ -94,7 +100,7 @@ public AdditiveGaussianProcessDistribution(String name,
this.precision = new DenseMatrix64F(dim, dim);
this.variance = new DenseMatrix64F(dim, dim);

this.solver = LinearSolverFactory.symmPosDef(dim);
this.solver = USE_CHOLESKY ? LinearSolverFactory.symmPosDef(dim) : null;

if (meanParameter != null) {
addVariable(meanParameter);
Expand Down Expand Up @@ -139,9 +145,33 @@ private void computeGramianAndVariance() {

private void computePrecisionAndDeterminant() {
DenseMatrix64F variance = getVariance();
solver.solve(variance, precision);
CholeskyDecomposition<DenseMatrix64F> d = solver.getDecomposition();
logDeterminant = Math.log(d.computeDeterminant().getReal());
if (USE_CHOLESKY) {
if (!solver.setA(variance)) {
throw new RuntimeException("Unable to decompose matrix");
}

solver.invert(precision);
logDeterminant = computeLogDeterminantFromTriangularMatrix(
((CholeskyDecompositionCommon_D64) solver.getDecomposition()).getT());
} else {
logDeterminant = invertAndGetDeterminant(variance, precision, true);
}
}

private double computeLogDeterminantFromTriangularMatrix(DenseMatrix64F T) {

final int n = T.numCols;
double[] t = T.getData();

double sum = 0.0;
int total = n * n;

for(int i = 0; i < total; i += n + 1) {
sum += Math.log(t[i]);
}

double logDet = 2 * sum;
return logDet;
}

private double[] getPrecision() {
Expand Down Expand Up @@ -236,8 +266,9 @@ public double logPdf(double[] x) {
exponent += diff[i] * precision[i * dim + j] * diff[j];
}
}

return getLogDeterminant() - exponent / 2; // TODO + normalizing constant

double logLikelihood = -0.5 * (dim * Math.log(2 * Math.PI) - getLogDeterminant()) - 0.5 * exponent;
return logLikelihood;
}

@Override
Expand Down Expand Up @@ -326,8 +357,8 @@ public static void computeAdditiveGramian(DenseMatrix64F gramian,

for (int i = 0; i < rowDim; ++i) {
for (int j = 0; j < colDim; ++j) {
double xi = design1.getParameterValue(0, i); // TODO make generic dimension
double xj = design2.getParameterValue(0, j); // TODO make generic dimension
double xi = design1.getParameterValue(i, 0); // TODO make generic dimension
double xj = design2.getParameterValue(j, 0); // TODO make generic dimension
gramian.add(i, j, scale * kernel.getCorrelation(xi, xj));
}
}
Expand Down
5 changes: 4 additions & 1 deletion src/dr/math/distributions/gp/GaussianProcessPrediction.java
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,11 @@ public void modelChangedEvent(Model model, Object object, int index) {

@Override
public void variableChangedEvent(Variable variable, int index, Variable.ChangeType type) {
if (variable == realizedValues || predictiveDesigns.contains(variable)) {
if (variable == realizedValues) {
predictionKnown = false;
} else if (variable instanceof DesignMatrix &&
predictiveDesigns.contains((DesignMatrix) variable)) {
throw new IllegalArgumentException("Not yet implemented");
} else {
throw new IllegalArgumentException("Unknown variable");
}
Expand Down

0 comments on commit ed7214c

Please sign in to comment.