From fedfa806024ca84467eba15f2b4e8eb15e52491b Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 31 May 2024 15:28:49 +0200 Subject: [PATCH 1/4] Docs: fix custom pytorch model tutorial --- .../howto_pytorch_lightning.md.template | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template index 34e6937b19..e4e240e5a4 100644 --- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template +++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template @@ -110,10 +110,10 @@ class FeedForwardNetwork(nn.Module): torch.nn.init.zeros_(lin.bias) return lin - def forward(self, context): - scale = self.scaling(context) - scaled_context = context / scale - nn_out = self.nn(scaled_context) + def forward(self, past_target): + scale = self.scaling(past_target) + scaled_past_target = past_target / scale + nn_out = self.nn(scaled_past_target) nn_out_reshaped = nn_out.reshape(-1, self.prediction_length, self.hidden_dimensions[-1]) distr_args = self.args_proj(nn_out_reshaped) return distr_args, torch.zeros_like(scale), scale @@ -143,13 +143,13 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule): super().__init__(*args, **kwargs) def training_step(self, batch, batch_idx): - context = batch["past_target"] - target = batch["future_target"] + past_target = batch["past_target"] + future_target = batch["future_target"] - assert context.shape[-1] == self.context_length - assert target.shape[-1] == self.prediction_length + assert past_target.shape[-1] == self.context_length + assert future_target.shape[-1] == self.prediction_length - distr_args, loc, scale = self(context) + distr_args, loc, scale = self(past_target) distr = self.distr_output.distribution(distr_args, loc, scale) loss = -distr.log_prob(target) From 5cee03c06c5671c963898d8b7ab0ac479f72c133 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 31 May 2024 15:34:37 +0200 Subject: [PATCH 2/4] more fixing --- .../advanced_topics/howto_pytorch_lightning.md.template | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template index e4e240e5a4..9290260a13 100644 --- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template +++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template @@ -151,7 +151,7 @@ class LightningFeedForwardNetwork(FeedForwardNetwork, pl.LightningModule): distr_args, loc, scale = self(past_target) distr = self.distr_output.distribution(distr_args, loc, scale) - loss = -distr.log_prob(target) + loss = -distr.log_prob(future_target) return loss.mean() From 7a6750517f7076147fc03f215482ee589297d51f Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 31 May 2024 17:45:44 +0200 Subject: [PATCH 3/4] update check for mxnet --- src/gluonts/model/forecast_generator.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 33b0320808..f43152abf1 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -83,12 +83,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: def make_predictions(prediction_net, inputs: dict): - # MXNet predictors only support positional arguments - class_name = prediction_net.__class__.__module__ - if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"): - return prediction_net(*inputs.values()) - else: - return prediction_net(**inputs) + try: + # Feed inputs as positional arguments for MXNet block predictors + import mxnet as mx + + if isinstance(prediction_net, mx.gluon.HybridBlock): + return prediction_net(*inputs.values()) + except ImportError: + pass + return prediction_net(**inputs) class ForecastGenerator: From 582ca375578b3e509bd4d75399fda9eb0ffd9b2f Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Fri, 31 May 2024 18:34:30 +0200 Subject: [PATCH 4/4] fixup --- src/gluonts/model/forecast_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index f43152abf1..0148a8e1e6 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -87,7 +87,7 @@ def make_predictions(prediction_net, inputs: dict): # Feed inputs as positional arguments for MXNet block predictors import mxnet as mx - if isinstance(prediction_net, mx.gluon.HybridBlock): + if isinstance(prediction_net, mx.gluon.Block): return prediction_net(*inputs.values()) except ImportError: pass