Skip to content

Commit

Permalink
let PhyloCTMC use branch rates from tree if not given #491 #496
Browse files Browse the repository at this point in the history
  • Loading branch information
yxu927 committed Jun 14, 2024
1 parent d075c49 commit e8c2e12
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import lphy.base.evolution.alignment.SimpleAlignment;
import lphy.base.evolution.tree.TimeTree;
import lphy.base.evolution.tree.TimeTreeNode;
import lphy.core.logger.LoggerUtils;
import lphy.core.model.GenerativeDistribution;
import lphy.core.model.Value;
import lphy.core.simulator.RandomUtils;
Expand Down Expand Up @@ -48,15 +47,11 @@ public abstract class AbstractPhyloCTMC implements GenerativeDistribution<Alignm
protected Value<Double[]> rootFreqs;
protected SortedMap<String, Integer> idMap = new TreeMap<>();
protected double[][] transProb;
/**
* <code>e^{Qt} = Ee^{At}E^-1</code>, where A is a diagonal matrix of eigenvalues (Eval),
* E is the matrix of right eigenvectors (Evec), and E^-1 is the matrix of left eigenvectors (Ievc).
*/
private EigenDecomposition decomposition;
private double[][] Ievc; // inverse Eigen vectors
private double[][] Evec; // Eigen vectors
private double[][] iexp; // intermediate matrix
private double[] Eval; // Eigenvalues
private double[][] Ievc;
private double[][] Evec;
private double[][] iexp;
private double[] Eval;


