Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix][Favor] Same feature map for q and k #183

Merged
merged 2 commits into from
Jan 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## TBD
### Fixed
- bugfix Favor, single feature map [#183]

## [0.0.8] - 2022-01-07
### Fixed
Expand Down
16 changes: 6 additions & 10 deletions tests/test_favor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def test_feature_map_shape():
)
_ = att(batch, batch, batch)

assert att.feature_map_key.features.shape[0] == batch.shape[-1]
assert att.feature_map_key.features.shape[1] == nb_random_features
assert att.feature_map.features.shape[0] == batch.shape[-1]
assert att.feature_map.features.shape[1] == nb_random_features


def test_feature_map_redraw():
Expand All @@ -102,20 +102,16 @@ def check(should_redraw: bool):
iter_before_redraw=1 if should_redraw else 100,
)
v0 = att(batch, batch, batch)
assert att.feature_map_query is not None
assert att.feature_map_key is not None
assert att.feature_map is not None

fq0 = att.feature_map_query.features
fk0 = att.feature_map_key.features
f0 = att.feature_map.features

v1 = att(batch, batch, batch)
fq1 = att.feature_map_query.features
fk1 = att.feature_map_key.features
f1 = att.feature_map.features

# There should not have been a redraw after v0
assert should_redraw != torch.allclose(v0, v1)
assert should_redraw != torch.allclose(fq0, fq1) # type: ignore
assert should_redraw != torch.allclose(fk0, fk1) # type: ignore
assert should_redraw != torch.allclose(f0, f1) # type: ignore

check(should_redraw=True)
check(should_redraw=False)
Expand Down
15 changes: 9 additions & 6 deletions xformers/components/attention/favor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
Args:
dropout (float): the probability of an output to be randomly dropped at training time
dim_features (int): the dimension of the random features space
iter_before_redraw (int): the number of iterations before a redraw of the features
iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
feature_map_type (FeatureMapType): the type of feature map being used,
for instance orthogonal random features.

Expand All @@ -67,7 +67,11 @@ def __init__(
super().__init__()

self.causal = causal
self.iter_before_redraw = iter_before_redraw
self.iter_before_redraw = (
(2 * iter_before_redraw)
if iter_before_redraw is not None
else iter_before_redraw
) # This will be used for both key and query
self.normalize_inputs = normalize_inputs
self.feature_map_type = feature_map_type
self.attn_drop = nn.Dropout(dropout, inplace=True)
Expand Down Expand Up @@ -98,8 +102,7 @@ def __init__(
"normalize_inputs": self.normalize_inputs,
}

self.feature_map_query: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
self.feature_map_key: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably enforce now that the iter_before_redraw is a multiple of 2, or multiply it by 2 in the implementation of the FeatureMap, otherwise we can still sample different maps for k / q

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oww, that would have been an annoying bug, great catch and I went too fast. I would do a x2 on the setting, because I think that from the outside "iter" means "number of steps we took"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the docstring for the attention to make it clear that this counted the number of "forward" calls


@staticmethod
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -135,8 +138,8 @@ def forward(
):

# Project key and queries onto the feature map space
k_prime = self.feature_map_key(k)
q_prime = self.feature_map_query(q)
k_prime = self.feature_map(k)
q_prime = self.feature_map(q)

with autocast(enabled=False):
# The softmax kernel approximation for Favor will easily overflow
Expand Down