Skip to content

Commit

Permalink
all gradient approaches now implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Oct 6, 2023
1 parent 142c699 commit 8e07963
Show file tree
Hide file tree
Showing 10 changed files with 294 additions and 81 deletions.
203 changes: 203 additions & 0 deletions ci/TestXML/testAffineCorrectionGradient.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
<?xml version="1.0" encoding="utf-8"?>
<beast>

<taxa id="taxa">
<taxon id="A"/>
<taxon id="B"/>
</taxa>

<alignment id="alignment1" dataType="nucleotide">
<sequence>
<taxon idref="A"/>
A
</sequence>
<sequence>
<taxon idref="B"/>
-
</sequence>
</alignment>

<patterns id="loc.pattern" from="1">
<alignment idref="alignment1"/>
</patterns>

<newick id="startingTree">
(A:1.0,B:1.0);
</newick>

<treeModel id="treeModel">
<tree idref="startingTree"/>
<rootHeight>
<parameter id="treeModel.rootHeight"/>
</rootHeight>
<nodeHeights internalNodes="true">
<parameter id="treeModel.internalNodeHeights"/>
</nodeHeights>
<nodeHeights internalNodes="true" rootNode="true">
<parameter id="treeModel.allInternalNodeHeights"/>
</nodeHeights>
</treeModel>

<strictClockBranchRates id="loc.branchRates">
<rate>
<parameter id="loc.clock.rate" value="1E-0" lower="0.0"/>
</rate>
</strictClockBranchRates>

<glmSubstitutionModel id="loc.model" normalize="false" dataType="nucleotide">
<!--
<rootFrequencies>
<frequencyModel id="loc.rootModel" dataType="nucleotide">
<frequencies>
<parameter id="loc.root" value="1 0 0 0"/>
</frequencies>
</frequencyModel>
</rootFrequencies>
-->

<rootFrequencies>
<frequencyModel id="loc.frequencyModel" dataType="nucleotide">
<frequencies>
<parameter id="loc.frequencies" value="0.1 0.3 0.2 0.4"/>
<!-- <parameter id="loc.frequencies" value="0.25 0.25 0.25 0.25"/> -->
</frequencies>
</frequencyModel>
</rootFrequencies>

<glmModel family="logLinear" checkIdentifiability="false">

<independentVariables>
<parameter id="log.kappa" value="0.6931472"/>
<!-- <parameter id="log.kappa" value="0.0"/> -->
<designMatrix id="kappa.designMatrix">
<parameter value="0 1 0 0 0 0
0 0 0 0 0 0"/>

<!--
<parameter value="0 1 0 0 0 0
0 0 0 0 0 0"/>
-->
</designMatrix>
</independentVariables>
</glmModel>
</glmSubstitutionModel>

<siteModel id="loc.siteModel">
<substitutionModel>
<glmSubstitutionModel idref="loc.model"/>
</substitutionModel>
</siteModel>

<treeDataLikelihood id="treeLikelihood" useAmbiguities="true" usePreOrder="true" scalingScheme="never" delayScaling="false">
<patterns idref="loc.pattern"/>
<treeModel idref="treeModel"/>
<siteModel idref="loc.siteModel"/>
<strictClockBranchRates idref="loc.branchRates"/>
<frequencyModel id="loc.rootModel" dataType="nucleotide">
<frequencies>
<parameter id="loc.root" value="1 0 0 0"/>
</frequencies>
</frequencyModel>
</treeDataLikelihood>

<glmSubstitutionModelGradient id="gradient.fo" traitName="loc" mode="firstOrder">
<treeDataLikelihood idref="treeLikelihood"/>
<glmSubstitutionModel idref="loc.model"/>
</glmSubstitutionModelGradient>

<glmSubstitutionModelGradient id="gradient.ac" traitName="loc" mode="affineCorrected">
<treeDataLikelihood idref="treeLikelihood"/>
<glmSubstitutionModel idref="loc.model"/>
</glmSubstitutionModelGradient>

<cachedReport id="report.gradient.fo">
<report>
Approximation via cross-products (first-order)
<glmSubstitutionModelGradient idref="gradient.fo"/>
</report>
</cachedReport>

<cachedReport id="report.gradient.ac">
<report>
Approximation via cross-products (affine-corrected)
<glmSubstitutionModelGradient idref="gradient.ac"/>
</report>
</cachedReport>

