From ceeec4ff1d428958ea9aa725a71c3fe99762ead8 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 23 Mar 2022 09:01:18 +0000 Subject: [PATCH] Fix attention concat --- pytorch_forecasting/models/base_model.py | 20 ++++++++++++++- .../temporal_fusion_transformer/__init__.py | 25 ++++++++++++++++--- pytorch_forecasting/utils.py | 11 ++++++++ 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/pytorch_forecasting/models/base_model.py b/pytorch_forecasting/models/base_model.py index 79126c4c..f66a05e9 100644 --- a/pytorch_forecasting/models/base_model.py +++ b/pytorch_forecasting/models/base_model.py @@ -6,6 +6,7 @@ from copy import deepcopy import inspect from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +import warnings import matplotlib.pyplot as plt import numpy as np @@ -75,7 +76,24 @@ def _torch_cat_na(x: List[torch.Tensor]) -> torch.Tensor: ) for xi in x ] - return torch.cat(x, dim=0) + + # check if remaining dimensions are all equal + if x[0].ndim > 2: + remaining_dimensions_equal = all([all([xi.size(i) == x[0].size(i) for xi in x]) for i in range(2, x[0].ndim)]) + else: + remaining_dimensions_equal = True + + # deaggregate + if remaining_dimensions_equal: + return torch.cat(x, dim=0) + else: + # make list instead but warn + warnings.warn( + f"Not all dimensions are equal for tensors shapes. Example tensor {x[0].shape}. " + "Returning list instead of torch.Tensor.", + UserWarning, + ) + return [xii for xi in x for xii in xi] def _concatenate_output( diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index a716539f..428510a3 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -556,6 +556,23 @@ def interpret_output( Returns: interpretations that can be plotted with ``plot_interpretation()`` """ + # take attention and concatenate if a list to proper attention object + if isinstance(out["attention"], (list, tuple)): + # assume issue is in last dimension, we need to find max + max_last_dimension = max(x.size(-1) for x in out["attention"]) + first_elm = out["attention"][0] + # create new attention tensor into which we will scatter + attention = torch.full( + (len(out["attention"]), *first_elm.shape[:-1], max_last_dimension), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["attention"]): + attention[idx, :, :, : x.size(-1)] = x + else: + attention = out["attention"] # histogram of decode and encode lengths encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) @@ -582,7 +599,7 @@ def interpret_output( static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps - attention = out["attention"][ + attention = attention[ :, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon ].mean(1) @@ -677,15 +694,15 @@ def plot_prediction( # add attention on secondary axis if plot_attention: - interpretation = self.interpret_output(out) + interpretation = self.interpret_output(out.iget(slice(idx, idx + 1))) for f in to_list(fig): ax = f.axes[0] ax2 = ax.twinx() ax2.set_ylabel("Attention") - encoder_length = x["encoder_lengths"][idx] + encoder_length = x["encoder_lengths"][0] ax2.plot( torch.arange(-encoder_length, 0), - interpretation["attention"][idx, :encoder_length].detach().cpu(), + interpretation["attention"][0, :encoder_length].detach().cpu(), alpha=0.2, color="k", ) diff --git a/pytorch_forecasting/utils.py b/pytorch_forecasting/utils.py index 407aa428..58feb0af 100644 --- a/pytorch_forecasting/utils.py +++ b/pytorch_forecasting/utils.py @@ -338,6 +338,17 @@ def items(self): def keys(self): return self._fields + def iget(self, idx: Union[int, slice]): + """Select item(s) row-wise. + + Args: + idx ([int, slice]): item to select + + Returns: + Output of single item. + """ + return self.__class__(*(x[idx] for x in self)) + class TupleOutputMixIn: """MixIn to give output a namedtuple-like access capabilities with ``to_network_output() function``."""