Skip to content

Commit

Permalink
working gradients for different rate parameter transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
msuchard committed Oct 2, 2024
1 parent a7f8970 commit 19bebea
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/dr/evomodel/substmodel/BaseSubstitutionModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ private void decompose() {
updateMatrix = false;
}

protected double setupMatrix() {
public double setupMatrix() {
setupRelativeRates(relativeRates);
double[] pi = getPi();
setupQMatrix(relativeRates, pi, q);
Expand Down
4 changes: 4 additions & 0 deletions src/dr/evomodel/substmodel/ComplexSubstitutionModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ public void setScaleRatesByFrequencies(boolean frequencyScaling) {
this.frequencyScaling = frequencyScaling;
}

public boolean getScaleRatesByFrequencies() {
return frequencyScaling;
}

public boolean getNormalization() {
return doNormalization;
}
Expand Down
7 changes: 7 additions & 0 deletions src/dr/evomodel/substmodel/LogAdditiveCtmcRateProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ default double[] getRates() {
return rates;
}

default Transform getTransform() {
return null;
}

interface Integrated extends LogAdditiveCtmcRateProvider { }

interface DataAugmented extends LogAdditiveCtmcRateProvider {
Expand Down Expand Up @@ -108,6 +112,9 @@ public ArbitraryTransform(String name,
this.transform = transform;
}

@Override
public Transform getTransform() { return transform; }

@Override
public double[] getRates() {
double[] rates = transformedRateParameter.getParameterValues();
Expand Down
5 changes: 5 additions & 0 deletions src/dr/evomodel/substmodel/LogRateSubstitutionModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import dr.inference.model.Model;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.util.Transform;

import java.util.*;

Expand Down Expand Up @@ -112,4 +113,8 @@ public List<Citation> getCitations() {

private final LogAdditiveCtmcRateProvider lrm;
private final double[] testProbabilities;

public Transform getTransform() {
return lrm.getTransform();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import dr.inference.model.Parameter;
import dr.util.Author;
import dr.util.Citation;
import dr.util.Transform;

import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -113,7 +114,8 @@ protected double preProcessNormalization(double[] differentials, double[] genera

double processSingleGradientDimension(int i,
double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {
boolean normalize, double normalizationConstant,
double rateScalar, Transform transform, boolean scaleByFrequencies) {

double[] covariate = parameterMap.getCovariateColumn(i);
return calculateCovariateDifferential(generator, differentials, covariate, pi, normalize);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import dr.inference.model.Model;
import dr.inference.model.ModelListener;
import dr.util.Citable;
import dr.util.Transform;
import dr.xml.Reportable;

import java.util.ArrayList;
Expand All @@ -65,6 +66,7 @@ public abstract class AbstractLogAdditiveSubstitutionModelGradient implements
protected final ComplexSubstitutionModel substitutionModel;
protected final int stateCount;
protected final List<Integer> crossProductAccumulationMap;
protected final boolean scaleRatesByFrequencies;

private final ApproximationMode mode;

Expand Down Expand Up @@ -168,6 +170,7 @@ public AbstractLogAdditiveSubstitutionModelGradient(String traitName,
this.tree = treeDataLikelihood.getTree();
this.branchModel = likelihoodDelegate.getBranchModel();
this.substitutionModel = substitutionModel;
this.scaleRatesByFrequencies = substitutionModel.getScaleRatesByFrequencies();
this.stateCount = substitutionModel.getDataType().getStateCount();
this.crossProductAccumulationMap = createCrossProductAccumulationMap(likelihoodDelegate.getBranchModel(),
substitutionModel);
Expand Down Expand Up @@ -200,7 +203,8 @@ public AbstractLogAdditiveSubstitutionModelGradient(String traitName,

abstract double processSingleGradientDimension(int dim,
double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant);
boolean normalize, double normalizationConstant, double rateScalar,
Transform transform, boolean scaleByFrequencies);

@Override
public double[] getGradientLogDensity() {
Expand All @@ -218,15 +222,20 @@ public double[] getGradientLogDensity() {
substitutionModel.getInfinitesimalMatrix(generator);

double[] pi = substitutionModel.getFrequencyModel().getFrequencies();
boolean normalize = substitutionModel.getNormalization();

double rateScalar = normalize ? 1 / substitutionModel.setupMatrix() : 0.0;

double normalizationConstant = preProcessNormalization(crossProducts, generator,
substitutionModel.getNormalization());

Transform transform = (substitutionModel instanceof LogRateSubstitutionModel) ?
((LogRateSubstitutionModel) substitutionModel).getTransform() : null;

final double[] gradient = new double[getParameter().getDimension()];
for (int i = 0; i < getParameter().getDimension(); ++i) {
gradient[i] = processSingleGradientDimension(i, crossProducts, generator, pi,
substitutionModel.getNormalization(),
normalizationConstant);
normalize, normalizationConstant, rateScalar, transform, scaleRatesByFrequencies);
}

if (COUNT_TOTAL_OPERATIONS) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import dr.inference.loggers.LogColumn;
import dr.inference.model.Parameter;
import dr.util.Citation;
import dr.util.Transform;

import java.util.List;

Expand Down Expand Up @@ -80,7 +81,8 @@ protected double preProcessNormalization(double[] differentials, double[] genera

@Override
double processSingleGradientDimension(int j, double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {
boolean normalize, double normalizationConstant,
double rateScalar, Transform transform, boolean scaleByFrequencies) {
// derivative wrt pi[j]
double total = 0.0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import dr.inference.model.DesignMatrix;
import dr.inference.model.MaskedParameter;
import dr.inference.model.Parameter;
import dr.util.Transform;

/**
* @author Marc A. Suchard
Expand Down Expand Up @@ -113,7 +114,8 @@ protected double preProcessNormalization(double[] differentials, double[] genera
@Override
double processSingleGradientDimension(int k,
double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {
boolean normalize, double normalizationConstant,
double rateScalar, Transform transform, boolean scaleByFrequencies) {

int whichCoefficient = indexK(k);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import dr.inference.model.Parameter;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.util.Transform;

import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -104,7 +105,8 @@ private int[][] makeAsymmetricMap() {

@Override
double processSingleGradientDimension(int k, double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {
boolean normalize, double normalizationConstant,
double rateScalar, Transform transform, boolean scaleByFrequencies) {

final int i = mapEffectToIndices[k][0], j = mapEffectToIndices[k][1];
final int ii = i * stateCount + i;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import dr.inference.model.Parameter;
import dr.util.Citation;
import dr.util.CommonCitations;
import dr.util.Transform;

import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -81,7 +82,9 @@ private LogAdditiveCtmcRateProvider.DataAugmented extractRateProvider(ComplexSub

@Override
protected double preProcessNormalization(double[] differentials, double[] generator,
boolean normalize) {
boolean normalize
// , Transform transform, boolean scaleByFrequencies
) {
double total = 0.0;
if (normalize) {
for (int i = 0; i < stateCount; ++i) {
Expand Down Expand Up @@ -115,13 +118,27 @@ private int[][] makeAsymmetricMap() {

@Override
double processSingleGradientDimension(int k, double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {
boolean normalize, double normalizationConstant, double rateScalar,
Transform transform, boolean scaleByFrequencies) {

final int i = mapEffectToIndices[k][0], j = mapEffectToIndices[k][1];
final int ii = i * stateCount + i;
final int ij = i * stateCount + j;
final Parameter transformedParameter = rateProvider.getLogRateParameter();

double element;
if (transform == null) {
element = generator[ij]; // Default is exp()
} else {
element = transform.gradient(transformedParameter.getParameterValue(k))
if (normalize) {
element *= rateScalar;
}
if (scaleByFrequencies) {
element *= pi[i];
}
}

double element = generator[ij];
double total = (differentials[ij] - differentials[ii]) * element;

if (normalize) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.inference.distribution.GeneralizedLinearModel;
import dr.inference.model.Parameter;
import dr.util.Transform;

/**
* @author Marc A. Suchard
Expand Down Expand Up @@ -88,7 +89,8 @@ protected double preProcessNormalization(double[] differentials, double[] genera
@Override
double processSingleGradientDimension(int k,
double[] differentials, double[] generator, double[] pi,
boolean normalize, double normalizationConstant) {
boolean normalize, double normalizationConstant,
double rateScalar, Transform transform, boolean scaleByFrequencies) {

double elementUpper = generator[indexIJ(k)];
double total = (differentials[indexIJ(k)] - differentials[indexII(k)]) * elementUpper;
Expand Down
4 changes: 2 additions & 2 deletions src/dr/util/Transform.java
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ public double updateOffdiagonalHessianLogDensity(double offdiagonalHessian, doub

@Override
public double gradient(double value) {
throw new RuntimeException("Not yet implemented");
return Math.exp(value);
}

public String getTransformName() { return "exp"; }
Expand Down Expand Up @@ -739,7 +739,7 @@ public double updateOffdiagonalHessianLogDensity(double offdiagonalHessian, doub

@Override
public double gradient(double value) {
throw new RuntimeException("Not yet implemented");
return 2 * value;
}

public String getTransformName() { return "squared"; }
Expand Down

0 comments on commit 19bebea

Please sign in to comment.