Skip to content

Commit

Permalink
Update attention calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Mar 23, 2022
1 parent ceeec4f commit eeec569
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 52 deletions.
124 changes: 72 additions & 52 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
InterpretableMultiHeadAttention,
VariableSelectionNetwork,
)
from pytorch_forecasting.utils import autocorrelation, create_mask, detach, integer_histogram, padded_stack, to_list
from pytorch_forecasting.utils import create_mask, detach, integer_histogram, masked_op, padded_stack, to_list


class TemporalFusionTransformer(BaseModelWithCovariates):
Expand Down Expand Up @@ -501,7 +501,8 @@ def forward(self, x: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

return self.to_network_output(
prediction=self.transform_output(output, target_scale=x["target_scale"]),
attention=attn_output_weights,
encoder_attention=attn_output_weights[..., :max_encoder_length],
decoder_attention=attn_output_weights[..., max_encoder_length:],
static_variables=static_variable_selection,
encoder_variables=encoder_sparse_weights,
decoder_variables=decoder_sparse_weights,
Expand Down Expand Up @@ -540,7 +541,6 @@ def interpret_output(
out: Dict[str, torch.Tensor],
reduction: str = "none",
attention_prediction_horizon: int = 0,
attention_as_autocorrelation: bool = False,
) -> Dict[str, torch.Tensor]:
"""
interpret output of model
Expand All @@ -550,29 +550,77 @@ def interpret_output(
reduction: "none" for no averaging over batches, "sum" for summing attentions, "mean" for
normalizing by encode lengths
attention_prediction_horizon: which prediction horizon to use for attention
attention_as_autocorrelation: if to record attention as autocorrelation - this should be set to true in
case of ``reduction != "none"`` and differing prediction times of the samples. Defaults to False
Returns:
interpretations that can be plotted with ``plot_interpretation()``
"""
# take attention and concatenate if a list to proper attention object
if isinstance(out["attention"], (list, tuple)):
if isinstance(out["decoder_attention"], (list, tuple)):
batch_size = len(out["decoder_attention"])
# start with decoder attention
# assume issue is in last dimension, we need to find max
max_last_dimension = max(x.size(-1) for x in out["attention"])
first_elm = out["attention"][0]
max_last_dimension = max(x.size(-1) for x in out["decoder_attention"])
first_elm = out["decoder_attention"][0]
# create new attention tensor into which we will scatter
attention = torch.full(
(len(out["attention"]), *first_elm.shape[:-1], max_last_dimension),
decoder_attention = torch.full(
(batch_size, *first_elm.shape[:-1], max_last_dimension),
float("nan"),
dtype=first_elm.dtype,
device=first_elm.device,
)
# scatter into tensor
for idx, x in enumerate(out["attention"]):
attention[idx, :, :, : x.size(-1)] = x
for idx, x in enumerate(out["decoder_attention"]):
decoder_length = out["decoder_lengths"][idx]
decoder_attention[idx, :, :, :decoder_length] = x[..., :decoder_length]

# same game for encoder attention
# create new attention tensor into which we will scatter
first_elm = out["encoder_attention"][0]
encoder_attention = torch.full(
(batch_size, *first_elm.shape[:-1], self.hparams.max_encoder_length),
float("nan"),
dtype=first_elm.dtype,
device=first_elm.device,
)
# scatter into tensor
for idx, x in enumerate(out["encoder_attention"]):
encoder_length = out["encoder_lengths"][idx]
encoder_attention[idx, :, :, self.hparams.max_encoder_length - encoder_length :] = x[
..., :encoder_length
]
else:
attention = out["attention"]
decoder_attention = out["decoder_attention"]
decoder_mask = create_mask(out["decoder_attention"].size(1), out["decoder_lengths"])
decoder_attention[decoder_mask[..., None, None].expand_as(decoder_attention)] = float("nan")
# roll encoder attention (so start last encoder value is on the right)
encoder_attention = out["encoder_attention"]
shifts = encoder_attention.size(3) - out["encoder_lengths"]
new_index = (
torch.arange(encoder_attention.size(3))[None, None, None].expand_as(encoder_attention)
- shifts[:, None, None, None]
) % encoder_attention.size(3)
encoder_attention = torch.gather(encoder_attention, dim=3, index=new_index)
# expand encoder_attentiont to full size
if encoder_attention.size(-1) < self.hparams.max_encoder_length:
encoder_attention = torch.concat(
[
torch.full(
(
*encoder_attention.shape[:-1],
self.hparams.max_encoder_length - out["encoder_lengths"].max(),
),
float("nan"),
dtype=encoder_attention.dtype,
device=encoder_attention.device,
),
encoder_attention,
],
dim=-1,
)

# combine attention vector
attention = torch.concat([encoder_attention, decoder_attention], dim=-1)
attention[attention < 1e-5] = float("nan")

# histogram of decode and encode lengths
encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
Expand All @@ -599,53 +647,25 @@ def interpret_output(
static_variables = out["static_variables"].squeeze(1)
# attention is batch x time x heads x time_to_attend
# average over heads + only keep prediction attention and attention on observed timesteps
attention = attention[
:, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon
].mean(1)
attention = masked_op(
attention[
:, attention_prediction_horizon, :, : self.hparams.max_encoder_length + attention_prediction_horizon
],
op="mean",
dim=1,
)

if reduction != "none": # if to average over batches
static_variables = static_variables.sum(dim=0)
encoder_variables = encoder_variables.sum(dim=0)
decoder_variables = decoder_variables.sum(dim=0)

# reorder attention or averaging
for i in range(len(attention)): # very inefficient but does the trick
if 0 < out["encoder_lengths"][i] < attention.size(1) - attention_prediction_horizon - 1:
relevant_attention = attention[
i, : out["encoder_lengths"][i] + attention_prediction_horizon
].clone()
if attention_as_autocorrelation:
relevant_attention = autocorrelation(relevant_attention)
attention[i, -out["encoder_lengths"][i] - attention_prediction_horizon :] = relevant_attention
attention[i, : attention.size(1) - out["encoder_lengths"][i] - attention_prediction_horizon] = 0.0
elif attention_as_autocorrelation:
attention[i] = autocorrelation(attention[i])

attention = attention.sum(dim=0)
if reduction == "mean":
attention = attention / encoder_length_histogram[1:].flip(0).cumsum(0).clamp(1)
attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize
elif reduction == "sum":
pass
else:
raise ValueError(f"Unknown reduction {reduction}")

attention = torch.zeros(
self.hparams.max_encoder_length + attention_prediction_horizon, device=self.device
).scatter(
dim=0,
index=torch.arange(
self.hparams.max_encoder_length + attention_prediction_horizon - attention.size(-1),
self.hparams.max_encoder_length + attention_prediction_horizon,
device=self.device,
),
src=attention,
)
attention = masked_op(attention, dim=0, op=reduction)
else:
attention = attention / attention.sum(-1).unsqueeze(-1) # renormalize
attention = attention / masked_op(attention, dim=1, op="sum").unsqueeze(-1) # renormalize

interpretation = dict(
attention=attention,
attention=attention.masked_fill(torch.isnan(attention), 0.0),
static_variables=static_variables,
encoder_variables=encoder_variables,
decoder_variables=decoder_variables,
Expand Down Expand Up @@ -702,7 +722,7 @@ def plot_prediction(
encoder_length = x["encoder_lengths"][0]
ax2.plot(
torch.arange(-encoder_length, 0),
interpretation["attention"][0, :encoder_length].detach().cpu(),
interpretation["attention"][0, -encoder_length:].detach().cpu(),
alpha=0.2,
color="k",
)
Expand Down
25 changes: 25 additions & 0 deletions pytorch_forecasting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,28 @@ def detach(
return [detach(xi) for xi in x]
else:
return x


def masked_op(tensor: torch.Tensor, op: str = "mean", dim: int = 0, mask: torch.Tensor = None) -> torch.Tensor:
"""Calculate operation on masked tensor.
Args:
tensor (torch.Tensor): tensor to conduct operation over
op (str): operation to apply. One of ["mean", "sum"]. Defaults to "mean".
dim (int, optional): dimension to average over. Defaults to 0.
mask (torch.Tensor, optional): boolean mask to apply (True=will take mean, False=ignore).
Masks nan values by default.
Returns:
torch.Tensor: tensor with averaged out dimension
"""
if mask is None:
mask = ~torch.isnan(tensor)
masked = tensor.masked_fill(~mask, 0.0)
summed = masked.sum(dim=dim)
if op == "mean":
return summed / mask.sum(dim=dim) # Find the average
elif op == "sum":
return summed
else:
raise ValueError(f"unkown operation {op}")
51 changes: 51 additions & 0 deletions tests/test_models/test_temporal_fusion_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
import sys

import numpy as np
import pytest
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
Expand Down Expand Up @@ -256,6 +257,56 @@ def test_prediction_with_dataloder(model, dataloaders_with_covariates, kwargs):
model.predict(val_dataloader, fast_dev_run=True, **kwargs)


def test_prediction_with_dataloder_raw(data_with_covariates, tmp_path):
# tests correct concatenation of raw output
test_data = data_with_covariates.copy()
np.random.seed(2)
test_data = test_data.sample(frac=0.5)

dataset = TimeSeriesDataSet(
test_data,
time_idx="time_idx",
max_encoder_length=24,
max_prediction_length=10,
min_prediction_length=1,
min_encoder_length=1,
target="volume",
group_ids=["agency", "sku"],
constant_fill_strategy=dict(volume=0.0),
allow_missing_timesteps=True,
time_varying_unknown_reals=["volume"],
time_varying_known_reals=["time_idx"],
)

net = TemporalFusionTransformer.from_dataset(
dataset,
learning_rate=1e-6,
hidden_size=4,
attention_head_size=1,
dropout=0.2,
hidden_continuous_size=2,
# loss=PoissonLoss(),
log_interval=1,
log_val_interval=1,
log_gradient_flow=True,
)
logger = TensorBoardLogger(tmp_path)
trainer = pl.Trainer(
max_epochs=1,
gradient_clip_val=0.1,
logger=logger,
)
trainer.fit(net, train_dataloaders=dataset.to_dataloader(batch_size=4, num_workers=0))

# choose small batch size to provoke issue
res = net.predict(dataset.to_dataloader(batch_size=2, num_workers=0), mode="raw")
# check that interpretation works
net.interpret_output(res)["attention"]
assert net.interpret_output(res.iget(slice(1)))["attention"].size() == torch.Size(
(1, net.hparams.max_encoder_length)
)


def test_prediction_with_dataset(model, dataloaders_with_covariates):
val_dataloader = dataloaders_with_covariates["val"]
model.predict(val_dataloader.dataset, fast_dev_run=True)
Expand Down

0 comments on commit eeec569

Please sign in to comment.