Skip to content

Commit

Permalink
change name to AdditiveGaussianProcess
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Nov 23, 2023
1 parent 8b93fbe commit 90a960f
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 32 deletions.
33 changes: 26 additions & 7 deletions src/dr/inferencexml/distribution/GaussianProcessFieldParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@
import dr.inference.distribution.RandomField;
import dr.inference.model.DesignMatrix;
import dr.inference.model.Parameter;
import dr.math.distributions.gp.GaussianProcessDistribution;
import dr.math.distributions.gp.Kernel;
import dr.math.distributions.gp.AdditiveGaussianProcessDistribution;
import dr.math.distributions.gp.AdditiveKernel;
import dr.xml.*;

import java.util.ArrayList;
import java.util.List;

import static dr.math.distributions.gp.AdditiveGaussianProcessDistribution.BasisDimension;
import static dr.inferencexml.distribution.RandomFieldParser.WEIGHTS_RULE;
import static dr.inferencexml.distribution.RandomFieldParser.parseWeightProvider;

Expand All @@ -55,9 +59,23 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {

String id = xo.hasId() ? xo.getId() : PARSER_NAME;

return new GaussianProcessDistribution(id, dim, mean,
new Kernel.DotProduct(null, null),
weights);
int order = 1;

List<BasisDimension> bases = parseBases(xo);

return new AdditiveGaussianProcessDistribution(id, order, dim, mean, bases, weights);
}

private List<BasisDimension> parseBases(XMLObject xo) {
List<BasisDimension> bases = new ArrayList<>();
for (XMLObject cxo : xo.getAllChildren(BASES)) {
bases.add(new BasisDimension(
new AdditiveKernel.DotProduct(null, null),
(DesignMatrix) cxo.getChild(DesignMatrix.class)
));
}

return bases;
}

public XMLSyntaxRule[] getSyntaxRules() { return rules; }
Expand All @@ -67,8 +85,9 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {
new ElementRule(MEAN,
new XMLSyntaxRule[]{new ElementRule(Parameter.class)}, true),
new ElementRule(BASES, new XMLSyntaxRule[] {
new ElementRule(DesignMatrix.class) },
0, Integer.MAX_VALUE),
new ElementRule(DesignMatrix.class),
// TODO parse kernel
}, 0, Integer.MAX_VALUE),
WEIGHTS_RULE,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,26 @@
import org.ejml.factory.LinearSolverFactory;
import org.ejml.interfaces.decomposition.CholeskyDecomposition;
import org.ejml.interfaces.linsol.LinearSolver;
import org.ejml.ops.CommonOps;

import java.util.Arrays;
import java.util.List;

