From eeec569329aac393c65d7313ba5daa0ddbb9e260 Mon Sep 17 00:00:00 2001 From: Jan Beitner Date: Wed, 23 Mar 2022 11:46:45 +0000 Subject: [PATCH] Update attention calculation --- .../temporal_fusion_transformer/__init__.py | 124 ++++++++++-------- pytorch_forecasting/utils.py | 25 ++++ .../test_temporal_fusion_transformer.py | 51 +++++++ 3 files changed, 148 insertions(+), 52 deletions(-) diff --git a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py index 428510a3..9012d8b9 100644 --- a/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py +++ b/pytorch_forecasting/models/temporal_fusion_transformer/__init__.py @@ -23,7 +23,7 @@ InterpretableMultiHeadAttention, VariableSelectionNetwork, ) -from pytorch_forecasting.utils import autocorrelation, create_mask, detach, integer_histogram, padded_stack, to_list +from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list class TemporalFusionTransformer(BaseModelWithCovariates): @@ -501,7 +501,8 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: return self.to_network_output( prediction=self.transform_output(output, target_scale=x["target_scale"]), - attention=attn_output_weights, + encoder_attention=attn_output_weights[..., :max_encoder_length], + decoder_attention=attn_output_weights[..., max_encoder_length:], static_variables=static_variable_selection, encoder_variables=encoder_sparse_weights, decoder_variables=decoder_sparse_weights, @@ -540,7 +541,6 @@ def interpret_output( out: Dict[str, torch.Tensor], reduction: str = "none", attention_prediction_horizon: int = 0, - attention_as_autocorrelation: bool = False, ) -> Dict[str, torch.Tensor]: """ interpret output of model @@ -550,29 +550,77 @@ def interpret_output( reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for normalizing by encode lengths attention_prediction_horizon: which prediction horizon to use for attention - attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in - case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False Returns: interpretations that can be plotted with ``plot_interpretation()`` """ # take attention and concatenate if a list to proper attention object - if isinstance(out["attention"], (list, tuple)): + if isinstance(out["decoder_attention"], (list, tuple)): + batch_size = len(out["decoder_attention"]) + # start with decoder attention # assume issue is in last dimension, we need to find max - max_last_dimension = max(x.size(-1) for x in out["attention"]) - first_elm = out["attention"][0] + max_last_dimension = max(x.size(-1) for x in out["decoder_attention"]) + first_elm = out["decoder_attention"][0] # create new attention tensor into which we will scatter - attention = torch.full( - (len(out["attention"]), *first_elm.shape[:-1], max_last_dimension), + decoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], max_last_dimension), float("nan"), dtype=first_elm.dtype, device=first_elm.device, ) # scatter into tensor - for idx, x in enumerate(out["attention"]): - attention[idx, :, :, : x.size(-1)] = x + for idx, x in enumerate(out["decoder_attention"]): + decoder_length = out["decoder_lengths"][idx] + decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length] + + # same game for encoder attention + # create new attention tensor into which we will scatter + first_elm = out["encoder_attention"][0] + encoder_attention = torch.full( + (batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length), + float("nan"), + dtype=first_elm.dtype, + device=first_elm.device, + ) + # scatter into tensor + for idx, x in enumerate(out["encoder_attention"]): + encoder_length = out["encoder_lengths"][idx] + encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length :] = x[ + ..., :encoder_length + ] else: - attention = out["attention"] + decoder_attention = out["decoder_attention"] + decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"]) + decoder_attention[decoder_mask[..., None, None].expand_as(decoder_attention)] = float("nan") + # roll encoder attention (so start last encoder value is on the right) + encoder_attention = out["encoder_attention"] + shifts = encoder_attention.size(3) - out["encoder_lengths"] + new_index = ( + torch.arange(encoder_attention.size(3))[None, None, None].expand_as(encoder_attention) + - shifts[:, None, None, None] + ) % encoder_attention.size(3) + encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index) + # expand encoder_attentiont to full size + if encoder_attention.size(-1) < self.hparams.max_encoder_length: + encoder_attention = torch.concat( + [ + torch.full( + ( + *encoder_attention.shape[:-1], + self.hparams.max_encoder_length - out["encoder_lengths"].max(), + ), + float("nan"), + dtype=encoder_attention.dtype, + device=encoder_attention.device, + ), + encoder_attention, + ], + dim=-1, + ) + + # combine attention vector + attention = torch.concat([encoder_attention, decoder_attention], dim=-1) + attention[attention < 1e-5] = float("nan") # histogram of decode and encode lengths encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length) @@ -599,53 +647,25 @@ def interpret_output( static_variables = out["static_variables"].squeeze(1) # attention is batch x time x heads x time_to_attend # average over heads + only keep prediction attention and attention on observed timesteps - attention = attention[ - :, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon - ].mean(1) + attention = masked_op( + attention[ + :, attention_prediction_horizon, :, : self.hparams.max_encoder_length + attention_prediction_horizon + ], + op="mean", + dim=1, + ) if reduction != "none": # if to average over batches static_variables = static_variables.sum(dim=0) encoder_variables = encoder_variables.sum(dim=0) decoder_variables = decoder_variables.sum(dim=0) - # reorder attention or averaging - for i in range(len(attention)): # very inefficient but does the trick - if 0 < out["encoder_lengths"][i] < attention.size(1) - attention_prediction_horizon - 1: - relevant_attention = attention[ - i, : out["encoder_lengths"][i] + attention_prediction_horizon - ].clone() - if attention_as_autocorrelation: - relevant_attention = autocorrelation(relevant_attention) - attention[i, -out["encoder_lengths"][i] - attention_prediction_horizon :] = relevant_attention - attention[i, : attention.size(1) - out["encoder_lengths"][i] - attention_prediction_horizon] = 0.0 - elif attention_as_autocorrelation: - attention[i] = autocorrelation(attention[i]) - - attention = attention.sum(dim=0) - if reduction == "mean": - attention = attention / encoder_length_histogram[1:].flip(0).cumsum(0).clamp(1) - attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize - elif reduction == "sum": - pass - else: - raise ValueError(f"Unknown reduction {reduction}") - - attention = torch.zeros( - self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device - ).scatter( - dim=0, - index=torch.arange( - self.hparams.max_encoder_length + attention_prediction_horizon - attention.size(-1), - self.hparams.max_encoder_length + attention_prediction_horizon, - device=self.device, - ), - src=attention, - ) + attention = masked_op(attention, dim=0, op=reduction) else: - attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize + attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze(-1) # renormalize interpretation = dict( - attention=attention, + attention=attention.masked_fill(torch.isnan(attention), 0.0), static_variables=static_variables, encoder_variables=encoder_variables, decoder_variables=decoder_variables, @@ -702,7 +722,7 @@ def plot_prediction( encoder_length = x["encoder_lengths"][0] ax2.plot( torch.arange(-encoder_length, 0), - interpretation["attention"][0, :encoder_length].detach().cpu(), + interpretation["attention"][0, -encoder_length:].detach().cpu(), alpha=0.2, color="k", ) diff --git a/pytorch_forecasting/utils.py b/pytorch_forecasting/utils.py index 58feb0af..8e8f3420 100644 --- a/pytorch_forecasting/utils.py +++ b/pytorch_forecasting/utils.py @@ -447,3 +447,28 @@ def detach( return [detach(xi) for xi in x] else: return x + + +def masked_op(tensor: torch.Tensor, op: str = "mean", dim: int = 0, mask: torch.Tensor = None) -> torch.Tensor: + """Calculate operation on masked tensor. + + Args: + tensor (torch.Tensor): tensor to conduct operation over + op (str): operation to apply. One of ["mean", "sum"]. Defaults to "mean". + dim (int, optional): dimension to average over. Defaults to 0. + mask (torch.Tensor, optional): boolean mask to apply (True=will take mean, False=ignore). + Masks nan values by default. + + Returns: + torch.Tensor: tensor with averaged out dimension + """ + if mask is None: + mask = ~torch.isnan(tensor) + masked = tensor.masked_fill(~mask, 0.0) + summed = masked.sum(dim=dim) + if op == "mean": + return summed / mask.sum(dim=dim) # Find the average + elif op == "sum": + return summed + else: + raise ValueError(f"unkown operation {op}") diff --git a/tests/test_models/test_temporal_fusion_transformer.py b/tests/test_models/test_temporal_fusion_transformer.py index ef81907b..fdea8076 100644 --- a/tests/test_models/test_temporal_fusion_transformer.py +++ b/tests/test_models/test_temporal_fusion_transformer.py @@ -2,6 +2,7 @@ import shutil import sys +import numpy as np import pytest import pytorch_lightning as pl from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint @@ -256,6 +257,56 @@ def test_prediction_with_dataloder(model, dataloaders_with_covariates, kwargs): model.predict(val_dataloader, fast_dev_run=True, **kwargs) +def test_prediction_with_dataloder_raw(data_with_covariates, tmp_path): + # tests correct concatenation of raw output + test_data = data_with_covariates.copy() + np.random.seed(2) + test_data = test_data.sample(frac=0.5) + + dataset = TimeSeriesDataSet( + test_data, + time_idx="time_idx", + max_encoder_length=24, + max_prediction_length=10, + min_prediction_length=1, + min_encoder_length=1, + target="volume", + group_ids=["agency", "sku"], + constant_fill_strategy=dict(volume=0.0), + allow_missing_timesteps=True, + time_varying_unknown_reals=["volume"], + time_varying_known_reals=["time_idx"], + ) + + net = TemporalFusionTransformer.from_dataset( + dataset, + learning_rate=1e-6, + hidden_size=4, + attention_head_size=1, + dropout=0.2, + hidden_continuous_size=2, + # loss=PoissonLoss(), + log_interval=1, + log_val_interval=1, + log_gradient_flow=True, + ) + logger = TensorBoardLogger(tmp_path) + trainer = pl.Trainer( + max_epochs=1, + gradient_clip_val=0.1, + logger=logger, + ) + trainer.fit(net, train_dataloaders=dataset.to_dataloader(batch_size=4, num_workers=0)) + + # choose small batch size to provoke issue + res = net.predict(dataset.to_dataloader(batch_size=2, num_workers=0), mode="raw") + # check that interpretation works + net.interpret_output(res)["attention"] + assert net.interpret_output(res.iget(slice(1)))["attention"].size() == torch.Size( + (1, net.hparams.max_encoder_length) + ) + + def test_prediction_with_dataset(model, dataloaders_with_covariates): val_dataloader = dataloaders_with_covariates["val"] model.predict(val_dataloader.dataset, fast_dev_run=True)