Skip to content

Commit

Permalink
BinomialDistribution: fix methods for truncated distribution, issue #128
Browse files Browse the repository at this point in the history
  • Loading branch information
pr0m1th3as committed Aug 10, 2024
1 parent 335421f commit f8eb307
Showing 1 changed file with 60 additions and 35 deletions.
95 changes: 60 additions & 35 deletions inst/dist_obj/BinomialDistribution.m
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,8 @@ function disp (this)
ub = x > ux;
p(lb) = 0;
p(ub) = 1;
p(! (lb | ub)) -= binocdf (lx, this.N, this.p);
p(! (lb | ub)) /= diff (binocdf ([lx, ux], this.N, this.p));
p(! (lb | ub)) -= binocdf (lx - 1, this.N, this.p);
p(! (lb | ub)) /= diff (binocdf ([lx-1, ux], this.N, this.p));
endif
## Apply uflag
if (utail)
Expand All @@ -213,18 +213,20 @@ function disp (this)
if (! isscalar (this))
error ("icdf: requires a scalar probability distribution.");
endif
umax = binoinv (1, this.N, this.p);
if (this.IsTruncated && this.Truncation(2) >= umax)
## Get lower and upper boundaries
lx = ceil (this.Truncation(1));
ux = floor (this.Truncation(2));
ux = min (ux, umax);
lp = binocdf (lx - 1, this.N, this.p);
up = binocdf (ux, this.N, this.p);
p = lp + p * (up - lp);
endif
x = binoinv (p, this.N, this.p);
if (this.IsTruncated)
lp = binocdf (this.Truncation(1), this.N, this.p);
up = binocdf (this.Truncation(2), this.N, this.p);
## Adjust p values within range of p @ lower limit and p @ upper limit
is_nan = p < 0 | p > 1;
p(is_nan) = NaN;
np = lp + (up - lp) .* p;
x = binoinv (np, this.N, this.p);
x(x < this.Truncation(1)) = this.Truncation(1);
x(x > this.Truncation(2)) = this.Truncation(2);
else
x = binoinv (p, this.N, this.p);
endif
endfunction

Expand Down Expand Up @@ -257,11 +259,13 @@ function disp (this)
if (! isscalar (this))
error ("mean: requires a scalar probability distribution.");
endif
m = binostat (this.N, this.p);
if (this.IsTruncated)
fm = @(x) x .* pdf (this, x);
m = integral (fm, this.Truncation(1), this.Truncation(2));
else
m = binostat (this.N, this.p);
lx = ceil (this.Truncation(1));
ux = floor (this.Truncation(2));
ux = min (ux, binoinv (1, this.N, this.p));
x = [lx:ux];
m = sum (x .* pdf (this, x));
endif
endfunction

Expand Down Expand Up @@ -375,7 +379,7 @@ function disp (this)
ux = this.Truncation(2);
ub = x > ux;
y(lb | ub) = 0;
y(! (lb | ub)) /= diff (binocdf ([lx, ux], this.N, this.p));
y(! (lb | ub)) /= diff (binocdf ([lx-1, ux], this.N, this.p));
endif
endfunction

Expand Down Expand Up @@ -510,7 +514,7 @@ function disp (this)
## pick the appropriate size from
lx = this.Truncation(1);
ux = this.Truncation(2);
ratio = 1 / diff (poisscdf ([lx, ux], this.N, this.p));
ratio = 1 / diff (binocdf ([lx-1, ux], this.N, this.p));
nsize = fix (2 * ratio * ps); # times 2 to be on the safe side
## Generate the numbers and remove out-of-bound random samples
r = binornd (this.N, this.p, nsize, 1);
Expand Down Expand Up @@ -589,10 +593,24 @@ function disp (this)
error ("var: requires a scalar probability distribution.");
endif
if (this.IsTruncated)
fm = @(x) x .* pdf (this, x);
m = integral (fm, this.Truncation(1), this.Truncation(2));
fv = @(x) ((x - m) .^ 2) .* pdf (this, x);
v = integral (fv, this.Truncation(1), this.Truncation(2));
## Calculate untruncated mean and variance
[um, uv] = binostat (this.N, this.p);
## Calculate truncated mean
m = mean (this);
## Get lower and upper boundaries
lx = ceil (this.Truncation(1));
ux = floor (this.Truncation(2));
ux = min (ux, binoinv (1, this.N, this.p));
## Handle infinite support on the right
if (isequal (ux, Inf))
ratio = 1 / diff (binocdf ([lx-1, ux], this.N, this.p));
x = 0:lx-1;
v = ratio * (uv + (um - m) ^ 2 - sum (((x - m) .^ 2) .* ...
binopdf (x, this.N, this.p)));
else
x = lx:ux;
v = sum (((x - m) .^ 2) .* pdf (this, x));
endif
else
[~, v] = binostat (this.N, this.p);
endif
Expand Down Expand Up @@ -651,34 +669,41 @@ function checkparams (N, p)
endfunction

