Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into beauti-thorne
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmccr1 committed Dec 18, 2024
2 parents 8226901 + 1c9935d commit d3bd9e6
Show file tree
Hide file tree
Showing 6 changed files with 80 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import dr.evolution.util.Taxa;
import dr.evolution.util.Units;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.BranchSpecificFixedEffects;
import dr.inference.distribution.DistributionLikelihood;
import dr.evomodel.tree.DefaultTreeModel;
import dr.evomodelxml.TreeWorkingPriorParsers;
import dr.evomodelxml.branchratemodel.*;
Expand Down Expand Up @@ -1026,6 +1028,20 @@ public void writeMLE(XMLWriter writer, MarginalLikelihoodEstimationOptions optio
break;

case MIXED_EFFECTS_CLOCK:

writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.RATES_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.SCALE_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffects.INTERCEPT_PRIOR);

String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
int number = 1;
String concat = coeff + number;
while (model.hasParameter(concat)) {
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, model.getPrefix() + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number);
number++;
concat = coeff + number;
}

break;

default:
Expand Down
25 changes: 14 additions & 11 deletions src/dr/app/beauti/generator/ClockModelGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import dr.app.beauti.components.ComponentFactory;
import dr.app.beauti.options.*;
import dr.app.beauti.types.ClockType;
import dr.app.beauti.types.OperatorType;
import dr.app.beauti.util.XMLWriter;
import dr.evolution.util.Taxa;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
Expand Down Expand Up @@ -66,8 +65,6 @@
import dr.util.Attribute;
import dr.xml.XMLParser;

import java.util.Map;

import static dr.inference.model.ParameterParser.PARAMETER;
import static dr.inferencexml.distribution.PriorParsers.*;
import static dr.inferencexml.distribution.shrinkage.BayesianBridgeLikelihoodParser.*;
Expand Down Expand Up @@ -301,16 +298,19 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ

writeCovarianceStatistic(writer, tag, prefix, treePrefix);

//TODO add more String constants for this type of code
boolean generateRatesGradient = false;
boolean generateScaleGradient = false;

