diff --git a/src/gluonts/torch/distributions/generalized_pareto.py b/src/gluonts/torch/distributions/generalized_pareto.py index 60ce3d626d..fec06548c1 100644 --- a/src/gluonts/torch/distributions/generalized_pareto.py +++ b/src/gluonts/torch/distributions/generalized_pareto.py @@ -187,6 +187,12 @@ def domain_map( concentration: torch.Tensor, ): # type: ignore scale = F.softplus(scale) + # Clamp concentration to avoid numerical issues + concentration = torch.tanh(concentration) + + # Adjust loc for negative concentration + neg_conc = concentration < 0 + loc = torch.where(neg_conc, loc - scale / concentration, loc) return loc.squeeze(-1), scale.squeeze(-1), concentration.squeeze(-1) @property