Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add attention map visualization func #475

Merged
merged 4 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pypots/data/saving/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ def pickle_dump(data: object, path: str) -> None:
create_dir_if_not_exist(extract_parent_dir(path))
with open(path, "wb") as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
logger.info(f"Successfully saved to {path}")
except Exception as e:
logger.error(
f"❌ Pickling failed. No cache data saved. Please investigate the error below.\n{e}"
f"❌ Pickling failed. No cache data saved. Investigate the error below:\n{e}"
)
return None
logger.info(f"Successfully saved to {path}")

return None


def pickle_load(path: str) -> object:
Expand All @@ -58,6 +59,9 @@ def pickle_load(path: str) -> object:
with open(path, "rb") as f:
data = pickle.load(f)
except Exception as e:
logger.error(f"❌ Loading data failed. Operation aborted. See info below:\n{e}")
logger.error(
f"❌ Loading data failed. Operation aborted. Investigate the error below:\n{e}"
)
return None

return data
5 changes: 3 additions & 2 deletions pypots/nn/modules/etsformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ def forward(self, x):
f = fft.rfftfreq(t)[self.low_freq :]

x_freq, index_tuple = self.topk_freq(x_freq)
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(x_freq.device)
f = rearrange(f[index_tuple], "b f d -> b f () d").to(x_freq.device)
device = x_freq.device
f = repeat(f, "f -> b f d", b=x_freq.size(0), d=x_freq.size(2)).to(device)
f = rearrange(f[index_tuple], "b f d -> b f () d").to(device)

return self.extrapolate(x_freq, f, t)

Expand Down
63 changes: 63 additions & 0 deletions pypots/utils/visual/attention_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""
Utilities for attention map visualization.
"""

# Created by Anshuman Swain <aswai@seas.upenn.edu> and Wenjie Du <wenjay.du@gmail.com>
# License: BSD-3-Clause

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import ArrayLike

try:
import seaborn as sns
except Exception:
pass


def plot_attention(timeSteps: ArrayLike, attention: np.ndarray, fontscale=None):
"""Visualize the map of attention weights from Transformer-based models.

Parameters
---------------
timeSteps: 1D array-like object, preferable list of strings
A vector containing the time steps of the input.
The time steps will be converted to a list of strings if they are not already.

attention: 2D array-like object
A 2D matrix representing the attention weights

fontscale: float/int
Sets the scale for fonts in the Seaborn heatmap (applied to sns.set_theme(font_scale = _)


Return
---------------
ax: Matplotlib axes object

"""

if not all(isinstance(ele, str) for ele in timeSteps):
timeSteps = [str(step) for step in timeSteps]

if fontscale is not None:
sns.set_theme(font_scale=fontscale)

fig, ax = plt.subplots()
ax.tick_params(left=True, bottom=True, labelsize=10)
ax.set_xticks(ax.get_xticks()[::2])
ax.set_yticks(ax.get_yticks()[::2])

assert attention.ndim == 2, "The attention matrix is not two-dimensional"
sns.heatmap(
attention,
ax=ax,
xticklabels=timeSteps,
yticklabels=timeSteps,
linewidths=0,
cbar=True,
)
cb = ax.collections[0].colorbar
cb.ax.tick_params(labelsize=10)

return fig
Loading