<branchSubstitutionParameterGradient id="kappaGradient.exact" traitName="kappaParameterGradient.exact"
useHessian="false" homogeneous="true" mode="exact">
<treeDataLikelihood idref="treeLikelihood"/>
<parameter idref="log.kappa"/>
</branchSubstitutionParameterGradient>

<branchSubstitutionParameterGradient id="kappaGradient.fo" traitName="kappaParameterGradient.fo"
useHessian="false" homogeneous="true" mode="firstOrder">
<treeDataLikelihood idref="treeLikelihood"/>
<parameter idref="log.kappa"/>
</branchSubstitutionParameterGradient>

<branchSubstitutionParameterGradient id="kappaGradient.ac" traitName="kappaParameterGradient.ac"
useHessian="false" homogeneous="true" mode="affineCorrected">
<treeDataLikelihood idref="treeLikelihood"/>
<parameter idref="log.kappa"/>
</branchSubstitutionParameterGradient>

<cachedReport id="report.xiang.exact">
<report>
Exact via Xiang-magic
<branchSubstitutionParameterGradient idref="kappaGradient.exact"/>
</report>
</cachedReport>

<cachedReport id="report.xiang.fo">
<report>
Exact via Xiang-magic
<branchSubstitutionParameterGradient idref="kappaGradient.fo"/>
</report>
</cachedReport>

<cachedReport id="report.xiang.ac">
<report>
Exact via Xiang-magic
<branchSubstitutionParameterGradient idref="kappaGradient.ac"/>
</report>
</cachedReport>

<assertEqual tolerance="1E-6" verbose="true" charactersToStrip="\[\],">
<message>
Check exact solution
</message>
<actual regex="analytic: (.*?)\n">
<cachedReport idref="report.xiang.exact"/>
</actual>
<expected regex="numeric : (.*?)\n">
<cachedReport idref="report.xiang.exact"/>
</expected>
</assertEqual>

<assertEqual tolerance="1E-6" verbose="true" charactersToStrip="\[\],">
<message>
Check first-order solutions
</message>
<actual regex="analytic: (.*?)\n">
<cachedReport idref="report.xiang.fo"/>
</actual>
<expected regex="analytic: (.*?)\n">
<cachedReport idref="report.gradient.fo"/>
</expected>
</assertEqual>

<assertEqual tolerance="1E-6" verbose="true" charactersToStrip="\[\],">
<message>
Check affine-corrected solutions
</message>
<actual regex="analytic: (.*?)\n">
<cachedReport idref="report.xiang.ac"/>
</actual>
<expected regex="analytic: (.*?)\n">
<cachedReport idref="report.gradient.ac"/>
</expected>
</assertEqual>


