Skip to content

Commit

Permalink
formatting and loss
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Oct 29, 2024
1 parent f11d2d3 commit e54c3df
Showing 1 changed file with 69 additions and 32 deletions.
101 changes: 69 additions & 32 deletions src/gluonts/torch/distributions/bernstein_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from typing import Dict, List, Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
import torch.nn.functional as F
from torch.distributions import Distribution, AffineTransform, TransformedDistribution
from torch.distributions import (
Distribution,
AffineTransform,
TransformedDistribution,
)

from gluonts.core.component import validated
from .distribution_output import DistributionOutput
Expand All @@ -24,7 +28,7 @@
class BernsteinQuantileDistribution(Distribution):
r"""
Distribution class for quantile function approximation using Bernstein polynomials.
Parameters
----------
coefficients
Expand All @@ -33,7 +37,7 @@ class BernsteinQuantileDistribution(Distribution):
degree
Degree of Bernstein polynomials.
"""

def __init__(
self,
coefficients: torch.Tensor,
Expand All @@ -42,7 +46,7 @@ def __init__(
) -> None:
self.coefficients = coefficients
self.degree = degree

batch_shape = coefficients.shape[:-1]
super().__init__(batch_shape=batch_shape, validate_args=validate_args)

Expand All @@ -51,50 +55,53 @@ def bernstein_basis(self, alpha: torch.Tensor, k: int) -> torch.Tensor:
n = self.degree
# Compute binomial coefficient
coef = torch.exp(
torch.lgamma(torch.tensor(n + 1.))
- torch.lgamma(torch.tensor(k + 1.))
- torch.lgamma(torch.tensor(n - k + 1.))
torch.lgamma(torch.tensor(n + 1.0))
- torch.lgamma(torch.tensor(k + 1.0))
- torch.lgamma(torch.tensor(n - k + 1.0))
)
return coef * (alpha ** k) * ((1 - alpha) ** (n - k))
return coef * (alpha**k) * ((1 - alpha) ** (n - k))

def quantile(self, alpha: torch.Tensor) -> torch.Tensor:
"""
Evaluate quantile function at specified levels using Bernstein polynomials.
Parameters
----------
alpha
Tensor of shape (*batch_shape) containing quantile levels in [0,1]
Returns
-------
Tensor
Quantile values of shape (*batch_shape)
"""
# Ensure alpha is in [0,1]
alpha = torch.clamp(alpha, 0, 1)

# Expand alpha for broadcasting
alpha_expanded = alpha.unsqueeze(-1)

# Compute all Bernstein basis polynomials
basis_values = torch.stack([
self.bernstein_basis(alpha_expanded, k)
for k in range(self.degree + 1)
], dim=-1)

basis_values = torch.stack(
[
self.bernstein_basis(alpha_expanded, k)
for k in range(self.degree + 1)
],
dim=-1,
)

# Compute quantile values as linear combination of basis polynomials
return torch.sum(basis_values * self.coefficients, dim=-1)

def cdf(self, y: torch.Tensor) -> torch.Tensor:
"""
Approximate the CDF using binary search on the quantile function.
Parameters
----------
y
Tensor of shape (*batch_shape) containing values
Returns
-------
Tensor
Expand All @@ -103,14 +110,14 @@ def cdf(self, y: torch.Tensor) -> torch.Tensor:
# Initialize search bounds
lower = torch.zeros_like(y)
upper = torch.ones_like(y)

# Binary search
for _ in range(10): # Number of iterations for desired precision
mid = (lower + upper) / 2
q_mid = self.quantile(mid)
lower = torch.where(q_mid < y, mid, lower)
upper = torch.where(q_mid < y, upper, mid)

return (lower + upper) / 2

def rsample(self, sample_shape: torch.Size = torch.Size()) -> torch.Tensor:
Expand All @@ -130,34 +137,35 @@ def crps(self, y: torch.Tensor) -> torch.Tensor:
# Approximate CRPS using numerical integration
alpha = torch.linspace(0, 1, 100, device=y.device)
quantiles = self.quantile(alpha)

# Compute integrand
indicator = (quantiles.unsqueeze(-1) >= y.unsqueeze(-2)).float()
integrand = (indicator - alpha.unsqueeze(-1)) ** 2

# Numerical integration using trapezoidal rule
return torch.trapz(integrand, alpha, dim=-2)


class BernsteinQuantileOutput(DistributionOutput):
r"""
Distribution output class for quantile function approximation using Bernstein polynomials.
Parameters
----------
degree
Degree of Bernstein polynomials to use.
"""

distr_cls: type = BernsteinQuantileDistribution

@validated()
def __init__(self, degree: int) -> None:
super().__init__()

assert isinstance(degree, int) and degree > 0, \
"degree must be a positive integer"


assert (
isinstance(degree, int) and degree > 0
), "degree must be a positive integer"

self.degree = degree
self.args_dim: Dict[str, int] = {"coefficients": degree + 1}

Expand All @@ -180,7 +188,7 @@ def distribution(
"""
coefficients = distr_args[0]
distr = self.distr_cls(coefficients, self.degree)

if scale is None:
return distr
else:
Expand All @@ -191,3 +199,32 @@ def distribution(
@property
def event_shape(self) -> Tuple:
return ()

def loss(
self,
target: torch.Tensor,
distr_args: Tuple[torch.Tensor, ...],
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Calculate loss based on CRPS (Continuous Ranked Probability Score).
Parameters
----------
target
Target values
distr_args
Distribution arguments returned by the network
loc
Location parameter for affine transformation
scale
Scale parameter for affine transformation
Returns
-------
torch.Tensor
CRPS loss value
"""
distribution = self.distribution(distr_args, loc=loc, scale=scale)
return distribution.crps(target)

0 comments on commit e54c3df

Please sign in to comment.