Skip to content

Commit

Permalink
Extracts intervalNodeMap - towards unified sky*
Browse files Browse the repository at this point in the history
  • Loading branch information
jtmccr1 committed Nov 21, 2024
1 parent cc5700c commit 6e24230
Show file tree
Hide file tree
Showing 14 changed files with 445 additions and 597 deletions.
8 changes: 8 additions & 0 deletions src/dr/evolution/coalescent/TreeIntervalList.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evomodel.coalescent.IntervalNodeMapProvider;

/**
* This is interface for an Interval list that wraps a tree and provides
Expand Down Expand Up @@ -88,4 +89,11 @@ public interface TreeIntervalList extends IntervalList{
* @return the tree
*/
Tree getTree();

/** gets the intervalNodeMap
* @return the intervalNodeMap
*/
IntervalNodeMapProvider getIntervalNodeMap();

void setBuildIntervalNodeMapping(boolean buildIntervalNodeMapping);
}
12 changes: 11 additions & 1 deletion src/dr/evomodel/bigfasttree/BigFastTreeIntervals.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Units;
import dr.evomodel.coalescent.IntervalNodeMapProvider;
import dr.evomodel.tree.TreeChangedEvent;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.AbstractModel;
Expand All @@ -49,7 +50,7 @@
* Smart intervals that don't need a full recalculation.
* author: JT
*/
public class BigFastTreeIntervals extends AbstractModel implements Units, TreeIntervalList {
public class BigFastTreeIntervals extends AbstractModel implements Units, TreeIntervalList, IntervalNodeMapProvider {
public BigFastTreeIntervals(TreeModel tree) {
this("bigFastIntervals",tree);
}
Expand Down Expand Up @@ -169,6 +170,14 @@ public boolean isCoalescentOnly() {
return true;
}

public void setBuildIntervalNodeMapping(boolean buildIntervalNodeMapping) {
// nothing done this is done by default with this tree model
}
public IntervalNodeMapProvider getIntervalNodeMap() {
return this;
}


// Interval Node mapping

@Override
Expand Down Expand Up @@ -700,4 +709,5 @@ public void setNodeOrder(int nodeNum, int position) {
private int intervalCount = 0;



}
19 changes: 11 additions & 8 deletions src/dr/evomodel/coalescent/GMRFSkygridLikelihood.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

import dr.evolution.coalescent.IntervalList;
import dr.evolution.coalescent.IntervalType;
import dr.evolution.coalescent.TreeIntervalList;
import dr.evolution.tree.Tree;
import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
Expand Down Expand Up @@ -96,9 +98,9 @@ public class GMRFSkygridLikelihood extends GMRFSkyrideLikelihood

// private List<Tree> treeList;
// private List<TreeIntervals> intervalsList;
private List<IntervalList> intervalsList;
private List<TreeIntervalList> intervalsList;

public GMRFSkygridLikelihood(List<IntervalList> intervalsList,
public GMRFSkygridLikelihood(List<TreeIntervalList> intervalsList,
Parameter popParameter,
Parameter groupParameter,
Parameter precParameter,
Expand Down Expand Up @@ -156,7 +158,8 @@ public GMRFSkygridLikelihood(List<IntervalList> intervalsList,
addVariable(ploidyFactors);

this.intervalsList = intervalsList;
//this.numTrees = setTree(treeList);
//this.numTrees = setTree(treeList);
//TODO are we accounting for what setTree does?

for (IntervalList intervalList : intervalsList) {
addModel((Model) intervalList);
Expand Down Expand Up @@ -219,7 +222,7 @@ public GMRFSkygridLikelihood(List<IntervalList> intervalsList,


//rewrite this constructor without duplicating so much code
public GMRFSkygridLikelihood(List<IntervalList> intervalsList,
public GMRFSkygridLikelihood(List<TreeIntervalList> intervalsList,
Parameter popParameter,
Parameter groupParameter,
Parameter precParameter,
Expand Down Expand Up @@ -921,13 +924,13 @@ public int nLoci() {
return intervalsList.size();
}

/* public Tree getTree(int nt) {
return treeList.get(nt);
public Tree getTree(int nt) {
return intervalsList.get(nt).getTree();
}

public TreeIntervals getTreeIntervals(int nt) {
public TreeIntervalList getTreeIntervals(int nt) {
return intervalsList.get(nt);
} */
}

public double getPopulationFactor(int nt) {
return ploidyFactors.getParameterValue(nt);
Expand Down
23 changes: 12 additions & 11 deletions src/dr/evomodel/coalescent/GMRFSkyrideGradient.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
*/
public class GMRFSkyrideGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable {

private final OldGMRFSkyrideLikelihood skyrideLikelihood;
private final UnifiedGMRFSkyrideLikelihood skyrideLikelihood;
private final WrtParameter wrtParameter;
private final Parameter parameter;
private final OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping;
private final IntervalNodeMapProvider intervalNodeMapping;
private final NodeHeightTransform nodeHeightTransform;

public GMRFSkyrideGradient(OldGMRFSkyrideLikelihood gmrfSkyrideLikelihood,
public GMRFSkyrideGradient(UnifiedGMRFSkyrideLikelihood gmrfSkyrideLikelihood,
WrtParameter wrtParameter,
TreeModel tree,
NodeHeightTransform nodeHeightTransform) {
Expand Down Expand Up @@ -158,8 +158,8 @@ public enum WrtParameter {

COALESCENT_INTERVAL {
@Override
double[] getGradientLogDensity(OldGMRFSkyrideLikelihood skyrideLikelihood,
OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) {
double[] getGradientLogDensity(UnifiedGMRFSkyrideLikelihood skyrideLikelihood,
IntervalNodeMapProvider intervalNodeMapping) {
double[] unSortedNodeHeightGradient = super.getGradientLogDensityWrtUnsortedNodeHeight(skyrideLikelihood);
double[] intervalGradient = new double[unSortedNodeHeightGradient.length];
double accumulatedGradient = 0.0;
Expand All @@ -178,8 +178,8 @@ void update(NodeHeightTransform nodeHeightTransform, double[] values) {

NODE_HEIGHTS {
@Override
double[] getGradientLogDensity(OldGMRFSkyrideLikelihood skyrideLikelihood,
OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) {
double[] getGradientLogDensity(UnifiedGMRFSkyrideLikelihood skyrideLikelihood,
IntervalNodeMapProvider intervalNodeMapping) {
double[] unSortedNodeHeightGradient = getGradientLogDensityWrtUnsortedNodeHeight(skyrideLikelihood);
return intervalNodeMapping.sortByNodeNumbers(unSortedNodeHeightGradient);
}
Expand All @@ -190,18 +190,19 @@ void update(NodeHeightTransform nodeHeightTransform, double[] values) {
}
};

abstract double[] getGradientLogDensity(OldGMRFSkyrideLikelihood skyrideLikelihood,
OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping);
abstract double[] getGradientLogDensity(UnifiedGMRFSkyrideLikelihood skyrideLikelihood,
IntervalNodeMapProvider intervalNodeMapping);

abstract void update(NodeHeightTransform nodeHeightTransform, double[] values);

double[] getGradientLogDensityWrtUnsortedNodeHeight(OldGMRFSkyrideLikelihood skyrideLikelihood) {
double[] getGradientLogDensityWrtUnsortedNodeHeight(UnifiedGMRFSkyrideLikelihood skyrideLikelihood) {
double[] unSortedNodeHeightGradient = new double[skyrideLikelihood.getCoalescentIntervalDimension()];
double[] gamma = skyrideLikelihood.getPopSizeParameter().getParameterValues();

int index = 0;
for (int i = 0; i < skyrideLikelihood.getIntervalCount(); i++) {
if (skyrideLikelihood.getIntervalType(i) == OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) {
// if (skyrideLikelihood.getIntervalType(i) == OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) {
if (skyrideLikelihood.isCoalescentInterval(i)) {
double weight = -Math.exp(-gamma[index]) * skyrideLikelihood.getLineageCount(i) * (skyrideLikelihood.getLineageCount(i) - 1);
if (index < skyrideLikelihood.getCoalescentIntervalDimension() - 1 && i < skyrideLikelihood.getIntervalCount() - 1) {
weight -= -Math.exp(-gamma[index + 1]) * skyrideLikelihood.getLineageCount(i + 1) * (skyrideLikelihood.getLineageCount(i + 1) - 1);
Expand Down
Loading

0 comments on commit 6e24230

Please sign in to comment.