diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index 42caf1fa9e..33b0320808 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -82,6 +82,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: raise NotImplementedError +def make_predictions(prediction_net, inputs: dict): + # MXNet predictors only support positional arguments + class_name = prediction_net.__class__.__module__ + if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"): + return prediction_net(*inputs.values()) + else: + return prediction_net(**inputs) + + class ForecastGenerator: """ Classes used to bring the output of a network into a class. @@ -115,7 +124,7 @@ def __call__( ) -> Iterator[Forecast]: for batch in inference_data_loader: inputs = select(input_names, batch, ignore_missing=True) - (outputs,), loc, scale = prediction_net(*inputs.values()) + (outputs,), loc, scale = make_predictions(prediction_net, inputs) outputs = to_numpy(outputs) if scale is not None: outputs = outputs * to_numpy(scale[..., None]) @@ -159,14 +168,16 @@ def __call__( ) -> Iterator[Forecast]: for batch in inference_data_loader: inputs = select(input_names, batch, ignore_missing=True) - outputs = to_numpy(prediction_net(*inputs.values())) + outputs = to_numpy(make_predictions(prediction_net, inputs)) if output_transform is not None: outputs = output_transform(batch, outputs) if num_samples: num_collected_samples = outputs[0].shape[0] collected_samples = [outputs] while num_collected_samples < num_samples: - outputs = to_numpy(prediction_net(*inputs.values())) + outputs = to_numpy( + make_predictions(prediction_net, inputs) + ) if output_transform is not None: outputs = output_transform(batch, outputs) collected_samples.append(outputs) @@ -209,7 +220,7 @@ def __call__( ) -> Iterator[Forecast]: for batch in inference_data_loader: inputs = select(input_names, batch, ignore_missing=True) - outputs = prediction_net(*inputs.values()) + outputs = make_predictions(prediction_net, inputs) if output_transform: log_once(OUTPUT_TRANSFORM_NOT_SUPPORTED_MSG)