Skip to content

Commit

Permalink
generalize LogAdditiveSubstitutionModel for arbitrary transform; grad…
Browse files Browse the repository at this point in the history
…ient is still not working; could break FM's code
  • Loading branch information
msuchard committed Oct 1, 2024
1 parent 0970429 commit a7f8970
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 149 deletions.
46 changes: 33 additions & 13 deletions src/dr/evomodel/substmodel/LogAdditiveCtmcRateProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import dr.inference.loggers.LogColumn;
import dr.inference.model.*;
import dr.util.Transform;

public interface LogAdditiveCtmcRateProvider extends Model, Likelihood {

Expand All @@ -54,27 +55,20 @@ interface DataAugmented extends LogAdditiveCtmcRateProvider {

class Basic extends AbstractModelLikelihood implements DataAugmented {

private final Parameter logRateParameter;
final Parameter transformedRateParameter;

public Basic(String name, Parameter logRateParameter) {
public Basic(String name, Parameter transformedRateParameter) {
super(name);
this.logRateParameter = logRateParameter;
this.transformedRateParameter = transformedRateParameter;

addVariable(logRateParameter);
addVariable(transformedRateParameter);
}

public Parameter getLogRateParameter() { return logRateParameter; }
public Parameter getLogRateParameter() { return transformedRateParameter; }

@Override
public double[] getXBeta() { // TODO this function should _not_ exponentiate

final int fieldDim = logRateParameter.getDimension();
double[] rates = new double[fieldDim];

for (int i = 0; i < fieldDim; ++i) {
rates[i] = Math.exp(logRateParameter.getParameterValue(i));
}
return rates;
return transformedRateParameter.getParameterValues();
}

@Override
Expand Down Expand Up @@ -102,5 +96,31 @@ protected void acceptState() { }
@Override
public void makeDirty() { }
}

class ArbitraryTransform extends Basic {

private final Transform transform;

public ArbitraryTransform(String name,
Parameter transformedRateParameter,
Transform transform) {
super(name, transformedRateParameter);
this.transform = transform;
}

@Override
public double[] getRates() {
double[] rates = transformedRateParameter.getParameterValues();
for (int i = 0; i < rates.length; ++i) {
rates[i] = transform.transform(rates[i]);
}
return rates;
}

@Override @Deprecated
public double[] getXBeta() { // TODO this function should NOT transform
throw new RuntimeException("Deprecated function");
}
}
}
}
129 changes: 2 additions & 127 deletions src/dr/evomodel/substmodel/LogRateSubstitutionModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
import dr.inference.model.BayesianStochasticSearchVariableSelection;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.util.Citation;
import dr.util.CommonCitations;

Expand Down Expand Up @@ -61,15 +59,9 @@ public LogAdditiveCtmcRateProvider getRateProvider() {
return lrm;
}

// public GeneralizedLinearModel getGeneralizedLinearModel() { // TODO re-check if this can be translated in this context
// if (glm instanceof GeneralizedLinearModel) {
// return (GeneralizedLinearModel) glm;
// }
// throw new RuntimeException("Not yet implemented");
// }

protected void setupRelativeRates(double[] rates) {
System.arraycopy(lrm.getLogRateParameter().getParameterValues(),0,rates,0,rates.length);
double[] transformedRates = lrm.getRates();
System.arraycopy(transformedRates,0,rates,0,rates.length);
}

@Override
Expand Down Expand Up @@ -120,121 +112,4 @@ public List<Citation> getCitations() {

private final LogAdditiveCtmcRateProvider lrm;
private final double[] testProbabilities;

// @Override // TODO check if this can be translated in this context
// public ParameterReplaceableSubstitutionModel factory(List<Parameter> oldParameters, List<Parameter> newParameters) {
// LogLinearModel newGLM = ((LogLinearModel)lrm).factory(oldParameters, newParameters);
// return new LogRateSubstitutionModel(getModelName(), dataType, freqModel, newGLM);
// }

// @Override
// public WrappedMatrix getInfinitesimalDifferentialMatrix(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt) {
// // TODO all instantiations of this function currently do the same thing; remove duplication
// return DifferentiableSubstitutionModelUtil.getInfinitesimalDifferentialMatrix(wrt, this);
// }

// static class WrtOldGLMSubstitutionModelParameter implements DifferentialMassProvider.DifferentialWrapper.WrtParameter { // TODO check if this can be translated in this context
//
// final private int dim;
// final private int fixedEffectIndex;
// final private int stateCount;
// final private LogLinearModel glm;
//
// public WrtOldGLMSubstitutionModelParameter(LogLinearModel glm, int fixedEffectIndex, int dim, int stateCount) {
// this.glm = glm;
// this.fixedEffectIndex = fixedEffectIndex;
// this.dim = dim;
// this.stateCount = stateCount;
// }
//
// @Override
// public double getRate(int switchCase) {
// throw new RuntimeException("Should not be called.");
// }
//
// @Override
// public double getNormalizationDifferential() {
// return 0;
// }
//
// @Override
// public void setupDifferentialFrequencies(double[] differentialFrequencies, double[] frequencies) {
// Arrays.fill(differentialFrequencies, 1);
// }
//
// @Override
// public void setupDifferentialRates(double[] differentialRates, double[] Q, double normalizingConstant) {
//
// final double[] covariate = glm.getDesignMatrix(fixedEffectIndex).getColumnValues(dim);
//
// int k = 0;
// for (int i = 0; i < stateCount; ++i) {
// for (int j = i + 1; j < stateCount; ++j) {
//
// differentialRates[k] = covariate[k] * Q[index(i, j)];
// k++;
//
// }
// }
//
// for (int j = 0; j < stateCount; ++j) {
// for (int i = j + 1; i < stateCount; ++i) {
//
// differentialRates[k] = covariate[k] * Q[index(i, j)];
// k++;
//
// }
// }
//
//// final double chainRule = getChainRule();
////// double[][] design = glm.getX(effect);
////
//// for (int i = 0; i < differentialRates.length; ++i) {
//// differentialRates[i] = covariate[i] / normalizingConstant * chainRule;
//// }
// }
//
// double getChainRule() {
// return Math.exp(glm.getFixedEffect(fixedEffectIndex).getParameterValue(dim));
// }
//
// private int index(int i, int j) {
// return i * stateCount + j;
// }
// }

// @Override // TODO check if this can be translated in this context
// public DifferentialMassProvider.DifferentialWrapper.WrtParameter factory(Parameter parameter, int dim) {
//
// final int effectIndex = ((LogLinearModel)lrm).getEffectNumber(parameter);
// if (effectIndex == -1) {
// throw new RuntimeException("Only implemented for single dimensions, break up beta to one for each block for now please.");
// }
// return new WrtOldGLMSubstitutionModelParameter((LogLinearModel) lrm, effectIndex, dim, stateCount);
// }


// @Override
// public void setupDifferentialRates(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt, double[] differentialRates, double normalizingConstant) {
// final double[] Q = new double[stateCount * stateCount];
// getInfinitesimalMatrix(Q); // TODO These are large; should cache
// wrt.setupDifferentialRates(differentialRates, Q, normalizingConstant);
// }
//
// @Override
// public void setupDifferentialFrequency(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt, double[] differentialFrequency) {
// wrt.setupDifferentialFrequencies(differentialFrequency, getFrequencyModel().getFrequencies());
// }
//
// @Override
// public double getWeightedNormalizationGradient(DifferentialMassProvider.DifferentialWrapper.WrtParameter wrt, double[][] differentialMassMatrix, double[] differentialFrequencies) {
// double derivative = 0;
//
// if (getNormalization()) {
// for (int i = 0; i < stateCount; ++i) {
// derivative -= differentialMassMatrix[i][i] * getFrequencyModel().getFrequency(i);
// }
// }
// return derivative;
// }
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@
public class ApproximateLogCtmcRateGradientParser extends AbstractXMLObjectParser {

private static final String PARSER_NAME = "approximateLogCtmcRateGradient";
private static final String TRANSFORMED_PARSER_NAME = "approximateTransformedCtmcRateGradient";
private static final String TRAIT_NAME = TreeTraitParserUtilities.TRAIT_NAME;

public String getParserName(){ return PARSER_NAME; }

public String[] getParserNames() {
return new String[] { PARSER_NAME, TRANSFORMED_PARSER_NAME };
}

public Object parseXMLObject(XMLObject xo) throws XMLParseException {

String traitName = xo.getAttribute(TRAIT_NAME, DEFAULT_TRAIT_NAME);
Expand Down
45 changes: 36 additions & 9 deletions src/dr/evomodelxml/substmodel/LogRateSubstitutionModelParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import dr.evomodel.substmodel.LogRateSubstitutionModel;
import dr.evoxml.util.DataTypeUtils;
import dr.inference.model.Parameter;
import dr.util.Transform;
import dr.xml.*;

/**
Expand All @@ -44,15 +45,20 @@
public class LogRateSubstitutionModelParser extends AbstractXMLObjectParser {

public static final String LOG_RATE_SUBSTITUTION_MODEL = "logRateSubstitutionModel";
private static final String TRANSFORMED_RATE_SUBSTITUTION_MODEL = "transformedRateSubstitutionModel";
private static final String NORMALIZE = "normalize";
private static final String LOG_RATES = "logRates";
private static final String TRANSFORMED_RATES = "transformedRates";
public static final String SCALE_RATES_BY_FREQUENCIES = "scaleRatesByFrequencies";


public String getParserName() {
return LOG_RATE_SUBSTITUTION_MODEL;
}

public String[] getParserNames() {
return new String[] {LOG_RATE_SUBSTITUTION_MODEL, TRANSFORMED_RATE_SUBSTITUTION_MODEL};
}

public Object parseXMLObject(XMLObject xo) throws XMLParseException {

DataType dataType = DataTypeUtils.getDataType(xo);
Expand All @@ -61,16 +67,31 @@ public Object parseXMLObject(XMLObject xo) throws XMLParseException {

int rateCount = (dataType.getStateCount() - 1) * dataType.getStateCount(); // TODO not used anymore

Parameter logRates = (Parameter) xo.getElementFirstChild(LOG_RATES);
final LogAdditiveCtmcRateProvider lrm;
if (xo.hasChildNamed(LOG_RATES)) {
Parameter logRates = (Parameter) xo.getElementFirstChild(LOG_RATES);

int length = logRates.getDimension();
if (length != rateCount) {
throw new XMLParseException("Rates parameter in " + getParserName() + " element should have " + (rateCount) + " dimensions. However the log rates dimension is " + length + ".");
}
int length = logRates.getDimension();
if (length != rateCount) {
throw new XMLParseException("Rates parameter in " + getParserName() + " element should have " + (rateCount) + " dimensions. However the log rates dimension is " + length + ".");
}

lrm = new LogAdditiveCtmcRateProvider.DataAugmented.Basic(logRates.getId(), logRates);
} else {
XMLObject cxo = xo.getChild(TRANSFORMED_RATES);
Parameter transformedRates = (Parameter) cxo.getChild(Parameter.class);
Transform.ParsedTransform parsedTransform = (Transform.ParsedTransform) cxo.getChild(Transform.ParsedTransform.class);

LogAdditiveCtmcRateProvider lrm = new LogAdditiveCtmcRateProvider.DataAugmented.Basic(logRates.getId(), logRates);
int length = transformedRates.getDimension();
if (length != rateCount) {
throw new XMLParseException("Rates parameter in " + getParserName() + " element should have " + (rateCount) + " dimensions. However the transformed rates dimension is " + length + ".");
}

XMLObject cxo = xo.getChild(dr.oldevomodelxml.substmodel.ComplexSubstitutionModelParser.ROOT_FREQUENCIES);
lrm = new LogAdditiveCtmcRateProvider.DataAugmented.ArbitraryTransform(
transformedRates.getId(), transformedRates, parsedTransform.transform);
}

XMLObject cxo = xo.getChild(ComplexSubstitutionModelParser.ROOT_FREQUENCIES);
FrequencyModel rootFreq = (FrequencyModel) cxo.getChild(FrequencyModel.class);

if (dataType != rootFreq.getDataType()) {
Expand Down Expand Up @@ -110,7 +131,13 @@ public XMLSyntaxRule[] getSyntaxRules() {
new ElementRule(DataType.class)
),
new ElementRule(ComplexSubstitutionModelParser.ROOT_FREQUENCIES, FrequencyModel.class),
new ElementRule(LOG_RATES, Parameter.class),
new XORRule(
new ElementRule(LOG_RATES, Parameter.class),
new ElementRule(TRANSFORMED_RATES, new XMLSyntaxRule[] {
new ElementRule(Parameter.class),
new ElementRule(Transform.ParsedTransform.class),
})
),
AttributeRule.newBooleanRule(NORMALIZE, true),
AttributeRule.newBooleanRule(SCALE_RATES_BY_FREQUENCIES, true),
};
Expand Down
53 changes: 53 additions & 0 deletions src/dr/util/Transform.java
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,57 @@ public double gradient(double value) {
public double logJacobian(double x) { return x; }
}

// y = x^2
class SquaredTransform extends UnivariableTransform {
@Override
public Transform inverseTransform() {
throw new RuntimeException("Not yet implemented");
}

public double transform(double x) {
return x * x;
}

public double inverse(double y) {
return Math.sqrt(y);
}

public boolean isInInteriorDomain(double x) {
return !Double.isInfinite(x);
}

public double gradientInverse(double y) { return 0.5 / y; }

public double updateGradientLogDensity(double gradientWrtX, double x) {
double y = transform(x);
double dXdY = gradientInverse(y);
return gradientWrtX * dXdY + gradientLogJacobianInverse(y);
}

public double gradientLogJacobianInverse(double y) {
throw new RuntimeException("Mot yet implemented");
}

@Override
public double updateDiagonalHessianLogDensity(double diagonalHessian, double gradient, double value) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double updateOffdiagonalHessianLogDensity(double offdiagonalHessian, double transfomationHessian, double gradientI, double gradientJ, double valueI, double valueJ) {
throw new RuntimeException("Not yet implemented");
}

@Override
public double gradient(double value) {
throw new RuntimeException("Not yet implemented");
}

public String getTransformName() { return "squared"; }

public double logJacobian(double x) { return Math.log(2 * x); }
}

// y = log(x)
class LogTransform extends UnivariableTransform {

Expand Down Expand Up @@ -2610,6 +2661,7 @@ public static MultivariableTransform parseMultivariableTransform(Object obj) {
LogTransform LOG = new LogTransform();
ExpTransform EXP = new ExpTransform();
NegateTransform NEGATE = new NegateTransform();
SquaredTransform SQUARED = new SquaredTransform();
Compose LOG_NEGATE = new Compose(new LogTransform(), new NegateTransform());
LogConstrainedSumTransform LOG_CONSTRAINED_SUM = new LogConstrainedSumTransform();
LogitTransform LOGIT = new LogitTransform();
Expand All @@ -2625,6 +2677,7 @@ enum Type {
LOGIT("logit", new LogitTransform()),
FISHER_Z("fisherZ",new FisherZTransform()),
INVERSE_SUM("inverseSum", new InverseSumTransform()),
SQUARED("squared", new SquaredTransform()),
POWER("power", new PowerTransform());

Type(String name, Transform transform) {
Expand Down

0 comments on commit a7f8970

Please sign in to comment.