/**
* @author Marc Suchard
* @author Filippo Monti
* //
* Duvenaud DK, Nickisch H, Rasmussen C. Additive Gaussian processes. In Shawe-Taylor J, Zemel R, Bartlett P, Pereira F, Weinberger KQ (eds.), Advances in Neural Information Processing Systems, volume 24. Curran Associates, Inc., 2011.
* URL <a href="https://proceedings.neurips.cc/paper/2011/file/4c5bde74a8f110656874902f07378009-Paper.pdf"/>
*/
public class GaussianProcessDistribution extends RandomFieldDistribution {
public class AdditiveGaussianProcessDistribution extends RandomFieldDistribution {

public static final String TYPE = "GaussianProcess";

private final int order;

private final int dim;
private final Parameter meanParameter;
private final Kernel kernel;
private final List<BasisDimension> bases;
private final RandomField.WeightProvider weightProvider;

private final double[] mean;
Expand All @@ -61,16 +66,22 @@ public class GaussianProcessDistribution extends RandomFieldDistribution {
private boolean precisionAndDeterminantKnown;
private boolean varianceKnown;

public GaussianProcessDistribution(String name,
int dim,
Parameter meanParameter,
Kernel kernel,
RandomField.WeightProvider weightProvider) {
public AdditiveGaussianProcessDistribution(String name,
int order,
int dim,
Parameter meanParameter,
List<BasisDimension> bases,
RandomField.WeightProvider weightProvider) {
super(name);

if (order != 1) {
throw new RuntimeException("Not yet implemented");
}

this.order = order;
this.dim = dim;
this.meanParameter = meanParameter;
this.kernel = kernel;
this.bases = bases;
this.weightProvider = weightProvider;

this.mean = new double[dim];
Expand All @@ -82,33 +93,53 @@ public GaussianProcessDistribution(String name,

addVariable(meanParameter);

if (kernel instanceof AbstractModel) {
addModel((AbstractModel) kernel);
for (BasisDimension basis : bases) {
AdditiveKernel kernel = basis.getKernel();
if (kernel instanceof AbstractModel) {
addModel((AbstractModel) kernel);
}
addVariable(basis.getDesignMatrix());
}

if (weightProvider != null) {
addModel(weightProvider);
}
}

private DenseMatrix64F getPrecision() {
private double[] getPrecision() {
if (!precisionAndDeterminantKnown) {
DenseMatrix64F variance = getVariance();
solver.solve(variance, precision);
CholeskyDecomposition<DenseMatrix64F> d = solver.getDecomposition();
logDeterminant = Math.log(d.computeDeterminant().getReal());
precisionAndDeterminantKnown = true;
}
return precision;
return precision.getData();
}

private DenseMatrix64F getVariance() {
if (!varianceKnown) {
for (int i = 0; i < dim; ++i) {
for (int j = 0; j < dim; ++j) {
variance.set(i, j, kernel.getCorrelation(0, 0)); // TODO
variance.zero();

// 1st order contribution
for (BasisDimension basis : bases) {
final AdditiveKernel kernel = basis.getKernel();
final DesignMatrix design = basis.getDesignMatrix();
final double scale = kernel.getScale(); // TODO is this term necessary? or scale only needed at the order-level

for (int i = 0; i < dim; ++i) {
for (int j = 0; j < dim; ++j) {
double xi = design.getParameterValue(0, i); // TODO make generic dimension
double xj = design.getParameterValue(0, j); // TODO make generic dimension
variance.add(i, j, scale * kernel.getCorrelation(xi, xj));
}
}
}

for (int n = 1; n < order; ++n) {
// TODO higher-order terms via Newton-Girard formula
}

varianceKnown = true;
}
return variance;
Expand Down Expand Up @@ -161,7 +192,7 @@ public double logPdf(double[] x) {

final double[] mean = getMean();
final double[] diff = tmp;
final double[] precision = getPrecision().getData();
final double[] precision = getPrecision();

for (int i = 0; i < dim; ++i) {
diff[i] = x[i] - mean[i];
Expand Down Expand Up @@ -224,4 +255,20 @@ protected void restoreState() {

@Override
protected void acceptState() { }

public static class BasisDimension {

private final AdditiveKernel kernel;
private final DesignMatrix design;

public BasisDimension(AdditiveKernel kernel, DesignMatrix design) {
this.kernel = kernel;
this.design = design;
}

AdditiveKernel getKernel() { return kernel; }

DesignMatrix getDesignMatrix() { return design; }
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,22 @@
* @author Marc A. Suchard
* @author Filippo Monti
*/
public interface Kernel {
public interface AdditiveKernel {

double getCorrelation(double x, double y);

class DotProduct extends Base implements Kernel {
double getScale();

class DotProduct extends Base {

public DotProduct(String name, List<Parameter> parameters) {
super(name, parameters);
}

public double getCorrelation(double x, double y) {
double sigma = parameters.get(0).getParameterValue(0);
return sigma * x * y;
return x * y;
}

}

class RadialBasisFunction extends Base {
Expand All @@ -38,11 +40,11 @@ public double getCorrelation(double x, double y) {
double length = parameters.get(1).getParameterValue(0);
double diff = x - y;

return sigma * Math.exp(-(diff * diff) / (2 * length * length));
return Math.exp(-(diff * diff) / (2 * length * length));
}
}

class Base extends AbstractModel {
abstract class Base extends AbstractModel implements AdditiveKernel {

final List<Parameter> parameters;

Expand All @@ -57,6 +59,12 @@ public Base(String name,
}
}

@Override
public double getScale() {
return 1.0;
// return parameters.get(0).getParameterValue(0); // TODO
}

@Override
protected void handleModelChangedEvent(Model model, Object object, int index) { }

Expand Down
4 changes: 2 additions & 2 deletions src/dr/math/distributions/gp/GaussianProcessPrediction.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
*/
public class GaussianProcessPrediction implements Loggable, VariableListener, ModelListener {

private final GaussianProcessDistribution gp;
private final AdditiveGaussianProcessDistribution gp;
private final Parameter realizedValues;
private final Parameter predictivePoints;
private final int dim;
Expand All @@ -45,7 +45,7 @@ public class GaussianProcessPrediction implements Loggable, VariableListener, Mo
private boolean predictionKnown;
private LogColumn[] columns;

public GaussianProcessPrediction(GaussianProcessDistribution gp,
public GaussianProcessPrediction(AdditiveGaussianProcessDistribution gp,
Parameter realizedValues,
Parameter predictivePoints) {
this.gp = gp;
Expand Down

0 comments on commit 90a960f

Please sign in to comment.