for (Operator operator : options.selectOperators()) {
if (operator.getName().equals("HMC relaxed clock location and scale") && operator.isUsed()) {
if (operator.getName().equals(ClockType.HMC_CLOCK_RATES_DESCRIPTION) && operator.isUsed()) {
generateRatesGradient = true;
}
if (operator.getName().equals(ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION) && operator.isUsed()) {
generateScaleGradient = true;
}
}

if (generateScaleGradient) {
if (generateRatesGradient) {

//scale prior
writer.writeOpenTag(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD,
Expand Down Expand Up @@ -352,6 +352,9 @@ public void writeBranchRatesModel(PartitionClockModel clockModel, XMLWriter writ
writer.writeCloseTag(LocationScaleGradientParser.LOCATION);
writer.writeCloseTag(LocationScaleGradientParser.NAME);

}

if (generateScaleGradient){
//scale gradient
writer.writeOpenTag(LocationScaleGradientParser.NAME, new Attribute[]{
new Attribute.Default<>(XMLParser.ID, prefix + ScaleGradient.SCALE_GRADIENT),
Expand Down Expand Up @@ -958,18 +961,18 @@ public static void writeBranchRatesModelRef(PartitionClockModel model, XMLWriter

case MIXED_EFFECTS_CLOCK:
//always write distribution likelihoods for rate, scale and intercept
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR);
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR);
//writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.RATES_PRIOR);
//writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.SCALE_PRIOR);
//writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffects.INTERCEPT_PRIOR);
//check for coefficients
String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
/*String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
int number = 1;
String concat = coeff + number;
while (model.hasParameter(concat)) {
writer.writeIDref(DistributionLikelihood.DISTRIBUTION_LIKELIHOOD, BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number);
number++;
concat = coeff + number;
}
}*/
tag = ArbitraryBranchRatesParser.ARBITRARY_BRANCH_RATES;
id = model.getPrefix() + ArbitraryBranchRates.BRANCH_RATES;
break;
Expand Down
45 changes: 39 additions & 6 deletions src/dr/app/beauti/generator/ParameterPriorGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@
import dr.evolution.util.Taxa;
import dr.evomodel.branchratemodel.BranchSpecificFixedEffects;
import dr.evomodel.tree.DefaultTreeModel;
import dr.evomodelxml.branchratemodel.BranchSpecificFixedEffectsParser;
import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.evomodelxml.tree.CTMCScalePriorParser;
import dr.evomodelxml.tree.MonophylyStatisticParser;
import dr.inference.distribution.DistributionLikelihood;
import dr.inference.model.ParameterParser;
import dr.inferencexml.distribution.CachedDistributionLikelihoodParser;
import dr.inferencexml.distribution.DistributionLikelihoodParser;
import dr.inferencexml.distribution.PriorParsers;
import dr.inferencexml.model.BooleanLikelihoodParser;
import dr.inferencexml.model.OneOnXPriorParser;
Expand All @@ -58,19 +58,48 @@
*/
public class ParameterPriorGenerator extends Generator {

//map parameters to prior IDs, for use with HMC
private HashMap<String, String> mapParameterToPrior;
//map parameters to prior IDs, for use with HMC or other approaches that define their prior befor the <mcmc> XML block
private final HashMap<String, String> mapParameterToPrior;

public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] components) {
super(options, components);
//TODO don't like this being here, but will see how things pan out as more HMC approaches are added
mapParameterToPrior = new HashMap<String, String>();
}

/**
* Add all possibly previously defined priors to a HashMap
* Cannot be done in constructor as the models have not been defined by the user at that point
*/
public void addParametersToPrior() {
int totalModels = options.getPartitionClockModels().size();
List<PartitionClockModel> partitionClockModels = options.getPartitionClockModels();
//HMC skygrid
mapParameterToPrior.put(GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION, GMRFSkyrideLikelihoodParser.SKYGRID_PRECISION_PRIOR);
//HMC relaxed clock
mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, BranchSpecificFixedEffects.LOCATION_PRIOR);
mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, BranchSpecificFixedEffects.RATES_PRIOR);
mapParameterToPrior.put(ClockType.HMCLN_SCALE, BranchSpecificFixedEffects.SCALE_PRIOR);
for (int i = 0; i < totalModels; i++) {
String prefix = partitionClockModels.get(i).getPrefix();
mapParameterToPrior.put(ClockType.HMC_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.LOCATION_PRIOR);
mapParameterToPrior.put(ClockType.HMC_CLOCK_BRANCH_RATES, prefix + BranchSpecificFixedEffects.RATES_PRIOR);
mapParameterToPrior.put(ClockType.HMCLN_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR);
}
//mixed effects clock
//always write distribution likelihoods for rate, scale and intercept
for (int i = 0; i < totalModels; i++) {
String prefix = partitionClockModels.get(i).getPrefix();
mapParameterToPrior.put(ClockType.ME_CLOCK_LOCATION, prefix + BranchSpecificFixedEffects.RATES_PRIOR);
mapParameterToPrior.put(ClockType.ME_CLOCK_SCALE, prefix + BranchSpecificFixedEffects.SCALE_PRIOR);
mapParameterToPrior.put(BranchSpecificFixedEffectsParser.INTERCEPT, prefix + BranchSpecificFixedEffects.INTERCEPT_PRIOR);
//check for coefficients
String coeff = BranchSpecificFixedEffectsParser.COEFFICIENT;
int number = 1;
String concat = coeff + number;
while (partitionClockModels.get(i).hasParameter(concat)) {
mapParameterToPrior.put(concat, prefix + BranchSpecificFixedEffectsParser.FIXED_EFFECTS_LIKELIHOOD + number);
number++;
concat = coeff + number;
}
}
}

/**
Expand All @@ -79,6 +108,10 @@ public ParameterPriorGenerator(BeautiOptions options, ComponentFactory[] compone
* @param writer the writer
*/
public void writeParameterPriors(XMLWriter writer) {

//first make sure that all possibly previously defined priors are part of the HashMap
addParametersToPrior();

boolean first = true;

for (Map.Entry<Taxa, Boolean> taxaBooleanEntry : options.taxonSetsMono.entrySet()) {
Expand Down
6 changes: 4 additions & 2 deletions src/dr/app/beauti/generator/SubstitutionModelGenerator.java
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ private void writeTwoStateSiteModel(XMLWriter writer, PartitionSubstitutionModel
if (options.useNuRelativeRates()) {
Parameter parameter = model.getParameter("nu");
String prefix1 = options.getPrefix();
if (!parameter.getSubParameters().isEmpty()) {
if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) {
writeNuRelativeRateBlock(writer, prefix1, parameter);
}
} else {
Expand Down Expand Up @@ -802,7 +802,9 @@ private void writeAASiteModel(XMLWriter writer, PartitionSubstitutionModel model

if (options.useNuRelativeRates()) {
Parameter parameter = model.getParameter("nu");
writeNuRelativeRateBlock(writer, prefix, parameter);
if (parameter.getParent() != null && !parameter.getSubParameters().isEmpty()) {
writeNuRelativeRateBlock(writer, prefix, parameter);
}
} else {
writeParameter(SiteModelParser.RELATIVE_RATE, "mu", model, writer);
}
Expand Down
8 changes: 4 additions & 4 deletions src/dr/app/beauti/options/PartitionClockModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public void initModelParametersAndOpererators() {
.initial(1.0).mean(1.0).offset(0.0).partitionOptions(this).isPriorFixed(true)
.isAdaptiveMultivariateCompatible(false).build(parameters);

new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, "HMC relaxed clock branch rates")
new Parameter.Builder(ClockType.HMC_CLOCK_BRANCH_RATES, ClockType.HMC_CLOCK_RATES_DESCRIPTION)
.prior(PriorType.LOGNORMAL_HPM_PRIOR).initial(0.001).isNonNegative(true)
.partitionOptions(this).isPriorFixed(true)
.isAdaptiveMultivariateCompatible(false).build(parameters);
Expand Down Expand Up @@ -216,11 +216,11 @@ public void initModelParametersAndOpererators() {
createScaleOperator(ClockType.UCGD_SHAPE, demoTuning, rateWeights);

//HMC relaxed clock
createOperator("HMCRCR", "HMC relaxed clock branch rates",
createOperator("HMCRCR", ClockType.HMC_CLOCK_RATES_DESCRIPTION,
"Hamiltonian Monte Carlo relaxed clock branch rates operator", null, OperatorType.RELAXED_CLOCK_HMC_RATE_OPERATOR,-1 , 1.0);
createOperator("HMCRCS", "HMC relaxed clock location and scale",
createOperator("HMCRCS", ClockType.HMC_CLOCK_LOCATION_SCALE_DESCRIPTION,
"Hamiltonian Monte Carlo relaxed clock scale operator", null, OperatorType.RELAXED_CLOCK_HMC_SCALE_OPERATOR,-1 , 0.5);
//for the time being turn off the HMC relaxed clock scale kernel
//turn off the HMC relaxed clock scale kernel by default
getOperator("HMCRCS").setUsed(false);
createScaleOperator(ClockType.HMC_CLOCK_LOCATION, demoTuning, rateWeights);
createScaleOperator(ClockType.HMCLN_SCALE, demoTuning, rateWeights);
Expand Down
3 changes: 3 additions & 0 deletions src/dr/app/beauti/types/ClockType.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,7 @@ public String toString() {

final public static String ACLD_MEAN = "acld.mean";
final public static String ACLD_STDEV = "acld.stdev";

final public static String HMC_CLOCK_RATES_DESCRIPTION = "HMC relaxed clock branch rates";
final public static String HMC_CLOCK_LOCATION_SCALE_DESCRIPTION = "HMC relaxed clock location and scale";
}

0 comments on commit d3bd9e6

Please sign in to comment.