</beast>
18 changes: 11 additions & 7 deletions src/dr/evomodel/substmodel/DifferentialMassProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public String getReport() {
return "Exact";
}
},
FIRST_ORDER("approximate") {
FIRST_ORDER("firstOrder") {
@Override
public double[] dispatch(double time,
DifferentiableSubstitutionModel model,
Expand All @@ -66,7 +66,7 @@ public String getReport() {
return "Approximate wrt parameter";
}
},
AFFINE("affine") {
AFFINE("affineCorrected") {
@Override
public double[] dispatch(double time,
DifferentiableSubstitutionModel model,
Expand All @@ -85,10 +85,14 @@ public String getReport() {
}
};

private final String name;
private final String label;

Mode(String name) {
this.name = name;
Mode(String label) {
this.label = label;
}

public String getLabel() {
return label;
}

public abstract double[] dispatch(double time,
Expand All @@ -97,9 +101,9 @@ public abstract double[] dispatch(double time,

public abstract String getReport();

public static Mode parse(String name) {
public static Mode parse(String label) {
for (Mode mode : Mode.values()) {
if (mode.name.equalsIgnoreCase(name)) {
if (mode.label.equalsIgnoreCase(label)) {
return mode;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ public abstract class AbstractGlmSubstitutionModelGradient extends AbstractLogAd
public AbstractGlmSubstitutionModelGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
GlmSubstitutionModel substitutionModel) {
GlmSubstitutionModel substitutionModel,
ApproximationMode mode) {

super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel);
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel, mode);
this.glm = substitutionModel.getGeneralizedLinearModel();
this.parameterMap = makeParameterMap(glm);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,41 @@ public abstract class AbstractLogAdditiveSubstitutionModelGradient implements

private final ApproximationMode mode;

enum ApproximationMode {
FIRST_ORDER {
public enum ApproximationMode {

FIRST_ORDER("firstOrder") {
@Override
public String getInfo() {
return "a first-order";
}
},
AFFINE_CORRECTED {
AFFINE_CORRECTED("affineCorrected") {
@Override
public String getInfo() {
return "an affine-corrected";
}
};

ApproximationMode(String label) {
this.label = label;
}

public abstract String getInfo();

public String getLabel() {
return label;
}

private String label;

public static ApproximationMode factory(String label) {
for (ApproximationMode mode : ApproximationMode.values()) {
if (mode.getLabel().equalsIgnoreCase(label)) {
return mode;
}
}
throw new IllegalArgumentException("Unknown approximation mode");
}
}

private static final ApproximationMode DEFAULT_MODE = ApproximationMode.FIRST_ORDER;
Expand Down Expand Up @@ -253,39 +273,9 @@ private int index21(int i, int j) {
return j * stateCount + i;
}

// private int determineSubstitutionNumber(BranchModel branchModel,
// ComplexSubstitutionModel substitutionModel) {
//
// List<SubstitutionModel> substitutionModels = branchModel.getSubstitutionModels();
// for (int i = 0; i < substitutionModels.size(); ++i) {
// if (substitutionModel == substitutionModels.get(i)) {
// return i;
// }
// }
// throw new IllegalArgumentException("Unknown substitution model");
// }

// private int determineSubstitutionModelCount(BranchModel branchModel) {
// List<SubstitutionModel> substitutionModels = branchModel.getSubstitutionModels();
// return substitutionModels.size();
// }

private void accumulateAcrossSubstitutionModelInstances(double[] crossProducts) {
final int length = stateCount * stateCount;

// // copy first set of entries instead of accumulating
// System.arraycopy(
// crossProducts, whichSubstitutionModel * length,
// crossProducts, 0, length);
//
// if ( crossProductAccumulationMap.length > 0 ) {
// for (int i : crossProductAccumulationMap) {
// for (int j = 0; j < length; j++) {
// crossProducts[j] += crossProducts[i * length + j];
// }
// }
// }

int firstModel = crossProductAccumulationMap.get(0);
if (firstModel > 0) {
// Copy first set of entries
Expand Down Expand Up @@ -317,26 +307,6 @@ private List<Integer> createCrossProductAccumulationMap(BranchModel branchModel,
return map;
}

// private void updateCrossProductAccumulationMap() {
//// System.err.println("Updating crossProductAccumulationMap");
// List<Integer> matchingModels = new ArrayList<>();
// List<SubstitutionModel> substitutionModels = branchModel.getSubstitutionModels();
//
// // We copy whichSubstitutionModel instead of accumulating it
// for (int i = 0; i < substitutionModels.size(); ++i) {
// if (i != whichSubstitutionModel && substitutionModel == substitutionModels.get(i)) {
// matchingModels.add(i);
// }
// }
//
// crossProductAccumulationMap = new int[matchingModels.size()];
// if (matchingModels.size() > 0) {
// for (int i = 0; i < matchingModels.size(); ++i) {
// crossProductAccumulationMap[i] = matchingModels.get(i);
// }
// }
// }

@Override
public Likelihood getLikelihood() {
return treeDataLikelihood;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ public DesignMatrixSubstitutionModelGradient(String traitName,
BeagleDataLikelihoodDelegate likelihoodDelegate,
GlmSubstitutionModel substitutionModel,
DesignMatrix matrix,
MaskedParameter parameter) {
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel);
MaskedParameter parameter,
ApproximationMode mode) {
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel, mode);

this.parameter = parameter;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public class FixedEffectSubstitutionModelGradient extends AbstractGlmSubstitutio
public FixedEffectSubstitutionModelGradient(String traitName,
TreeDataLikelihood treeDataLikelihood,
BeagleDataLikelihoodDelegate likelihoodDelegate,
GlmSubstitutionModel substitutionModel) {
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel);
GlmSubstitutionModel substitutionModel,
ApproximationMode mode) {
super(traitName, treeDataLikelihood, likelihoodDelegate, substitutionModel, mode);
}
}
Loading

0 comments on commit 8e07963

Please sign in to comment.