-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a general negative binomial distribution #536
- Loading branch information
1 parent
d114646
commit 2672b18
Showing
2 changed files
with
134 additions
and
0 deletions.
There are no files selected for viewing
85 changes: 85 additions & 0 deletions
85
lphy-base/src/main/java/lphy/base/distribution/GeneralNegativeBinomial.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package lphy.base.distribution; | ||
|
||
import lphy.core.model.RandomVariable; | ||
import lphy.core.model.Value; | ||
import lphy.core.model.annotation.GeneratorCategory; | ||
import lphy.core.model.annotation.GeneratorInfo; | ||
import lphy.core.model.annotation.ParameterInfo; | ||
import org.apache.commons.math3.distribution.GammaDistribution; | ||
import org.apache.commons.math3.distribution.PoissonDistribution; | ||
import org.apache.commons.math3.random.RandomGenerator; | ||
|
||
import java.util.Map; | ||
import java.util.TreeMap; | ||
|
||
import static lphy.base.distribution.DistributionConstants.*; | ||
|
||
public class GeneralNegativeBinomial extends ParametricDistribution<Integer>{ | ||
|
||
private Value<Double> p; | ||
private Value<Double> r; | ||
private GammaDistribution gamma; | ||
private PoissonDistribution poisson; | ||
|
||
public GeneralNegativeBinomial(@ParameterInfo(name = rParamName, description = "the number of successes which can not be an integer.") Value<Double> r, | ||
@ParameterInfo(name = pParamName, description = "the probability of a success.") Value<Double> p) { | ||
super(); | ||
this.p = p; | ||
this.r = r; | ||
|
||
constructDistribution(random); | ||
} | ||
|
||
public GeneralNegativeBinomial() { | ||
constructDistribution(random); | ||
} | ||
|
||
|
||
@Override | ||
protected void constructDistribution(RandomGenerator random) { | ||
} | ||
@GeneratorInfo(name = "GNB", verbClause = "has", narrativeName = "generalised negative binomial distribution", | ||
category = GeneratorCategory.PRIOR, | ||
description = "The generalised negative binomial distribution which parameter r can not be an integer.") | ||
@Override | ||
public RandomVariable<Integer> sample() { | ||
double alpha = r.value(); | ||
double beta = (1 - p.value()) / p.value(); | ||
gamma = new GammaDistribution(random, alpha, beta); | ||
double lamda = gamma.sample(); | ||
poisson = new PoissonDistribution(random, lamda, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS); | ||
Integer result = poisson.sample(); | ||
|
||
return new RandomVariable<>(null, result, this); | ||
} | ||
|
||
@Override | ||
public Map<String, Value> getParams() { | ||
return new TreeMap<>() {{ | ||
put(pParamName, p); | ||
put(nParamName, r); | ||
}}; | ||
} | ||
|
||
@Override | ||
public void setParam(String paramName, Value value) { | ||
switch (paramName) { | ||
case rParamName: | ||
this.r = value; | ||
break; | ||
case pParamName: | ||
this.p = value; | ||
break; | ||
default: | ||
throw new RuntimeException("Unrecognised parameter name: " + paramName); | ||
} | ||
|
||
|
||
// super.setParam(paramName, value); // constructDistribution | ||
} | ||
|
||
|
||
public String toString() { | ||
return getName(); | ||
} | ||
} |
49 changes: 49 additions & 0 deletions
49
lphy-base/src/test/java/lphy/base/distribution/GNBTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package lphy.base.distribution; | ||
|
||
import lphy.core.model.Value; | ||
import lphy.core.simulator.RandomUtils; | ||
import org.apache.commons.math3.stat.descriptive.SummaryStatistics; | ||
import org.junit.jupiter.api.BeforeEach; | ||
import org.junit.jupiter.api.Test; | ||
|
||
|
||
public class GNBTest { | ||
@BeforeEach | ||
void setUp() { | ||
RandomUtils.setSeed(123); | ||
} | ||
|
||
@Test | ||
public void testGNB() { | ||
int rep = 100; | ||
int sampleSize = 10000; | ||
int failures = 0; | ||
Double pv = 0.4; | ||
Double rv = 3.3; | ||
double observedMean; | ||
double observedVariance; | ||
double expectedMean = rv * (1 - pv) / pv; | ||
double expectedVariance = rv * (1 - pv) / (pv * pv); | ||
double stdError = Math.sqrt(expectedVariance / sampleSize); | ||
double DELTA = 2 * stdError; | ||
Value<Double> p = new Value<Double>("p", pv); | ||
Value<Double> r = new Value<Double>("r", rv); | ||
GeneralNegativeBinomial negativeBinomial = new GeneralNegativeBinomial(); | ||
negativeBinomial.setParam("p", p); | ||
negativeBinomial.setParam("r", r); | ||
for (int j = 0; j < rep; j++) { | ||
SummaryStatistics results = new SummaryStatistics(); | ||
for (int i = 0; i < sampleSize; i++) { | ||
int result = negativeBinomial.sample().value(); | ||
results.addValue(result); | ||
} | ||
observedMean = results.getMean(); | ||
observedVariance = results.getVariance(); | ||
if (observedMean < expectedMean - DELTA || observedMean > expectedMean + DELTA) { | ||
failures++; | ||
} | ||
//assertEquals(expectedMean, observedMean, DELTA); | ||
} | ||
System.out.println(failures);//failures should be less than rep *1/20 | ||
} | ||
} |