diff --git a/src/dr/evolution/coalescent/TreeIntervalList.java b/src/dr/evolution/coalescent/TreeIntervalList.java index 3eb93f1e08..34d43eefcb 100644 --- a/src/dr/evolution/coalescent/TreeIntervalList.java +++ b/src/dr/evolution/coalescent/TreeIntervalList.java @@ -29,6 +29,7 @@ import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; +import dr.evomodel.coalescent.IntervalNodeMapProvider; /** * This is interface for an Interval list that wraps a tree and provides @@ -88,4 +89,11 @@ public interface TreeIntervalList extends IntervalList{ * @return the tree */ Tree getTree(); + + /** gets the intervalNodeMap + * @return the intervalNodeMap + */ + IntervalNodeMapProvider getIntervalNodeMap(); + + void setBuildIntervalNodeMapping(boolean buildIntervalNodeMapping); } diff --git a/src/dr/evomodel/bigfasttree/BigFastTreeIntervals.java b/src/dr/evomodel/bigfasttree/BigFastTreeIntervals.java index 721cbe2a6a..383bdaa164 100644 --- a/src/dr/evomodel/bigfasttree/BigFastTreeIntervals.java +++ b/src/dr/evomodel/bigfasttree/BigFastTreeIntervals.java @@ -32,6 +32,7 @@ import dr.evolution.tree.NodeRef; import dr.evolution.tree.Tree; import dr.evolution.util.Units; +import dr.evomodel.coalescent.IntervalNodeMapProvider; import dr.evomodel.tree.TreeChangedEvent; import dr.evomodel.tree.TreeModel; import dr.inference.model.AbstractModel; @@ -49,7 +50,7 @@ * Smart intervals that don't need a full recalculation. * author: JT */ -public class BigFastTreeIntervals extends AbstractModel implements Units, TreeIntervalList { +public class BigFastTreeIntervals extends AbstractModel implements Units, TreeIntervalList, IntervalNodeMapProvider { public BigFastTreeIntervals(TreeModel tree) { this("bigFastIntervals",tree); } @@ -169,6 +170,14 @@ public boolean isCoalescentOnly() { return true; } + public void setBuildIntervalNodeMapping(boolean buildIntervalNodeMapping) { + // nothing done this is done by default with this tree model + } + public IntervalNodeMapProvider getIntervalNodeMap() { + return this; + } + + // Interval Node mapping @Override @@ -700,4 +709,5 @@ public void setNodeOrder(int nodeNum, int position) { private int intervalCount = 0; + } \ No newline at end of file diff --git a/src/dr/evomodel/coalescent/GMRFSkygridLikelihood.java b/src/dr/evomodel/coalescent/GMRFSkygridLikelihood.java index 44d7b5653a..3df851bb96 100644 --- a/src/dr/evomodel/coalescent/GMRFSkygridLikelihood.java +++ b/src/dr/evomodel/coalescent/GMRFSkygridLikelihood.java @@ -29,6 +29,8 @@ import dr.evolution.coalescent.IntervalList; import dr.evolution.coalescent.IntervalType; +import dr.evolution.coalescent.TreeIntervalList; +import dr.evolution.tree.Tree; import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser; import dr.inference.hmc.GradientWrtParameterProvider; import dr.inference.model.Likelihood; @@ -96,9 +98,9 @@ public class GMRFSkygridLikelihood extends GMRFSkyrideLikelihood // private List treeList; // private List intervalsList; - private List intervalsList; + private List intervalsList; - public GMRFSkygridLikelihood(List intervalsList, + public GMRFSkygridLikelihood(List intervalsList, Parameter popParameter, Parameter groupParameter, Parameter precParameter, @@ -156,7 +158,8 @@ public GMRFSkygridLikelihood(List intervalsList, addVariable(ploidyFactors); this.intervalsList = intervalsList; - //this.numTrees = setTree(treeList); + //this.numTrees = setTree(treeList); + //TODO are we accounting for what setTree does? for (IntervalList intervalList : intervalsList) { addModel((Model) intervalList); @@ -219,7 +222,7 @@ public GMRFSkygridLikelihood(List intervalsList, //rewrite this constructor without duplicating so much code - public GMRFSkygridLikelihood(List intervalsList, + public GMRFSkygridLikelihood(List intervalsList, Parameter popParameter, Parameter groupParameter, Parameter precParameter, @@ -921,13 +924,13 @@ public int nLoci() { return intervalsList.size(); } - /* public Tree getTree(int nt) { - return treeList.get(nt); + public Tree getTree(int nt) { + return intervalsList.get(nt).getTree(); } - public TreeIntervals getTreeIntervals(int nt) { + public TreeIntervalList getTreeIntervals(int nt) { return intervalsList.get(nt); - } */ + } public double getPopulationFactor(int nt) { return ploidyFactors.getParameterValue(nt); diff --git a/src/dr/evomodel/coalescent/GMRFSkyrideGradient.java b/src/dr/evomodel/coalescent/GMRFSkyrideGradient.java index d4f02de6bd..4c83210147 100644 --- a/src/dr/evomodel/coalescent/GMRFSkyrideGradient.java +++ b/src/dr/evomodel/coalescent/GMRFSkyrideGradient.java @@ -42,13 +42,13 @@ */ public class GMRFSkyrideGradient implements GradientWrtParameterProvider, HessianWrtParameterProvider, Reportable { - private final OldGMRFSkyrideLikelihood skyrideLikelihood; + private final UnifiedGMRFSkyrideLikelihood skyrideLikelihood; private final WrtParameter wrtParameter; private final Parameter parameter; - private final OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping; + private final IntervalNodeMapProvider intervalNodeMapping; private final NodeHeightTransform nodeHeightTransform; - public GMRFSkyrideGradient(OldGMRFSkyrideLikelihood gmrfSkyrideLikelihood, + public GMRFSkyrideGradient(UnifiedGMRFSkyrideLikelihood gmrfSkyrideLikelihood, WrtParameter wrtParameter, TreeModel tree, NodeHeightTransform nodeHeightTransform) { @@ -158,8 +158,8 @@ public enum WrtParameter { COALESCENT_INTERVAL { @Override - double[] getGradientLogDensity(OldGMRFSkyrideLikelihood skyrideLikelihood, - OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) { + double[] getGradientLogDensity(UnifiedGMRFSkyrideLikelihood skyrideLikelihood, + IntervalNodeMapProvider intervalNodeMapping) { double[] unSortedNodeHeightGradient = super.getGradientLogDensityWrtUnsortedNodeHeight(skyrideLikelihood); double[] intervalGradient = new double[unSortedNodeHeightGradient.length]; double accumulatedGradient = 0.0; @@ -178,8 +178,8 @@ void update(NodeHeightTransform nodeHeightTransform, double[] values) { NODE_HEIGHTS { @Override - double[] getGradientLogDensity(OldGMRFSkyrideLikelihood skyrideLikelihood, - OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping) { + double[] getGradientLogDensity(UnifiedGMRFSkyrideLikelihood skyrideLikelihood, + IntervalNodeMapProvider intervalNodeMapping) { double[] unSortedNodeHeightGradient = getGradientLogDensityWrtUnsortedNodeHeight(skyrideLikelihood); return intervalNodeMapping.sortByNodeNumbers(unSortedNodeHeightGradient); } @@ -190,18 +190,19 @@ void update(NodeHeightTransform nodeHeightTransform, double[] values) { } }; - abstract double[] getGradientLogDensity(OldGMRFSkyrideLikelihood skyrideLikelihood, - OldAbstractCoalescentLikelihood.IntervalNodeMapping intervalNodeMapping); + abstract double[] getGradientLogDensity(UnifiedGMRFSkyrideLikelihood skyrideLikelihood, + IntervalNodeMapProvider intervalNodeMapping); abstract void update(NodeHeightTransform nodeHeightTransform, double[] values); - double[] getGradientLogDensityWrtUnsortedNodeHeight(OldGMRFSkyrideLikelihood skyrideLikelihood) { + double[] getGradientLogDensityWrtUnsortedNodeHeight(UnifiedGMRFSkyrideLikelihood skyrideLikelihood) { double[] unSortedNodeHeightGradient = new double[skyrideLikelihood.getCoalescentIntervalDimension()]; double[] gamma = skyrideLikelihood.getPopSizeParameter().getParameterValues(); int index = 0; for (int i = 0; i < skyrideLikelihood.getIntervalCount(); i++) { - if (skyrideLikelihood.getIntervalType(i) == OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) { + // if (skyrideLikelihood.getIntervalType(i) == OldAbstractCoalescentLikelihood.CoalescentEventType.COALESCENT) { + if (skyrideLikelihood.isCoalescentInterval(i)) { double weight = -Math.exp(-gamma[index]) * skyrideLikelihood.getLineageCount(i) * (skyrideLikelihood.getLineageCount(i) - 1); if (index < skyrideLikelihood.getCoalescentIntervalDimension() - 1 && i < skyrideLikelihood.getIntervalCount() - 1) { weight -= -Math.exp(-gamma[index + 1]) * skyrideLikelihood.getLineageCount(i + 1) * (skyrideLikelihood.getLineageCount(i + 1) - 1); diff --git a/src/dr/evomodel/coalescent/GMRFSkyrideLikelihood.java b/src/dr/evomodel/coalescent/GMRFSkyrideLikelihood.java index b139bedab4..c82d9126ab 100644 --- a/src/dr/evomodel/coalescent/GMRFSkyrideLikelihood.java +++ b/src/dr/evomodel/coalescent/GMRFSkyrideLikelihood.java @@ -27,14 +27,13 @@ package dr.evomodel.coalescent; -import dr.evolution.coalescent.IntervalList; import dr.evolution.coalescent.IntervalType; import dr.evolution.coalescent.TreeIntervalList; import dr.evolution.tree.NodeRef; +import dr.evolution.tree.Tree; import dr.evomodel.tree.TreeModel; import dr.evomodelxml.coalescent.GMRFSkyrideLikelihoodParser; import dr.inference.model.MatrixParameter; -import dr.inference.model.Model; import dr.inference.model.Parameter; import dr.inference.model.Variable; import dr.math.MathUtils; @@ -44,6 +43,7 @@ import no.uib.cipr.matrix.DenseVector; import no.uib.cipr.matrix.SymmTridiagMatrix; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -55,7 +55,7 @@ * @author Vladimir Minin * @author Marc Suchard */ -public class GMRFSkyrideLikelihood extends AbstractCoalescentLikelihood implements Citable { +public class GMRFSkyrideLikelihood extends AbstractCoalescentLikelihood implements UnifiedGMRFSkyrideLikelihood,Citable { // PUBLIC STUFF @@ -64,7 +64,7 @@ public class GMRFSkyrideLikelihood extends AbstractCoalescentLikelihood implemen // PRIVATE STUFF - private IntervalList intervalList; + private TreeIntervalList intervalList; protected Parameter popSizeParameter; protected Parameter groupSizeParameter; protected Parameter precisionParameter; @@ -87,7 +87,7 @@ public class GMRFSkyrideLikelihood extends AbstractCoalescentLikelihood implemen protected MatrixParameter dMatrix; protected boolean timeAwareSmoothing = TIME_AWARE_IS_ON_BY_DEFAULT; protected boolean rescaleByRootHeight; - + private boolean buildIntervalNodeMapping; public GMRFSkyrideLikelihood() { super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD); } @@ -96,76 +96,31 @@ public GMRFSkyrideLikelihood(String name) { super(name); } - public GMRFSkyrideLikelihood(IntervalList intervalList, Parameter popParameter, Parameter groupParameter, Parameter precParameter, + public GMRFSkyrideLikelihood(TreeIntervalList intervalList, Parameter popParameter, Parameter groupParameter, Parameter precParameter, Parameter lambda, Parameter beta, MatrixParameter dMatrix, boolean timeAwareSmoothing, boolean rescaleByRootHeight) { + this(wrapIntervals(intervalList), popParameter, groupParameter, precParameter, lambda, beta, dMatrix, timeAwareSmoothing, rescaleByRootHeight, false); + } - super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD,intervalList); - - // adding the key word to the the model means the keyword will be logged in the - // header of the logfile. - this.addKeyword("skyride"); - - this.intervalList = this.getIntervalList(); // pull this down from the AbtractCoalescentLikelihood for ease of access. - this.popSizeParameter = popParameter; - this.groupSizeParameter = groupParameter; - this.precisionParameter = precParameter; - this.lambdaParameter = lambda; - this.betaParameter = beta; - this.dMatrix = dMatrix; - this.timeAwareSmoothing = timeAwareSmoothing; - this.rescaleByRootHeight = rescaleByRootHeight; - - addVariable(popSizeParameter); - addVariable(precisionParameter); - addVariable(lambdaParameter); - if (betaParameter != null) { - addVariable(betaParameter); - } - addModel((Model)intervalList); - - //setTree(treeList); - - int correctFieldLength = getCorrectFieldLength(); - - if (popSizeParameter.getDimension() <= 1) { - // popSize dimension hasn't been set yet, set it here: - popSizeParameter.setDimension(correctFieldLength); - } - - fieldLength = popSizeParameter.getDimension(); - if (correctFieldLength != fieldLength) { - throw new IllegalArgumentException("Population size parameter should have length " + correctFieldLength); - } - - // Field length must be set by this point - //wrapSetupIntervals(); - coalescentIntervals = new double[fieldLength]; - storedCoalescentIntervals = new double[fieldLength]; - sufficientStatistics = new double[fieldLength]; - storedSufficientStatistics = new double[fieldLength]; - - setupGMRFWeights(); - - addStatistic(new DeltaStatistic()); - - initializationReport(); - - /* Force all entries in groupSizeParameter = 1 for compatibility with Tracer */ - if (groupSizeParameter != null) { - for (int i = 0; i < groupSizeParameter.getDimension(); i++) - groupSizeParameter.setParameterValue(i, 1.0); - } + private static List wrapIntervals(TreeIntervalList intervals) { + List intervalList = new ArrayList(); + intervalList.add(intervals); + return intervalList; } - - - public GMRFSkyrideLikelihood(List intervalsList, Parameter popParameter, Parameter groupParameter, Parameter precParameter, + public GMRFSkyrideLikelihood(List intervalsList, Parameter popParameter, Parameter groupParameter, Parameter precParameter, Parameter lambda, Parameter beta, MatrixParameter dMatrix, - boolean timeAwareSmoothing, boolean rescaleByRootHeight) { + boolean timeAwareSmoothing, boolean rescaleByRootHeight, + boolean buildIntervalNodeMapping) { - super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD,intervalsList.get(0)); + super(GMRFSkyrideLikelihoodParser.SKYLINE_LIKELIHOOD, intervalsList.get(0)); // intervals watched by AbstractCoalescentLikelihood + if (intervalsList.size() != 1) { + throw new RuntimeException("GMRFSkyrideLikelihood only implemented for one tree"); + } + this.intervalList = intervalsList.get(0); // pull this down from the AbtractCoalescentLikelihood for ease of access. + + // adding the key word to the the model means the keyword will be logged in the // header of the logfile. this.addKeyword("skyride"); @@ -186,8 +141,7 @@ public GMRFSkyrideLikelihood(List intervalsList, Parameter popPara addVariable(betaParameter); } - //setTree(treeList); - setIntervalList(intervalsList); + int correctFieldLength = getCorrectFieldLength(); @@ -201,6 +155,8 @@ public GMRFSkyrideLikelihood(List intervalsList, Parameter popPara throw new IllegalArgumentException("Population size parameter should have length " + correctFieldLength); } + this.buildIntervalNodeMapping = buildIntervalNodeMapping; + // Field length must be set by this point //wrapSetupIntervals(); coalescentIntervals = new double[fieldLength]; @@ -208,6 +164,18 @@ public GMRFSkyrideLikelihood(List intervalsList, Parameter popPara sufficientStatistics = new double[fieldLength]; storedSufficientStatistics = new double[fieldLength]; +// coalescentIntervals = new Parameter.Default(fieldLength); + if (buildIntervalNodeMapping) { + if(!intervalList.isBuildIntervalNodeMapping()){ + intervalList.setBuildIntervalNodeMapping(true); + } + } else { + if(intervalList.isBuildIntervalNodeMapping()){ + intervalList.setBuildIntervalNodeMapping(false); + } + } + + setupGMRFWeights(); addStatistic(new DeltaStatistic()); @@ -222,10 +190,8 @@ public GMRFSkyrideLikelihood(List intervalsList, Parameter popPara } protected int getCorrectFieldLength() { - //return tree.getExternalNodeCount() - 1; - - // Change later in case we are not working with entire tree - return intervalList.getSampleCount()-1; + Tree tree = intervalList.getTree(); + return tree.getExternalNodeCount() - 1; } /* @@ -246,22 +212,15 @@ protected int setTree(List treeList) { return 1; } */ - protected int setIntervalList(List intervalsList) { - if (intervalsList.size() != 1) { - throw new RuntimeException("GMRFSkyrideLikelihood only implemented for one tree"); - } - this.intervalList = this.getIntervalList(); // pull this down from the AbtractCoalescentLikelihood for ease of access. - return 1; - } // public double[] getCopyOfCoalescentIntervals() { // return coalescentIntervals.clone(); // } // -// public double[] getCoalescentIntervals() { -// return coalescentIntervals; -// } + public double[] getCoalescentIntervals() { + return coalescentIntervals; + } public void initializationReport() { System.out.println("Creating a GMRF smoothed skyride model:"); @@ -335,12 +294,18 @@ protected void setupSufficientStatistics() { } } } + //TODO update this so we can use the NodeMapping in the intervalList + // TREE INTERVAL LIST is an intervarnodeMapping. BFTreeIntervalList could be as well + public IntervalNodeMapProvider getIntervalNodeMapping() { + return this.intervalList.getIntervalNodeMap(); + } protected double getFieldScalar() { final double rootHeight; if (rescaleByRootHeight) { // rootHeight = tree.getNodeHeight(tree.getRoot()); - rootHeight = intervalList.getTotalDuration(); + Tree tree = intervalList.getTree(); + rootHeight = tree.getNodeHeight(tree.getRoot()); } else { rootHeight = 1.0; } @@ -391,11 +356,6 @@ public SymmTridiagMatrix getScaledWeightMatrix(double precision) { return a; } - public void setupCoalescentIntervals() { -// this.intervalList.calculateIntervals(); // Done lazily in IntervalList - setupSufficientStatistics(); - } - public SymmTridiagMatrix getStoredScaledWeightMatrix(double precision) { SymmTridiagMatrix a = storedWeightMatrix.copy(); for (int i = 0; i < a.numRows() - 1; i++) { @@ -422,7 +382,6 @@ public SymmTridiagMatrix getScaledWeightMatrix(double precision, double lambda) private void makeIntervalsKnown() { if (!intervalsKnown) { - // intervalsKnown -> false when handleModelChanged event occurs in super. intervalList.calculateIntervals(); setupGMRFWeights(); intervalsKnown = true; @@ -448,16 +407,20 @@ public IntervalType getCoalescentIntervalType(int i) { }*/ public int getNumberOfCoalescentEvents() { - // return tree.getExternalNodeCount() - 1; - - // Change later to deal with case where we aren't working with entire tree - return intervalList.getSampleCount() - 1; + Tree tree = intervalList.getTree(); + return tree.getExternalNodeCount() - 1; } public double getCoalescentEventsStatisticValue(int i) { return sufficientStatistics[i]; } + public void setupCoalescentIntervals() { + // this.intervalList.calculateIntervals(); // Done lazily in IntervalList + setupSufficientStatistics(); + } + + public double[] getCoalescentIntervalHeights() { makeIntervalsKnown(); double[] a = new double[coalescentIntervals.length]; @@ -470,6 +433,22 @@ public double[] getCoalescentIntervalHeights() { return a; } + @Override + public int getIntervalCount() { + // TODO Auto-generated method stub + return this.intervalList.getIntervalCount(); + } + + @Override + public boolean isCoalescentInterval(int interval) { + return this.intervalList.getIntervalType(interval) == IntervalType.COALESCENT; + } + + @Override + public int getLineageCount(int interval) { + return this.intervalList.getLineageCount(interval); + } + public SymmTridiagMatrix getCopyWeightMatrix() { return weightMatrix.copy(); } @@ -493,6 +472,7 @@ protected void storeState() { super.storeState(); System.arraycopy(coalescentIntervals, 0, storedCoalescentIntervals, 0, coalescentIntervals.length); System.arraycopy(sufficientStatistics, 0, storedSufficientStatistics, 0, sufficientStatistics.length); + storedWeightMatrix = weightMatrix.copy(); storedLogFieldLikelihood = logFieldLikelihood; } @@ -500,9 +480,17 @@ protected void storeState() { protected void restoreState() { super.restoreState(); - // TODO Just swap pointers - System.arraycopy(storedCoalescentIntervals, 0, coalescentIntervals, 0, storedCoalescentIntervals.length); - System.arraycopy(storedSufficientStatistics, 0, sufficientStatistics, 0, storedSufficientStatistics.length); + // TODO Just swap pointers XJ: there you go +// System.arraycopy(storedCoalescentIntervals, 0, coalescentIntervals, 0, storedCoalescentIntervals.length); +// System.arraycopy(storedSufficientStatistics, 0, sufficientStatistics, 0, storedSufficientStatistics.length); + + double[] tmp = coalescentIntervals; + coalescentIntervals = storedCoalescentIntervals; + storedCoalescentIntervals = tmp; + tmp = sufficientStatistics; + sufficientStatistics = storedSufficientStatistics; + storedSufficientStatistics = tmp; + weightMatrix = storedWeightMatrix; logFieldLikelihood = storedLogFieldLikelihood; } @@ -668,6 +656,7 @@ public List getCitations() { 25, 1459, 1471, "10.1093/molbev/msn090" ); + } /* diff --git a/src/dr/evomodel/coalescent/IntervalNodeMapProvider.java b/src/dr/evomodel/coalescent/IntervalNodeMapProvider.java new file mode 100644 index 0000000000..e1160c2120 --- /dev/null +++ b/src/dr/evomodel/coalescent/IntervalNodeMapProvider.java @@ -0,0 +1,244 @@ +package dr.evomodel.coalescent; + +import java.util.ArrayList; +import java.util.Arrays; + +import dr.evolution.tree.Tree; +import dr.util.ComparableDouble; +import dr.util.HeapSort; + +public interface IntervalNodeMapProvider { + + + int[] getIntervalsForNode(int nodeNumber); + int[] getNodeNumbersForInterval(int interval); + double[] sortByNodeNumbers(double[] byIntervalOrder); + + interface IntervalNodeMapping extends IntervalNodeMapProvider { + void addNode(int nodeNumbe); + void setIntervalStartIndices(int intervalCount); + void initializeMaps(); + public void storeMapping(); + public void restoreMapping(); + } + class Default implements IntervalNodeMapping { + private int[] nodeNumbersInIntervals; + private int[] intervalStartIndices; + private int[] intervalNumberOfNodes; + + private int[] storedNodeNumbersInIntervals; + private int[] storedIntervalStartIndices; + private int[] storedIntervalNumberOfNodes; + + private int nextIndex = 0; + private int nIntervals; + private Tree tree; + private final int maxIndicesPerNode = 3; + + public Default (int maxIntervalCount, Tree tree) { + nodeNumbersInIntervals = new int[maxIndicesPerNode * maxIntervalCount]; + storedNodeNumbersInIntervals = new int[maxIndicesPerNode * maxIntervalCount]; + + intervalStartIndices = new int[maxIntervalCount]; + storedIntervalStartIndices = new int[maxIntervalCount]; + + intervalNumberOfNodes = new int[maxIndicesPerNode * maxIntervalCount]; + storedIntervalNumberOfNodes = new int[maxIndicesPerNode * maxIntervalCount]; + this.tree = tree; + } + + public void addNode(int nodeNumber) { + nodeNumbersInIntervals[nextIndex] = nodeNumber; + nextIndex++; + } + + private void mapNodeInterval(int nodeNumber, int intervalNumber) { + int index = 0; + while(index < maxIndicesPerNode) { + if (intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + index] == -1) { + intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + index] = intervalNumber; + break; + } else { + index++; + } + } + if (index == maxIndicesPerNode) { + throw new RuntimeException("The node appears in more than" + maxIndicesPerNode + " intervals!"); + } +// if (intervalNumberOfNodes[maxIndicesPerNode * nodeNumber] == -1 || intervalNumberOfNodes[maxIndicesPerNode * nodeNumber] == intervalNumber) { +// intervalNumberOfNodes[3 * nodeNumber] = intervalNumber; +// } else if (intervalNumberOfNodes[2 * nodeNumber + 1] == -1) { +// intervalNumberOfNodes[3 * nodeNumber + 1] = intervalNumber; +// } else { +// double[] testIntervals = new double[nIntervals]; +// for (int i = 0; i < nIntervals - 1; i++) { +// testIntervals[i] = tree.getNodeHeight(tree.getNode(nodeNumbersInIntervals[intervalStartIndices[i + 1]])) +// - tree.getNodeHeight(tree.getNode(nodeNumbersInIntervals[intervalStartIndices[i]])); +// } +// throw new RuntimeException("The node appears in more than two intervals!"); +// } + } + + public void setIntervalStartIndices(int intervalCount) { + + if (nodeNumbersInIntervals[nextIndex - 1] == nodeNumbersInIntervals[nextIndex - 2]) { + nodeNumbersInIntervals[nextIndex - 1] = 0; + nextIndex--; + } + + int index = 1; + mapNodeInterval(nodeNumbersInIntervals[0], 0); + + for (int i = 1; i < intervalCount; i++) { + + while(nodeNumbersInIntervals[index] != nodeNumbersInIntervals[index - 1]) { + mapNodeInterval(nodeNumbersInIntervals[index], i - 1); + index++; + } + + intervalStartIndices[i] = index; + mapNodeInterval(nodeNumbersInIntervals[index], i); + index++; + + } + + while(index < nextIndex) { + mapNodeInterval(nodeNumbersInIntervals[index], intervalCount - 1); + index++; + } + + nIntervals = intervalCount; + } + + public void initializeMaps() { + Arrays.fill(intervalNumberOfNodes, -1); + Arrays.fill(intervalStartIndices, 0); + Arrays.fill(nodeNumbersInIntervals, 0); + nextIndex = 0; + } + + @Override + public int[] getIntervalsForNode(int nodeNumber) { + int nonZeros = 0; + while(intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + nonZeros] != -1) { + nonZeros++; + } + int[] result = new int[nonZeros]; + for(int i = 0; i < nonZeros; i++) { + result[i] = intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + i]; + } + return result; +// if(intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + 1] == -1) { +// return new int[]{intervalNumberOfNodes[maxIndicesPerNode * nodeNumber]}; +// } else { +// return new int[]{intervalNumberOfNodes[maxIndicesPerNode * nodeNumber], intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + 1]}; +// } + } + + @Override + public int[] getNodeNumbersForInterval(int interval) { + assert(interval < nIntervals); + + final int startIndex = intervalStartIndices[interval]; + int endIndex; + if (interval == nIntervals - 1) { + endIndex = nextIndex - 1; + } else { + endIndex = intervalStartIndices[interval + 1] - 1; + } + + int[] nodeNumbers = new int[endIndex - startIndex + 1]; + + for (int i = 0; i < endIndex - startIndex + 1; i++) { + nodeNumbers[i] = nodeNumbersInIntervals[startIndex + i]; + } + return nodeNumbers; + } + + @Override + public double[] sortByNodeNumbers(double[] byIntervalOrder) { + double[] sortedValues = new double[byIntervalOrder.length]; + int[] nodeIndices = new int[byIntervalOrder.length]; + ArrayList mappedIntervals = new ArrayList(); + for (int i = 0; i < nodeIndices.length; i++) { + mappedIntervals.add(new ComparableDouble(getIntervalsForNode(i + tree.getExternalNodeCount())[0])); + } + HeapSort.sort(mappedIntervals, nodeIndices); + for (int i = 0; i < nodeIndices.length; i++) { + sortedValues[nodeIndices[i]] = byIntervalOrder[i]; + } + return sortedValues; + } + + /** + * Additional state information, outside of the sub-model is stored by this call. + */ + public void storeMapping() { + System.arraycopy(nodeNumbersInIntervals,0,storedNodeNumbersInIntervals,0,nodeNumbersInIntervals.length); + System.arraycopy(intervalNumberOfNodes,0,storedIntervalNumberOfNodes,0,intervalNumberOfNodes.length); + System.arraycopy(intervalStartIndices,0,storedIntervalStartIndices,0,intervalStartIndices.length); + } + + /** + * After this call the model is guaranteed to have returned its extra state information to + * the values coinciding with the last storeState call. + * Sub-models are handled automatically and do not need to be considered in this method. + */ + public void restoreMapping() { + int[] tmp = storedNodeNumbersInIntervals; + storedNodeNumbersInIntervals = nodeNumbersInIntervals; + nodeNumbersInIntervals = tmp; + + int[] tmp2 = storedIntervalNumberOfNodes; + storedIntervalNumberOfNodes=intervalNumberOfNodes; + intervalNumberOfNodes = tmp2; + + int[] tmp3= storedIntervalStartIndices; + storedIntervalStartIndices = intervalStartIndices; + intervalStartIndices =tmp3; + } + } + + class None implements IntervalNodeMapping { + + @Override + public void addNode(int nodeNumber) { + // Do nothing + } + + @Override + public void setIntervalStartIndices(int intervalCount) { + // Do nothing + } + + @Override + public void initializeMaps() { + // Do nothing + } + + @Override + public int[] getIntervalsForNode(int nodeNumber) { + throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); + } + + @Override + public int[] getNodeNumbersForInterval(int interval) { + throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); + } + + @Override + public double[] sortByNodeNumbers(double[] byIntervalOrder) { + throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); + } + public void storeMapping() { + // Do nothing + } + public void restoreMapping() { + // Do nothing + } + } +} + + + + diff --git a/src/dr/evomodel/coalescent/OldAbstractCoalescentLikelihood.java b/src/dr/evomodel/coalescent/OldAbstractCoalescentLikelihood.java index 861ba50f09..ba8e901ee1 100644 --- a/src/dr/evomodel/coalescent/OldAbstractCoalescentLikelihood.java +++ b/src/dr/evomodel/coalescent/OldAbstractCoalescentLikelihood.java @@ -423,9 +423,9 @@ public final void setupIntervals() { storedIntervals = new double[maxIntervalCount]; storedLineageCounts = new int[maxIntervalCount]; if (buildIntervalNodeMapping) { - intervalNodeMapping = new IntervalNodeMapping.Default(tree.getNodeCount(), tree); + intervalNodeMapping = new IntervalNodeMapProvider.IntervalNodeMapping.Default(tree.getNodeCount(), tree); } else { - intervalNodeMapping = new IntervalNodeMapping.None(); + intervalNodeMapping = new IntervalNodeMapProvider.IntervalNodeMapping.None(); } } @@ -490,192 +490,8 @@ public XTreeIntervals(double[] intervals, int[] lineageCounts) { } - public interface IntervalNodeMapping { - void addNode(int nodeNumbe); - void setIntervalStartIndices(int intervalCount); - void initializeMaps(); - - int[] getIntervalsForNode(int nodeNumber); - int[] getNodeNumbersForInterval(int interval); - double[] sortByNodeNumbers(double[] byIntervalOrder); - - class Default implements IntervalNodeMapping { - final int[] nodeNumbersInIntervals; - final int[] intervalStartIndices; - final int[] intervalNumberOfNodes; - private int nextIndex = 0; - private int nIntervals; - private Tree tree; - - private final int maxIndicesPerNode = 3; - - public Default (int maxIntervalCount, Tree tree) { - nodeNumbersInIntervals = new int[maxIndicesPerNode * maxIntervalCount]; - intervalStartIndices = new int[maxIntervalCount]; - intervalNumberOfNodes = new int[maxIndicesPerNode * maxIntervalCount]; - this.tree = tree; - } - - public void addNode(int nodeNumber) { - nodeNumbersInIntervals[nextIndex] = nodeNumber; - nextIndex++; - } - - private void mapNodeInterval(int nodeNumber, int intervalNumber) { - int index = 0; - while(index < maxIndicesPerNode) { - if (intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + index] == -1) { - intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + index] = intervalNumber; - break; - } else { - index++; - } - } - if (index == maxIndicesPerNode) { - throw new RuntimeException("The node appears in more than" + maxIndicesPerNode + " intervals!"); - } -// if (intervalNumberOfNodes[maxIndicesPerNode * nodeNumber] == -1 || intervalNumberOfNodes[maxIndicesPerNode * nodeNumber] == intervalNumber) { -// intervalNumberOfNodes[3 * nodeNumber] = intervalNumber; -// } else if (intervalNumberOfNodes[2 * nodeNumber + 1] == -1) { -// intervalNumberOfNodes[3 * nodeNumber + 1] = intervalNumber; -// } else { -// double[] testIntervals = new double[nIntervals]; -// for (int i = 0; i < nIntervals - 1; i++) { -// testIntervals[i] = tree.getNodeHeight(tree.getNode(nodeNumbersInIntervals[intervalStartIndices[i + 1]])) -// - tree.getNodeHeight(tree.getNode(nodeNumbersInIntervals[intervalStartIndices[i]])); -// } -// throw new RuntimeException("The node appears in more than two intervals!"); -// } - } - - public void setIntervalStartIndices(int intervalCount) { - - if (nodeNumbersInIntervals[nextIndex - 1] == nodeNumbersInIntervals[nextIndex - 2]) { - nodeNumbersInIntervals[nextIndex - 1] = 0; - nextIndex--; - } - - int index = 1; - mapNodeInterval(nodeNumbersInIntervals[0], 0); - - for (int i = 1; i < intervalCount; i++) { - - while(nodeNumbersInIntervals[index] != nodeNumbersInIntervals[index - 1]) { - mapNodeInterval(nodeNumbersInIntervals[index], i - 1); - index++; - } - - intervalStartIndices[i] = index; - mapNodeInterval(nodeNumbersInIntervals[index], i); - index++; - - } - - while(index < nextIndex) { - mapNodeInterval(nodeNumbersInIntervals[index], intervalCount - 1); - index++; - } - - nIntervals = intervalCount; - } - - public void initializeMaps() { - Arrays.fill(intervalNumberOfNodes, -1); - Arrays.fill(intervalStartIndices, 0); - Arrays.fill(nodeNumbersInIntervals, 0); - nextIndex = 0; - } - - @Override - public int[] getIntervalsForNode(int nodeNumber) { - int nonZeros = 0; - while(intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + nonZeros] != -1) { - nonZeros++; - } - int[] result = new int[nonZeros]; - for(int i = 0; i < nonZeros; i++) { - result[i] = intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + i]; - } - return result; -// if(intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + 1] == -1) { -// return new int[]{intervalNumberOfNodes[maxIndicesPerNode * nodeNumber]}; -// } else { -// return new int[]{intervalNumberOfNodes[maxIndicesPerNode * nodeNumber], intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + 1]}; -// } - } - - @Override - public int[] getNodeNumbersForInterval(int interval) { - assert(interval < nIntervals); - - final int startIndex = intervalStartIndices[interval]; - int endIndex; - if (interval == nIntervals - 1) { - endIndex = nextIndex - 1; - } else { - endIndex = intervalStartIndices[interval + 1] - 1; - } - - int[] nodeNumbers = new int[endIndex - startIndex + 1]; - - for (int i = 0; i < endIndex - startIndex + 1; i++) { - nodeNumbers[i] = nodeNumbersInIntervals[startIndex + i]; - } - return nodeNumbers; - } - - @Override - public double[] sortByNodeNumbers(double[] byIntervalOrder) { - double[] sortedValues = new double[byIntervalOrder.length]; - int[] nodeIndices = new int[byIntervalOrder.length]; - ArrayList mappedIntervals = new ArrayList(); - for (int i = 0; i < nodeIndices.length; i++) { - mappedIntervals.add(new ComparableDouble(getIntervalsForNode(i + tree.getExternalNodeCount())[0])); - } - HeapSort.sort(mappedIntervals, nodeIndices); - for (int i = 0; i < nodeIndices.length; i++) { - sortedValues[nodeIndices[i]] = byIntervalOrder[i]; - } - return sortedValues; - } - } - - class None implements IntervalNodeMapping { - - @Override - public void addNode(int nodeNumber) { - // Do nothing - } - - @Override - public void setIntervalStartIndices(int intervalCount) { - // Do nothing - } - - @Override - public void initializeMaps() { - // Do nothing - } - - @Override - public int[] getIntervalsForNode(int nodeNumber) { - throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); - } - - @Override - public int[] getNodeNumbersForInterval(int interval) { - throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); - } - - @Override - public double[] sortByNodeNumbers(double[] byIntervalOrder) { - throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); - } - } - } - - private static void getTreeIntervals(Tree tree, NodeRef root, NodeRef[] exclude, XTreeIntervals ti, IntervalNodeMapping intervalNodeMapping) { + private static void getTreeIntervals(Tree tree, NodeRef root, NodeRef[] exclude, XTreeIntervals ti, IntervalNodeMapProvider.IntervalNodeMapping intervalNodeMapping) { double MULTIFURCATION_LIMIT = 1e-9; ArrayList times = new ArrayList(); @@ -850,7 +666,7 @@ public final boolean isCoalescentOnly() { return true; } - public IntervalNodeMapping getIntervalNodeMapping() { + public IntervalNodeMapProvider getIntervalNodeMapping() { return intervalNodeMapping; } @@ -1051,7 +867,7 @@ public String getParserDescription() { int[] lineageCounts; private int[] storedLineageCounts; - IntervalNodeMapping intervalNodeMapping; + IntervalNodeMapProvider.IntervalNodeMapping intervalNodeMapping; boolean intervalsKnown = false; private boolean storedIntervalsKnown = false; @@ -1063,4 +879,5 @@ public String getParserDescription() { int intervalCount = 0; private int storedIntervalCount = 0; + } diff --git a/src/dr/evomodel/coalescent/OldGMRFSkyrideLikelihood.java b/src/dr/evomodel/coalescent/OldGMRFSkyrideLikelihood.java index a84eda8f93..8ed643ac27 100644 --- a/src/dr/evomodel/coalescent/OldGMRFSkyrideLikelihood.java +++ b/src/dr/evomodel/coalescent/OldGMRFSkyrideLikelihood.java @@ -56,7 +56,7 @@ * @author Marc Suchard */ @Deprecated -public class OldGMRFSkyrideLikelihood extends OldAbstractCoalescentLikelihood implements CoalescentIntervalProvider, Citable { +public class OldGMRFSkyrideLikelihood extends OldAbstractCoalescentLikelihood implements CoalescentIntervalProvider, Citable,UnifiedGMRFSkyrideLikelihood { // PUBLIC STUFF @@ -166,9 +166,9 @@ public OldGMRFSkyrideLikelihood(List treeList, Parameter popParameter, Par // coalescentIntervals = new Parameter.Default(fieldLength); if (buildIntervalNodeMapping) { - coalesentIntervalNodeMapping = new IntervalNodeMapping.Default(tree.getNodeCount(), tree); + coalesentIntervalNodeMapping = new IntervalNodeMapProvider.IntervalNodeMapping.Default(tree.getNodeCount(), tree); } else { - coalesentIntervalNodeMapping = new IntervalNodeMapping.None(); + coalesentIntervalNodeMapping = new IntervalNodeMapProvider.IntervalNodeMapping.None(); } @@ -212,6 +212,11 @@ protected int setTree(List treeList) { public double[] getCoalescentIntervals() { return coalescentIntervals; } + + @Override + public boolean isCoalescentInterval(int interval) { + return getIntervalType(interval) == CoalescentEventType.COALESCENT; + } public void initializationReport() { System.out.println("Creating a GMRF smoothed skyride model:"); @@ -270,7 +275,7 @@ public String toString() { return getId() + "(" + Double.toString(getLogLikelihood()) + ")"; } - IntervalNodeMapping coalesentIntervalNodeMapping; + IntervalNodeMapProvider.IntervalNodeMapping coalesentIntervalNodeMapping; protected void setupSufficientStatistics() { int index = 0; @@ -302,7 +307,7 @@ protected void setupSufficientStatistics() { coalesentIntervalNodeMapping.setIntervalStartIndices(index); } - public IntervalNodeMapping getIntervalNodeMapping() { + public IntervalNodeMapProvider getIntervalNodeMapping() { return coalesentIntervalNodeMapping; } @@ -622,6 +627,7 @@ public List getCitations() { 25, 1459, 1471, "10.1093/molbev/msn090" ); + } /* diff --git a/src/dr/evomodel/coalescent/TreeIntervals.java b/src/dr/evomodel/coalescent/TreeIntervals.java index afbd298c89..5c5b450d41 100644 --- a/src/dr/evomodel/coalescent/TreeIntervals.java +++ b/src/dr/evomodel/coalescent/TreeIntervals.java @@ -35,9 +35,6 @@ import dr.evolution.util.Units; import dr.evomodel.tree.TreeModel; import dr.inference.model.*; -import dr.util.ComparableDouble; -import dr.util.HeapSort; - import java.util.*; @@ -113,18 +110,24 @@ private void setup(Tree tree) { } eventsKnown = false; - this.intervalNodeMapping = buildIntervalNodeMapping ?new IntervalNodeMapping.Default(tree.getNodeCount(), tree):new IntervalNodeMapping.None(); + this.intervalNodeMapping = buildIntervalNodeMapping ?new IntervalNodeMapProvider.IntervalNodeMapping.Default(tree.getNodeCount(), tree):new IntervalNodeMapProvider.IntervalNodeMapping.None(); addStatistic(new DeltaStatistic()); } - // This option is set in the constructor as final -// public void setBuildIntervalNodeMapping(boolean buildIntervalNodeMapping){ -// this.buildIntervalNodeMapping = buildIntervalNodeMapping; -// this.intervalNodeMapping = buildIntervalNodeMapping ? new IntervalNodeMapping.Default(tree.getNodeCount(),tree):new IntervalNodeMapping.None(); -// //Force a recalculation here -// eventsKnown = false; -// } + public void setBuildIntervalNodeMapping(boolean buildIntervalNodeMapping){ + this.buildIntervalNodeMapping = buildIntervalNodeMapping; + if (this.buildIntervalNodeMapping) { + intervals = new Intervals(tree.getNodeCount()); + storedIntervals = new Intervals(tree.getNodeCount()); + } else { + intervals = new FastIntervals(tree.getExternalNodeCount(), tree.getInternalNodeCount()); + storedIntervals = new FastIntervals(tree.getExternalNodeCount(), tree.getInternalNodeCount()); + } + //Force a recalculation here + eventsKnown = false; + this.intervalNodeMapping = buildIntervalNodeMapping ?new IntervalNodeMapProvider.IntervalNodeMapping.Default(tree.getNodeCount(), tree):new IntervalNodeMapProvider.IntervalNodeMapping.None(); + } // ************************************************************** // ModelListener IMPLEMENTATION // ************************************************************** @@ -390,6 +393,10 @@ public void setUnits(Type units) { intervals.setUnits(units); } + public IntervalNodeMapProvider getIntervalNodeMap() { + return intervalNodeMapping; + } + // Interval Node mapping @@ -443,7 +450,6 @@ public double[] getCoalescentIntervals() { } return coalIntervals; } - // **************************************************************** // Inner classes // **************************************************************** @@ -464,268 +470,6 @@ public double getStatisticValue(int i) { } } - public interface IntervalNodeMapping { - - void addNode(int nodeNumbe); - void setIntervalStartIndices(int intervalCount); - void initializeMaps(); - - void mapNodeInterval(int nodeNumber, int intervalNumber); - //could put this in store restore model - void storeMapping(); - void restoreMapping(); - int[] getIntervalsForNode(int nodeNumber); - int[] getNodeNumbersForInterval(int interval); - double[] sortByNodeNumbers(double[] byIntervalOrder); - - class Default implements IntervalNodeMapping { - private int[] nodeNumbersInIntervals; - private int[] intervalStartIndices; - private int[] intervalNumberOfNodes; - - private int[] storedNodeNumbersInIntervals; - private int[] storedIntervalStartIndices; - private int[] storedIntervalNumberOfNodes; - - private int nextIndex = 0; - private int nIntervals; - private Tree tree; - - private final int maxIndicesPerNode = 3; - - public Default (int maxIntervalCount, Tree tree) { - nodeNumbersInIntervals = new int[maxIndicesPerNode * maxIntervalCount]; - storedNodeNumbersInIntervals = new int[maxIndicesPerNode * maxIntervalCount]; - - intervalStartIndices = new int[maxIntervalCount]; - storedIntervalStartIndices = new int[maxIntervalCount]; - - intervalNumberOfNodes = new int[maxIndicesPerNode * maxIntervalCount]; - storedIntervalNumberOfNodes = new int[maxIndicesPerNode * maxIntervalCount]; - - this.tree = tree; - - } - - public void addNode(int nodeNumber) { - if (nextIndex > 500) { -// System.err.println("why"); //Why not? -JT - } - nodeNumbersInIntervals[nextIndex] = nodeNumber; - nextIndex++; - } - - public void mapNodeInterval(int nodeNumber, int intervalNumber) { - int index = 0; - while(index < maxIndicesPerNode) { - if (intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + index] == -1) { - intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + index] = intervalNumber; - break; - } else { - index++; - } - } - if (index == maxIndicesPerNode) { - throw new RuntimeException("The node appears in more than" + maxIndicesPerNode + " intervals!"); - } -// if (intervalNumberOfNodes[maxIndicesPerNode * nodeNumber] == -1 || intervalNumberOfNodes[maxIndicesPerNode * nodeNumber] == intervalNumber) { -// intervalNumberOfNodes[3 * nodeNumber] = intervalNumber; -// } else if (intervalNumberOfNodes[2 * nodeNumber + 1] == -1) { -// intervalNumberOfNodes[3 * nodeNumber + 1] = intervalNumber; -// } else { -// double[] testIntervals = new double[nIntervals]; -// for (int i = 0; i < nIntervals - 1; i++) { -// testIntervals[i] = tree.getNodeHeight(tree.getNode(nodeNumbersInIntervals[intervalStartIndices[i + 1]])) -// - tree.getNodeHeight(tree.getNode(nodeNumbersInIntervals[intervalStartIndices[i]])); -// } -// throw new RuntimeException("The node appears in more than two intervals!"); -// } - } - - public void setIntervalStartIndices(int intervalCount) { - - if (nodeNumbersInIntervals[nextIndex - 1] == nodeNumbersInIntervals[nextIndex - 2]) { - nodeNumbersInIntervals[nextIndex - 1] = 0; - nextIndex--; - } - - int index = 1; - mapNodeInterval(nodeNumbersInIntervals[0], 0); - - for (int i = 1; i < intervalCount; i++) { - - while(nodeNumbersInIntervals[index] != nodeNumbersInIntervals[index - 1]) { - mapNodeInterval(nodeNumbersInIntervals[index], i - 1); - index++; - } - - intervalStartIndices[i] = index; - mapNodeInterval(nodeNumbersInIntervals[index], i); - index++; - - } - - while(index < nextIndex) { - mapNodeInterval(nodeNumbersInIntervals[index], intervalCount - 1); - index++; - } - - nIntervals = intervalCount; - } - - public void initializeMaps() { - Arrays.fill(intervalNumberOfNodes, -1); - Arrays.fill(intervalStartIndices, 0); - Arrays.fill(nodeNumbersInIntervals, 0); - nextIndex = 0; - } - - @Override - public int[] getIntervalsForNode(int nodeNumber) { - int nonZeros = 0; - while(intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + nonZeros] != -1) { - nonZeros++; - } - int[] result = new int[nonZeros]; - for(int i = 0; i < nonZeros; i++) { - result[i] = intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + i]; - } - return result; -// if(intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + 1] == -1) { -// return new int[]{intervalNumberOfNodes[maxIndicesPerNode * nodeNumber]}; -// } else { -// return new int[]{intervalNumberOfNodes[maxIndicesPerNode * nodeNumber], intervalNumberOfNodes[maxIndicesPerNode * nodeNumber + 1]}; -// } - } - - @Override - public int[] getNodeNumbersForInterval(int interval) { - assert(interval < nIntervals); - - final int startIndex = intervalStartIndices[interval]; - int endIndex; - if (interval == nIntervals - 1) { - endIndex = nextIndex - 1; - } else { - endIndex = intervalStartIndices[interval + 1] - 1; - } - - int[] nodeNumbers = new int[endIndex - startIndex + 1]; - - for (int i = 0; i < endIndex - startIndex + 1; i++) { - nodeNumbers[i] = nodeNumbersInIntervals[startIndex + i]; - } - return nodeNumbers; - } - - @Override - public double[] sortByNodeNumbers(double[] byIntervalOrder) { - double[] sortedValues = new double[byIntervalOrder.length]; - int[] nodeIndices = new int[byIntervalOrder.length]; - ArrayList mappedIntervals = new ArrayList(); - for (int i = 0; i < nodeIndices.length; i++) { - mappedIntervals.add(new ComparableDouble(getIntervalsForNode(i + tree.getExternalNodeCount())[0])); - } - HeapSort.sort(mappedIntervals, nodeIndices); - for (int i = 0; i < nodeIndices.length; i++) { - sortedValues[nodeIndices[i]] = byIntervalOrder[i]; - } - return sortedValues; - } - - - - - /** - * Additional state information, outside of the sub-model is stored by this call. - */ - public void storeMapping() { - System.arraycopy(nodeNumbersInIntervals,0,storedNodeNumbersInIntervals,0,nodeNumbersInIntervals.length); - System.arraycopy(intervalNumberOfNodes,0,storedIntervalNumberOfNodes,0,intervalNumberOfNodes.length); - System.arraycopy(intervalStartIndices,0,storedIntervalStartIndices,0,intervalStartIndices.length); - } - - /** - * After this call the model is guaranteed to have returned its extra state information to - * the values coinciding with the last storeState call. - * Sub-models are handled automatically and do not need to be considered in this method. - */ - public void restoreMapping() { - int[] tmp = storedNodeNumbersInIntervals; - storedNodeNumbersInIntervals = nodeNumbersInIntervals; - nodeNumbersInIntervals = tmp; - - int[] tmp2 = storedIntervalNumberOfNodes; - storedIntervalNumberOfNodes=intervalNumberOfNodes; - intervalNumberOfNodes = tmp2; - - int[] tmp3= storedIntervalStartIndices; - storedIntervalStartIndices = intervalStartIndices; - intervalStartIndices =tmp3; - } - - - } - - class None implements IntervalNodeMapping { - - @Override - public void addNode(int nodeNumber) { - // Do nothing - } - - @Override - public void setIntervalStartIndices(int intervalCount) { - // Do nothing - } - - @Override - public void initializeMaps() { - // Do nothing - } - - @Override - public void mapNodeInterval(int nodeNumber, int intervalNumber) { - // Do nothing - - } - - @Override - public int[] getIntervalsForNode(int nodeNumber) { - throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); - } - - @Override - public int[] getNodeNumbersForInterval(int interval) { - throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); - } - - @Override - public double[] sortByNodeNumbers(double[] byIntervalOrder) { - throw new RuntimeException("No intervalNodeMapping available. This function should not be called."); - } - - /** - * Additional state information, outside of the sub-model is stored by this call. - */ - public void storeMapping() { - //nothing - } - - /** - * After this call the model is guaranteed to have returned its extra state information to - * the values coinciding with the last storeState call. - * Sub-models are handled automatically and do not need to be considered in this method. - */ - public void restoreMapping() { - //Nothing - } - - - } - } - - // **************************************************************** // Private and protected stuff // **************************************************************** @@ -736,13 +480,13 @@ public void restoreMapping() { private Tree tree = null; private Set includedLeafSet = null; private Set[] excludedLeafSets = null; - private final boolean buildIntervalNodeMapping; + private boolean buildIntervalNodeMapping; /** * The intervals. */ private MutableIntervalList intervals = null; - private IntervalNodeMapping intervalNodeMapping; + private IntervalNodeMapProvider.IntervalNodeMapping intervalNodeMapping; /** * The stored values for intervals. */ @@ -750,4 +494,5 @@ public void restoreMapping() { private boolean eventsKnown = false; private boolean storedEventsKnown = false; + } diff --git a/src/dr/evomodel/coalescent/UnifiedGMRFSkyrideLikelihood.java b/src/dr/evomodel/coalescent/UnifiedGMRFSkyrideLikelihood.java new file mode 100644 index 0000000000..69c5a71b2a --- /dev/null +++ b/src/dr/evomodel/coalescent/UnifiedGMRFSkyrideLikelihood.java @@ -0,0 +1,22 @@ +package dr.evomodel.coalescent; + +import dr.inference.model.Likelihood; +import dr.inference.model.Parameter; +import no.uib.cipr.matrix.SymmTridiagMatrix; + +public interface UnifiedGMRFSkyrideLikelihood extends CoalescentIntervalProvider,Likelihood{ + IntervalNodeMapProvider getIntervalNodeMapping(); + int getCoalescentIntervalDimension(); + Parameter getPopSizeParameter(); + Parameter getPrecisionParameter(); + Parameter getLambdaParameter(); + + double[] getSufficientStatistics(); + SymmTridiagMatrix getStoredScaledWeightMatrix(double currentPrecision, double currentLambda); + SymmTridiagMatrix getScaledWeightMatrix(double proposedPrecision, double proposedLambda); + + //Interval stuff + int getIntervalCount(); + boolean isCoalescentInterval(int interval); // because interval types are linked to classes for now + int getLineageCount(int interval); +} diff --git a/src/dr/evomodel/coalescent/operators/GMRFSkyrideBlockUpdateOperator.java b/src/dr/evomodel/coalescent/operators/GMRFSkyrideBlockUpdateOperator.java index e510f57ef8..78395decd6 100644 --- a/src/dr/evomodel/coalescent/operators/GMRFSkyrideBlockUpdateOperator.java +++ b/src/dr/evomodel/coalescent/operators/GMRFSkyrideBlockUpdateOperator.java @@ -27,7 +27,7 @@ package dr.evomodel.coalescent.operators; -import dr.evomodel.coalescent.OldGMRFSkyrideLikelihood; +import dr.evomodel.coalescent.UnifiedGMRFSkyrideLikelihood; import dr.evomodelxml.coalescent.operators.GMRFSkyrideBlockUpdateOperatorParser; import dr.inference.model.Parameter; import dr.inference.operators.*; @@ -56,11 +56,11 @@ public class GMRFSkyrideBlockUpdateOperator extends AbstractAdaptableOperator { private Parameter precisionParameter; private Parameter lambdaParameter; - OldGMRFSkyrideLikelihood gmrfField; + UnifiedGMRFSkyrideLikelihood gmrfField; private double[] zeros; - public GMRFSkyrideBlockUpdateOperator(OldGMRFSkyrideLikelihood gmrfLikelihood, + public GMRFSkyrideBlockUpdateOperator(UnifiedGMRFSkyrideLikelihood gmrfLikelihood, double weight, AdaptationMode mode, double scaleFactor, int maxIterations, double stopValue) { super(mode); diff --git a/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java b/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java index e8abb55726..4e5b4fb40a 100644 --- a/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java +++ b/src/dr/evomodelxml/coalescent/GMRFSkyrideGradientParser.java @@ -29,6 +29,7 @@ import dr.evomodel.coalescent.GMRFSkyrideGradient; import dr.evomodel.coalescent.GMRFMultilocusSkyrideLikelihood; import dr.evomodel.coalescent.OldGMRFSkyrideLikelihood; +import dr.evomodel.coalescent.UnifiedGMRFSkyrideLikelihood; import dr.evomodel.coalescent.hmc.GMRFGradient; import dr.evomodel.tree.TreeModel; import dr.evomodel.treedatalikelihood.discrete.NodeHeightTransform; @@ -56,7 +57,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { // Parameter parameter = (Parameter) xo.getChild(Parameter.class); TreeModel tree = (TreeModel) xo.getChild(TreeModel.class); - OldGMRFSkyrideLikelihood skyrideLikelihood = (OldGMRFSkyrideLikelihood) xo.getChild(OldGMRFSkyrideLikelihood.class); + UnifiedGMRFSkyrideLikelihood skyrideLikelihood = (UnifiedGMRFSkyrideLikelihood) xo.getChild(UnifiedGMRFSkyrideLikelihood.class); String wrtParameterCase = (String) xo.getAttribute(WRT_PARAMETER); @@ -100,7 +101,7 @@ public XMLSyntaxRule[] getSyntaxRules() { private final XMLSyntaxRule[] rules = { AttributeRule.newStringRule(WRT_PARAMETER), new ElementRule(TreeModel.class, true), - new ElementRule(OldGMRFSkyrideLikelihood.class), + new ElementRule(UnifiedGMRFSkyrideLikelihood.class), new ElementRule(NodeHeightTransform.class, true), AttributeRule.newDoubleRule(TOLERANCE, true), AttributeRule.newBooleanRule(IGNORE_WARNING, true), diff --git a/src/dr/evomodelxml/coalescent/GMRFSkyrideLikelihoodParser.java b/src/dr/evomodelxml/coalescent/GMRFSkyrideLikelihoodParser.java index b8e9a5dc48..1cdf0ab7f5 100644 --- a/src/dr/evomodelxml/coalescent/GMRFSkyrideLikelihoodParser.java +++ b/src/dr/evomodelxml/coalescent/GMRFSkyrideLikelihoodParser.java @@ -26,6 +26,7 @@ package dr.evomodelxml.coalescent; import dr.evolution.coalescent.IntervalList; +import dr.evolution.coalescent.TreeIntervalList; import dr.evomodel.coalescent.*; import dr.evolution.tree.Tree; import dr.evomodel.tree.TreeModel; @@ -104,7 +105,7 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { boolean buildIntervalNodeMapping = xo.getAttribute(BUILD_MAPPING, false); - List intervalsList = new ArrayList(); + List intervalsList = new ArrayList(); List treeList = new ArrayList(); if(xo.getChild(POPULATION_TREE) != null) { @@ -136,11 +137,11 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException { if (xo.getChild(INTERVALS) != null) { cxo = xo.getChild(INTERVALS); - intervalsList = new ArrayList(); + intervalsList = new ArrayList(); for (int i = 0; i < cxo.getChildCount(); i++) { Object testObject = cxo.getChild(i); - if (testObject instanceof IntervalList) { - intervalsList.add((IntervalList) testObject); + if (testObject instanceof TreeIntervalList) { + intervalsList.add((TreeIntervalList) testObject); } } } diff --git a/src/dr/evomodelxml/coalescent/operators/GMRFSkyrideBlockUpdateOperatorParser.java b/src/dr/evomodelxml/coalescent/operators/GMRFSkyrideBlockUpdateOperatorParser.java index c6f9884d40..d9944e32c5 100644 --- a/src/dr/evomodelxml/coalescent/operators/GMRFSkyrideBlockUpdateOperatorParser.java +++ b/src/dr/evomodelxml/coalescent/operators/GMRFSkyrideBlockUpdateOperatorParser.java @@ -29,6 +29,7 @@ import dr.evomodel.coalescent.GMRFMultilocusSkyrideLikelihood; import dr.evomodel.coalescent.OldGMRFSkyrideLikelihood; +import dr.evomodel.coalescent.UnifiedGMRFSkyrideLikelihood; import dr.evomodel.coalescent.operators.GMRFMultilocusSkyrideBlockUpdateOperator; import dr.evomodel.coalescent.operators.GMRFSkyrideBlockUpdateOperator; import dr.inference.operators.AdaptableMCMCOperator; @@ -121,7 +122,7 @@ public String format(LogRecord record) { if (xo.getAttribute(OLD_SKYRIDE, true) && !(xo.getName().compareTo(GRID_BLOCK_UPDATE_OPERATOR) == 0) ) { - OldGMRFSkyrideLikelihood gmrfLikelihood = (OldGMRFSkyrideLikelihood) xo.getChild(OldGMRFSkyrideLikelihood.class); + UnifiedGMRFSkyrideLikelihood gmrfLikelihood = (UnifiedGMRFSkyrideLikelihood) xo.getChild(UnifiedGMRFSkyrideLikelihood.class); return new GMRFSkyrideBlockUpdateOperator(gmrfLikelihood, weight, mode, scaleFactor, maxIterations, stopValue); } else {