Skip to content

Commit

Permalink
Add a general negative binomial distribution #536
Browse files Browse the repository at this point in the history
  • Loading branch information
zjzxiaohei committed Dec 6, 2024
1 parent d114646 commit 2672b18
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package lphy.base.distribution;

import lphy.core.model.RandomVariable;
import lphy.core.model.Value;
import lphy.core.model.annotation.GeneratorCategory;
import lphy.core.model.annotation.GeneratorInfo;
import lphy.core.model.annotation.ParameterInfo;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.random.RandomGenerator;

import java.util.Map;
import java.util.TreeMap;

import static lphy.base.distribution.DistributionConstants.*;

public class GeneralNegativeBinomial extends ParametricDistribution<Integer>{

private Value<Double> p;
private Value<Double> r;
private GammaDistribution gamma;
private PoissonDistribution poisson;

public GeneralNegativeBinomial(@ParameterInfo(name = rParamName, description = "the number of successes which can not be an integer.") Value<Double> r,
@ParameterInfo(name = pParamName, description = "the probability of a success.") Value<Double> p) {
super();
this.p = p;
this.r = r;

constructDistribution(random);
}

public GeneralNegativeBinomial() {
constructDistribution(random);
}


@Override
protected void constructDistribution(RandomGenerator random) {
}
@GeneratorInfo(name = "GNB", verbClause = "has", narrativeName = "generalised negative binomial distribution",
category = GeneratorCategory.PRIOR,
description = "The generalised negative binomial distribution which parameter r can not be an integer.")
@Override
public RandomVariable<Integer> sample() {
double alpha = r.value();
double beta = (1 - p.value()) / p.value();
gamma = new GammaDistribution(random, alpha, beta);
double lamda = gamma.sample();
poisson = new PoissonDistribution(random, lamda, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS);
Integer result = poisson.sample();

return new RandomVariable<>(null, result, this);
}

@Override
public Map<String, Value> getParams() {
return new TreeMap<>() {{
put(pParamName, p);
put(nParamName, r);
}};
}

@Override
public void setParam(String paramName, Value value) {
switch (paramName) {
case rParamName:
this.r = value;
break;
case pParamName:
this.p = value;
break;
default:
throw new RuntimeException("Unrecognised parameter name: " + paramName);
}


// super.setParam(paramName, value); // constructDistribution
}


public String toString() {
return getName();
}
}
49 changes: 49 additions & 0 deletions lphy-base/src/test/java/lphy/base/distribution/GNBTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package lphy.base.distribution;

import lphy.core.model.Value;
import lphy.core.simulator.RandomUtils;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;


public class GNBTest {
@BeforeEach
void setUp() {
RandomUtils.setSeed(123);
}

@Test
public void testGNB() {
int rep = 100;
int sampleSize = 10000;
int failures = 0;
Double pv = 0.4;
Double rv = 3.3;
double observedMean;
double observedVariance;
double expectedMean = rv * (1 - pv) / pv;
double expectedVariance = rv * (1 - pv) / (pv * pv);
double stdError = Math.sqrt(expectedVariance / sampleSize);
double DELTA = 2 * stdError;
Value<Double> p = new Value<Double>("p", pv);
Value<Double> r = new Value<Double>("r", rv);
GeneralNegativeBinomial negativeBinomial = new GeneralNegativeBinomial();
negativeBinomial.setParam("p", p);
negativeBinomial.setParam("r", r);
for (int j = 0; j < rep; j++) {
SummaryStatistics results = new SummaryStatistics();
for (int i = 0; i < sampleSize; i++) {
int result = negativeBinomial.sample().value();
results.addValue(result);
}
observedMean = results.getMean();
observedVariance = results.getVariance();
if (observedMean < expectedMean - DELTA || observedMean > expectedMean + DELTA) {
failures++;
}
//assertEquals(expectedMean, observedMean, DELTA);
}
System.out.println(failures);//failures should be less than rep *1/20
}
}

0 comments on commit 2672b18

Please sign in to comment.