## Test output
%!shared pd, t
%!shared pd, t, t_inf
%! pd = BinomialDistribution (5, 0.5);
%! t = truncate (pd, 2, 4);
%! t_inf = truncate (pd, 2, Inf);
%!assert (cdf (pd, [0:5]), [0.0312, 0.1875, 0.5, 0.8125, 0.9688, 1], 1e-4);
#%!assert (cdf (t, [0:5]), [0, 0, 0.4, 0.8, 1, 1], 1e-4);
%!assert (cdf (t, [0:5]), [0, 0, 0.4, 0.8, 1, 1], 1e-4);
%!assert (cdf (t_inf, [0:5]), [0, 0, 0.3846, 0.7692, 0.9615, 1], 1e-4);
%!assert (cdf (pd, [1.5, 2, 3, 4, NaN]), [0.1875, 0.5, 0.8125, 0.9688, NaN], 1e-4);
#%!assert (cdf (t, [1.5, 2, 3, 4, NaN]), [0, 0.4, 0.8, 1, NaN], 1e-4);
%!assert (cdf (t, [1.5, 2, 3, 4, NaN]), [0, 0.4, 0.8, 1, NaN], 1e-4);
%!assert (icdf (pd, [0:0.2:1]), [0, 2, 2, 3, 3, 5], 1e-4);
#%!assert (icdf (t, [0:0.2:1]), [2, 2, 2, 3, 3, 4], 1e-4);
%!assert (icdf (t, [0:0.2:1]), [2, 2, 2, 3, 3, 4], 1e-4);
%!assert (icdf (t_inf, [0:0.2:1]), [2, 2, 3, 3, 4, 5], 1e-4);
%!assert (icdf (pd, [-1, 0.4:0.2:1, NaN]), [NaN, 2, 3, 3, 5, NaN], 1e-4);
#%!assert (icdf (t, [-1, 0.4:0.2:1, NaN]), [NaN, 2, 3, 3, 4, NaN], 1e-4);
%!assert (icdf (t, [-1, 0.4:0.2:1, NaN]), [NaN, 2, 3, 3, 4, NaN], 1e-4);
%!assert (iqr (pd), 1);
#%!assert (iqr (t), 1);
%!assert (iqr (t), 1);
%!assert (mean (pd), 2.5, 1e-10);
#%!assert (mean (t), 2.8, 1e-10);
%!assert (mean (t), 2.8, 1e-10);
%!assert (mean (t_inf), 2.8846, 1e-4);
%!assert (median (pd), 2.5);
#%!assert (median (t), 3);
%!assert (median (t), 3);
%!assert (pdf (pd, [0:5]), [0.0312, 0.1562, 0.3125, 0.3125, 0.1562, 0.0312], 1e-4);
#%!assert (pdf (t, [0:5]), [0, 0, 0.4, 0.4, 0.2, 0], 1e-4);
%!assert (pdf (t, [0:5]), [0, 0, 0.4, 0.4, 0.2, 0], 1e-4);
%!assert (pdf (t_inf, [0:5]), [0, 0, 0.3846, 0.3846, 0.1923, 0.0385], 1e-4);
%!assert (pdf (pd, [-1, 1.5, NaN]), [0, 0, NaN], 1e-4);
#%!assert (pdf (t, [-1, 1.5, NaN]), [0, 0, NaN], 1e-4);
%!assert (pdf (t, [-1, 1.5, NaN]), [0, 0, NaN], 1e-4);
%!assert (isequal (size (random (pd, 100, 50)), [100, 50]))
#%!assert (any (random (t, 1000, 1) < 2), false);
#%!assert (any (random (t, 1000, 1) > 4), false);
%!assert (any (random (t, 1000, 1) < 2), false);
%!assert (any (random (t, 1000, 1) > 4), false);
%!assert (std (pd), 1.1180, 1e-4);
#%!assert (std (t), 0.7483, 1e-4);
%!assert (std (t), 0.7483, 1e-4);
%!assert (std (t_inf), 0.8470, 1e-4);
%!assert (var (pd), 1.2500, 1e-4);
#%!assert (var (t), 0.5600, 1e-4);
%!assert (var (t), 0.5600, 1e-4);
%!assert (var (t_inf), 0.7175, 1e-4);

## Test input validation
## 'BinomialDistribution' constructor
Expand Down

0 comments on commit f8eb307

Please sign in to comment.