From e8c2e1222db1edaca79c1281e17fa214b3b9f0e4 Mon Sep 17 00:00:00 2001 From: yxu927 <132981793+yxu927@users.noreply.github.com> Date: Fri, 14 Jun 2024 16:25:45 +1200 Subject: [PATCH] let PhyloCTMC use branch rates from tree if not given #491 #496 --- .../likelihood/AbstractPhyloCTMC.java | 25 ++--- .../base/evolution/likelihood/PhyloCTMC.java | 6 +- .../lphy/base/evolution/tree/TimeTree.java | 5 +- .../evolution/likelihood/PhyloCTMCTest.java | 100 ++++++++++++++++++ 4 files changed, 115 insertions(+), 21 deletions(-) create mode 100644 lphy-base/src/test/java/lphy/base/evolution/likelihood/PhyloCTMCTest.java diff --git a/lphy-base/src/main/java/lphy/base/evolution/likelihood/AbstractPhyloCTMC.java b/lphy-base/src/main/java/lphy/base/evolution/likelihood/AbstractPhyloCTMC.java index e8f9574b5..d60e75d4b 100644 --- a/lphy-base/src/main/java/lphy/base/evolution/likelihood/AbstractPhyloCTMC.java +++ b/lphy-base/src/main/java/lphy/base/evolution/likelihood/AbstractPhyloCTMC.java @@ -5,7 +5,6 @@ import lphy.base.evolution.alignment.SimpleAlignment; import lphy.base.evolution.tree.TimeTree; import lphy.base.evolution.tree.TimeTreeNode; -import lphy.core.logger.LoggerUtils; import lphy.core.model.GenerativeDistribution; import lphy.core.model.Value; import lphy.core.simulator.RandomUtils; @@ -48,15 +47,11 @@ public abstract class AbstractPhyloCTMC implements GenerativeDistribution rootFreqs; protected SortedMap idMap = new TreeMap<>(); protected double[][] transProb; - /** - * e^{Qt} = Ee^{At}E^-1, where A is a diagonal matrix of eigenvalues (Eval), - * E is the matrix of right eigenvectors (Evec), and E^-1 is the matrix of left eigenvectors (Ievc). - */ private EigenDecomposition decomposition; - private double[][] Ievc; // inverse Eigen vectors - private double[][] Evec; // Eigen vectors - private double[][] iexp; // intermediate matrix - private double[] Eval; // Eigenvalues + private double[][] Ievc; + private double[][] Evec; + private double[][] iexp; + private double[] Eval; public AbstractPhyloCTMC(Value tree, Value clockRate, Value freq, @@ -71,15 +66,11 @@ public AbstractPhyloCTMC(Value tree, Value clockRate, Value 0) { - if (this.branchRates != null) { // have branchRates from input but tree also has branch rates - LoggerUtils.log.warning("PhyloCTMC has branchRates from input parameter and tree has branch rates, " + - "default to using input parameter branchRates."); - } else { // if tree has branch rates, then use them - this.branchRates = new Value<>("branchRates", treeBranchRates); - } + // if tree has branch rates, then use them + if (branchRates == null && treeBranchRates != null && treeBranchRates.length > 0) { + this.branchRates = new Value<>("branchRates", treeBranchRates); } + // checkCompatibilities(); } diff --git a/lphy-base/src/main/java/lphy/base/evolution/likelihood/PhyloCTMC.java b/lphy-base/src/main/java/lphy/base/evolution/likelihood/PhyloCTMC.java index e5647d672..4356ede99 100644 --- a/lphy-base/src/main/java/lphy/base/evolution/likelihood/PhyloCTMC.java +++ b/lphy-base/src/main/java/lphy/base/evolution/likelihood/PhyloCTMC.java @@ -40,7 +40,7 @@ public PhyloCTMC(@ParameterInfo(name = AbstractPhyloCTMC.treeParamName, verb = " @ParameterInfo(name = AbstractPhyloCTMC.rootFreqParamName, verb = "are", narrativeName = "root frequencies", description = "the root probabilities. Optional parameter. If not specified then first row of e^{100*Q) is used.", optional = true) Value rootFreq, @ParameterInfo(name = QParamName, narrativeName= "instantaneous rate matrix", description = "the instantaneous rate matrix.") Value Q, @ParameterInfo(name = siteRatesParamName, description = "a rate for each site in the alignment. Site rates are assumed to be 1.0 otherwise.", optional = true) Value siteRates, - @ParameterInfo(name = AbstractPhyloCTMC.branchRatesParamName, description = "a rate for each branch in the tree. Branch rates are assumed to be 1.0 otherwise.", optional = true) Value branchRates, + @ParameterInfo(name = AbstractPhyloCTMC.branchRatesParamName, description = "a rate for each branch in the tree. Original branch rates are used if rates not given. Branch rates are assumed to be 1.0 otherwise.", optional = true) Value branchRates, @ParameterInfo(name = AbstractPhyloCTMC.LParamName, narrativeName= "alignment length", description = "length of the alignment", optional = true) Value L, @ParameterInfo(name = AbstractPhyloCTMC.dataTypeParamName, description = "the data type used for simulations, default to nucleotide", @@ -111,7 +111,7 @@ public void setParam(String paramName, Value value) { else if (paramName.equals(siteRatesParamName)) siteRates = value; else if (paramName.equals(AbstractPhyloCTMC.branchRatesParamName)) branchRates = value; else if (paramName.equals(AbstractPhyloCTMC.LParamName)) L = value; -// else if (paramName.equals(stateNamesParamName)) stateNames = value; + // else if (paramName.equals(stateNamesParamName)) stateNames = value; else if (paramName.equals(AbstractPhyloCTMC.dataTypeParamName)) dataType = value; else if (paramName.equals(AbstractPhyloCTMC.rootSeqParamName)) rootSeq = value; else throw new RuntimeException("Unrecognised parameter name: " + paramName); @@ -123,7 +123,7 @@ public void setParam(String paramName, Value value) { narrativeName = "phylogenetic continuous time Markov process", category = GeneratorCategory.PHYLO_LIKELIHOOD, examples = {"gtrGammaCoalescent.lphy", "errorModel1.lphy"}, description = "The phylogenetic continuous-time Markov chain distribution. A sequence is simulated for every leaf node, and every direct ancestor node with an id." + - "(The sampling distribution that the phylogenetic likelihood is derived from.)") + "(The sampling distribution that the phylogenetic likelihood is derived from.)") public RandomVariable sample() { setup(); diff --git a/lphy-base/src/main/java/lphy/base/evolution/tree/TimeTree.java b/lphy-base/src/main/java/lphy/base/evolution/tree/TimeTree.java index 0177e64cd..5954acf3b 100644 --- a/lphy-base/src/main/java/lphy/base/evolution/tree/TimeTree.java +++ b/lphy-base/src/main/java/lphy/base/evolution/tree/TimeTree.java @@ -337,8 +337,11 @@ public TimeTreeNode getLabeledNode(String label) { public Double[] getBranchRates(){ List branchRates = new ArrayList<>(); for (TimeTreeNode node: nodes){ - branchRates.add(node.getBranchRate()); + if (node.getBranchRate() != null) { + branchRates.add(node.getBranchRate()); + } } + return branchRates.toArray(Double[]::new); } } diff --git a/lphy-base/src/test/java/lphy/base/evolution/likelihood/PhyloCTMCTest.java b/lphy-base/src/test/java/lphy/base/evolution/likelihood/PhyloCTMCTest.java new file mode 100644 index 000000000..4d378ff59 --- /dev/null +++ b/lphy-base/src/test/java/lphy/base/evolution/likelihood/PhyloCTMCTest.java @@ -0,0 +1,100 @@ +package lphy.base.evolution.likelihood; + +import lphy.base.evolution.branchrate.LocalClock; +import lphy.base.evolution.tree.TimeTree; +import lphy.base.evolution.tree.TimeTreeNode; +import lphy.base.function.tree.Newick; +import lphy.core.model.Value; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Objects; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class PhyloCTMCTest { + String newickTree; + + @BeforeEach + void setUp() { + newickTree = "((1:2.0, (2:1.0, 3:1.0):1.0):2.0, 4:4.0)"; + } + + + @Test + void apply() { + // generate a tree with local clock + TimeTree tree = Newick.parseNewick(newickTree); + TimeTreeNode node1 = null; + TimeTreeNode node2 = null; + + for (int i = 0; i treeValue = new Value<>("tree", tree); + Value cladesValue = new Value<>("clades", clades); + Value cladeRatesValue = new Value<>("cladeRates", cladeRates); + Value rootRateValue = new Value<>("rootRate" , rootRate); + Value includeStemValue = new Value<>("includeStem" , Boolean.TRUE); + + LocalClock localClockInstance = new LocalClock(treeValue, cladesValue, cladeRatesValue, rootRateValue, includeStemValue); + TimeTree newTree = localClockInstance.apply().value(); + + // test PhyloCTMC + Double[][] Q = { + { -1.0, 0.5, 0.3, 0.2 }, + { 0.4, -1.0, 0.1, 0.5 }, + { 0.3, 0.2, -1.0, 0.5 }, + { 0.2, 0.3, 0.5, -1.0 } + }; + + + Value newTreeValue = new Value<>("newTree", newTree); + Value QValue = new Value<>("Q", Q); + Value siteRatesValue = new Value("siteRates", new Double[]{0.1, 0.1, 0.1, 0.1, 0.1}); + Value LValue = new Value("L", 5); + + PhyloCTMC phyloCTMCInstance = new PhyloCTMC( + newTreeValue,null,null, + QValue,siteRatesValue,null, + LValue,null, null); + + phyloCTMCInstance.sample(); + + Double[] branchRates = phyloCTMCInstance.getBranchRates().value(); + + List allNodes = newTree.getNodes(); + assertEquals(allNodes.size(), branchRates.length); + + // index 0 : node 2 + assertEquals(allNodes.get(0).getBranchRate(), branchRates[0]); + // index 1 : node 3 + assertEquals(allNodes.get(1).getBranchRate(), branchRates[1]); + // index 2 : node 1 + assertEquals(allNodes.get(2).getBranchRate(), branchRates[2]); + // index 3 : node 4 + assertEquals("4" , allNodes.get(3).getId()); + assertEquals(0.3, allNodes.get(3).getBranchRate()); + assertEquals(allNodes.get(3).getBranchRate(), branchRates[3]); + // index 4 : node (2,3) + assertEquals(allNodes.get(4).getBranchRate(), branchRates[4]); + // index 5 : node ((2,3),1) + assertEquals(allNodes.get(5).getBranchRate(), branchRates[5]); + // index 6 : node (((2,3),1),4) + assertEquals(allNodes.get(6).getBranchRate(), branchRates[6]); + } +}