From 5a3419fb6b09ba410f71d67100648f083a31d242 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 20 Sep 2024 21:50:04 +0200 Subject: [PATCH] fix next_lags --- examples/anomaly_detection_pytorch.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/anomaly_detection_pytorch.py b/examples/anomaly_detection_pytorch.py index b7343e4f58..d3b8c24fc8 100644 --- a/examples/anomaly_detection_pytorch.py +++ b/examples/anomaly_detection_pytorch.py @@ -152,6 +152,7 @@ def main(args): distr = model.output_distribution( params, trailing_n=1, scale=scale ) + scaled_past_target = inputs["past_target"] / scale batch_anomalies = [] for i in tqdm( range(inputs["future_target"].shape[1]), @@ -178,13 +179,16 @@ def main(args): ) next_lags = lagged_sequence_values( model.lags_seq, - inputs["past_target"] / scale, + scaled_past_target, target / scale, dim=-1, ) rnn_input = torch.cat((next_lags, next_features), dim=-1) - output, state = model.rnn(rnn_input, state) + scaled_past_target = torch.cat( + (scaled_past_target, target / scale), dim=1 + ) + params = model.param_proj(output) distr = model.output_distribution(params, scale=scale) # stack the batch_anomalies along the prediction length dimension