From f1e524a75d809fee8e41b80a053a93fc89363bd4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 13 Sep 2024 19:33:22 +0200 Subject: [PATCH] clamp concentration --- src/gluonts/torch/distributions/generalized_pareto.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/gluonts/torch/distributions/generalized_pareto.py b/src/gluonts/torch/distributions/generalized_pareto.py index 60ce3d626d..fec06548c1 100644 --- a/src/gluonts/torch/distributions/generalized_pareto.py +++ b/src/gluonts/torch/distributions/generalized_pareto.py @@ -187,6 +187,12 @@ def domain_map( concentration: torch.Tensor, ): # type: ignore scale = F.softplus(scale) + # Clamp concentration to avoid numerical issues + concentration = torch.tanh(concentration) + + # Adjust loc for negative concentration + neg_conc = concentration < 0 + loc = torch.where(neg_conc, loc - scale / concentration, loc) return loc.squeeze(-1), scale.squeeze(-1), concentration.squeeze(-1) @property