Skip to content

Commit

Permalink
Fix attention concat
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Beitner committed Mar 23, 2022
1 parent eea8c16 commit ceeec4f
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 5 deletions.
20 changes: 19 additions & 1 deletion pytorch_forecasting/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
import inspect
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import warnings

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -75,7 +76,24 @@ def _torch_cat_na(x: List[torch.Tensor]) -> torch.Tensor:
)
for xi in x
]
return torch.cat(x, dim=0)

# check if remaining dimensions are all equal
if x[0].ndim > 2:
remaining_dimensions_equal = all([all([xi.size(i) == x[0].size(i) for xi in x]) for i in range(2, x[0].ndim)])
else:
remaining_dimensions_equal = True

# deaggregate
if remaining_dimensions_equal:
return torch.cat(x, dim=0)
else:
# make list instead but warn
warnings.warn(
f"Not all dimensions are equal for tensors shapes. Example tensor {x[0].shape}. "
"Returning list instead of torch.Tensor.",
UserWarning,
)
return [xii for xi in x for xii in xi]


def _concatenate_output(
Expand Down
25 changes: 21 additions & 4 deletions pytorch_forecasting/models/temporal_fusion_transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,23 @@ def interpret_output(
Returns:
interpretations that can be plotted with ``plot_interpretation()``
"""
# take attention and concatenate if a list to proper attention object
if isinstance(out["attention"], (list, tuple)):
# assume issue is in last dimension, we need to find max
max_last_dimension = max(x.size(-1) for x in out["attention"])
first_elm = out["attention"][0]
# create new attention tensor into which we will scatter
attention = torch.full(
(len(out["attention"]), *first_elm.shape[:-1], max_last_dimension),
float("nan"),
dtype=first_elm.dtype,
device=first_elm.device,
)
# scatter into tensor
for idx, x in enumerate(out["attention"]):
attention[idx, :, :, : x.size(-1)] = x
else:
attention = out["attention"]

# histogram of decode and encode lengths
encoder_length_histogram = integer_histogram(out["encoder_lengths"], min=0, max=self.hparams.max_encoder_length)
Expand All @@ -582,7 +599,7 @@ def interpret_output(
static_variables = out["static_variables"].squeeze(1)
# attention is batch x time x heads x time_to_attend
# average over heads + only keep prediction attention and attention on observed timesteps
attention = out["attention"][
attention = attention[
:, attention_prediction_horizon, :, : out["encoder_lengths"].max() + attention_prediction_horizon
].mean(1)

Expand Down Expand Up @@ -677,15 +694,15 @@ def plot_prediction(

# add attention on secondary axis
if plot_attention:
interpretation = self.interpret_output(out)
interpretation = self.interpret_output(out.iget(slice(idx, idx + 1)))
for f in to_list(fig):
ax = f.axes[0]
ax2 = ax.twinx()
ax2.set_ylabel("Attention")
encoder_length = x["encoder_lengths"][idx]
encoder_length = x["encoder_lengths"][0]
ax2.plot(
torch.arange(-encoder_length, 0),
interpretation["attention"][idx, :encoder_length].detach().cpu(),
interpretation["attention"][0, :encoder_length].detach().cpu(),
alpha=0.2,
color="k",
)
Expand Down
11 changes: 11 additions & 0 deletions pytorch_forecasting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,17 @@ def items(self):
def keys(self):
return self._fields

def iget(self, idx: Union[int, slice]):
"""Select item(s) row-wise.
Args:
idx ([int, slice]): item to select
Returns:
Output of single item.
"""
return self.__class__(*(x[idx] for x in self))


class TupleOutputMixIn:
"""MixIn to give output a namedtuple-like access capabilities with ``to_network_output() function``."""
Expand Down

0 comments on commit ceeec4f

Please sign in to comment.