public AbstractPhyloCTMC(Value<TimeTree> tree, Value<Number> clockRate, Value<Double[]> freq,
Expand All @@ -71,15 +66,11 @@ public AbstractPhyloCTMC(Value<TimeTree> tree, Value<Number> clockRate, Value<Do
this.random = RandomUtils.getRandom();

Double[] treeBranchRates = tree.value().getBranchRates();

if (treeBranchRates != null && treeBranchRates.length > 0) {
if (this.branchRates != null) { // have branchRates from input but tree also has branch rates
LoggerUtils.log.warning("PhyloCTMC has branchRates from input parameter and tree has branch rates, " +
"default to using input parameter branchRates.");
} else { // if tree has branch rates, then use them
this.branchRates = new Value<>("branchRates", treeBranchRates);
}
// if tree has branch rates, then use them
if (branchRates == null && treeBranchRates != null && treeBranchRates.length > 0) {
this.branchRates = new Value<>("branchRates", treeBranchRates);
}

// checkCompatibilities();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public PhyloCTMC(@ParameterInfo(name = AbstractPhyloCTMC.treeParamName, verb = "
@ParameterInfo(name = AbstractPhyloCTMC.rootFreqParamName, verb = "are", narrativeName = "root frequencies", description = "the root probabilities. Optional parameter. If not specified then first row of e^{100*Q) is used.", optional = true) Value<Double[]> rootFreq,
@ParameterInfo(name = QParamName, narrativeName= "instantaneous rate matrix", description = "the instantaneous rate matrix.") Value<Double[][]> Q,
@ParameterInfo(name = siteRatesParamName, description = "a rate for each site in the alignment. Site rates are assumed to be 1.0 otherwise.", optional = true) Value<Double[]> siteRates,
@ParameterInfo(name = AbstractPhyloCTMC.branchRatesParamName, description = "a rate for each branch in the tree. Branch rates are assumed to be 1.0 otherwise.", optional = true) Value<Double[]> branchRates,
@ParameterInfo(name = AbstractPhyloCTMC.branchRatesParamName, description = "a rate for each branch in the tree. Original branch rates are used if rates not given. Branch rates are assumed to be 1.0 otherwise.", optional = true) Value<Double[]> branchRates,
@ParameterInfo(name = AbstractPhyloCTMC.LParamName, narrativeName= "alignment length",
description = "length of the alignment", optional = true) Value<Integer> L,
@ParameterInfo(name = AbstractPhyloCTMC.dataTypeParamName, description = "the data type used for simulations, default to nucleotide",
Expand Down Expand Up @@ -111,7 +111,7 @@ public void setParam(String paramName, Value value) {
else if (paramName.equals(siteRatesParamName)) siteRates = value;
else if (paramName.equals(AbstractPhyloCTMC.branchRatesParamName)) branchRates = value;
else if (paramName.equals(AbstractPhyloCTMC.LParamName)) L = value;
// else if (paramName.equals(stateNamesParamName)) stateNames = value;
// else if (paramName.equals(stateNamesParamName)) stateNames = value;
else if (paramName.equals(AbstractPhyloCTMC.dataTypeParamName)) dataType = value;
else if (paramName.equals(AbstractPhyloCTMC.rootSeqParamName)) rootSeq = value;
else throw new RuntimeException("Unrecognised parameter name: " + paramName);
Expand All @@ -123,7 +123,7 @@ public void setParam(String paramName, Value value) {
narrativeName = "phylogenetic continuous time Markov process",
category = GeneratorCategory.PHYLO_LIKELIHOOD, examples = {"gtrGammaCoalescent.lphy", "errorModel1.lphy"},
description = "The phylogenetic continuous-time Markov chain distribution. A sequence is simulated for every leaf node, and every direct ancestor node with an id." +
"(The sampling distribution that the phylogenetic likelihood is derived from.)")
"(The sampling distribution that the phylogenetic likelihood is derived from.)")
public RandomVariable<Alignment> sample() {
setup();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,11 @@ public TimeTreeNode getLabeledNode(String label) {
public Double[] getBranchRates(){
List<Double> branchRates = new ArrayList<>();
for (TimeTreeNode node: nodes){
branchRates.add(node.getBranchRate());
if (node.getBranchRate() != null) {
branchRates.add(node.getBranchRate());
}
}

return branchRates.toArray(Double[]::new);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package lphy.base.evolution.likelihood;

import lphy.base.evolution.branchrate.LocalClock;
import lphy.base.evolution.tree.TimeTree;
import lphy.base.evolution.tree.TimeTreeNode;
import lphy.base.function.tree.Newick;
import lphy.core.model.Value;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Objects;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

public class PhyloCTMCTest {
String newickTree;

@BeforeEach
void setUp() {
newickTree = "((1:2.0, (2:1.0, 3:1.0):1.0):2.0, 4:4.0)";
}


@Test
void apply() {
// generate a tree with local clock
TimeTree tree = Newick.parseNewick(newickTree);
TimeTreeNode node1 = null;
TimeTreeNode node2 = null;

for (int i = 0; i<tree.getNodes().size(); i++){
if (Objects.equals(tree.getNodes().get(i).getId(), "4")){ //node2 is the leaf node 4
node2 = tree.getNodes().get(i);
} else if (tree.getNodes().get(i).getAllLeafNodes().size() == 2){ //node1 is the parent of (2,3)
node1 = tree.getNodes().get(i);
}
}

assertNotNull(node1);
assertNotNull(node2);

TimeTreeNode[] clades = {node1, node2};
Double[] cladeRates = {0.4, 0.3};
double rootRate = 0.2;

Value<TimeTree> treeValue = new Value<>("tree", tree);
Value<Object[]> cladesValue = new Value<>("clades", clades);
Value<Double[]> cladeRatesValue = new Value<>("cladeRates", cladeRates);
Value<Double> rootRateValue = new Value<>("rootRate" , rootRate);
Value<Boolean> includeStemValue = new Value<>("includeStem" , Boolean.TRUE);

LocalClock localClockInstance = new LocalClock(treeValue, cladesValue, cladeRatesValue, rootRateValue, includeStemValue);
TimeTree newTree = localClockInstance.apply().value();

// test PhyloCTMC
Double[][] Q = {
{ -1.0, 0.5, 0.3, 0.2 },
{ 0.4, -1.0, 0.1, 0.5 },
{ 0.3, 0.2, -1.0, 0.5 },
{ 0.2, 0.3, 0.5, -1.0 }
};


Value<TimeTree> newTreeValue = new Value<>("newTree", newTree);
Value<Double[][]> QValue = new Value<>("Q", Q);
Value<Double[]> siteRatesValue = new Value<Double[]>("siteRates", new Double[]{0.1, 0.1, 0.1, 0.1, 0.1});
Value<Integer> LValue = new Value<Integer>("L", 5);

PhyloCTMC phyloCTMCInstance = new PhyloCTMC(
newTreeValue,null,null,
QValue,siteRatesValue,null,
LValue,null, null);

phyloCTMCInstance.sample();

Double[] branchRates = phyloCTMCInstance.getBranchRates().value();

List<TimeTreeNode> allNodes = newTree.getNodes();
assertEquals(allNodes.size(), branchRates.length);

// index 0 : node 2
assertEquals(allNodes.get(0).getBranchRate(), branchRates[0]);
// index 1 : node 3
assertEquals(allNodes.get(1).getBranchRate(), branchRates[1]);
// index 2 : node 1
assertEquals(allNodes.get(2).getBranchRate(), branchRates[2]);
// index 3 : node 4
assertEquals("4" , allNodes.get(3).getId());
assertEquals(0.3, allNodes.get(3).getBranchRate());
assertEquals(allNodes.get(3).getBranchRate(), branchRates[3]);
// index 4 : node (2,3)
assertEquals(allNodes.get(4).getBranchRate(), branchRates[4]);
// index 5 : node ((2,3),1)
assertEquals(allNodes.get(5).getBranchRate(), branchRates[5]);
// index 6 : node (((2,3),1),4)
assertEquals(allNodes.get(6).getBranchRate(), branchRates[6]);
}
}

0 comments on commit e8c2e12

Please sign in to comment.