Skip to content

Commit

Permalink
[bugfix][Favor] Same feature map for q and k (#183)
Browse files Browse the repository at this point in the history
* same feature map for q and k
* nasty bugfix, thanks @fmassa
  • Loading branch information
blefaudeux authored Jan 18, 2022
1 parent 60e94e5 commit 5ce6428
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## TBD
### Fixed
- bugfix Favor, single feature map [#183]

## [0.0.8] - 2022-01-07
### Fixed
Expand Down
16 changes: 6 additions & 10 deletions tests/test_favor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def test_feature_map_shape():
)
_ = att(batch, batch, batch)

assert att.feature_map_key.features.shape[0] == batch.shape[-1]
assert att.feature_map_key.features.shape[1] == nb_random_features
assert att.feature_map.features.shape[0] == batch.shape[-1]
assert att.feature_map.features.shape[1] == nb_random_features


def test_feature_map_redraw():
Expand All @@ -102,20 +102,16 @@ def check(should_redraw: bool):
iter_before_redraw=1 if should_redraw else 100,
)
v0 = att(batch, batch, batch)
assert att.feature_map_query is not None
assert att.feature_map_key is not None
assert att.feature_map is not None

fq0 = att.feature_map_query.features
fk0 = att.feature_map_key.features
f0 = att.feature_map.features

v1 = att(batch, batch, batch)
fq1 = att.feature_map_query.features
fk1 = att.feature_map_key.features
f1 = att.feature_map.features

# There should not have been a redraw after v0
assert should_redraw != torch.allclose(v0, v1)
assert should_redraw != torch.allclose(fq0, fq1) # type: ignore
assert should_redraw != torch.allclose(fk0, fk1) # type: ignore
assert should_redraw != torch.allclose(f0, f1) # type: ignore

check(should_redraw=True)
check(should_redraw=False)
Expand Down
15 changes: 9 additions & 6 deletions xformers/components/attention/favor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
Args:
dropout (float): the probability of an output to be randomly dropped at training time
dim_features (int): the dimension of the random features space
iter_before_redraw (int): the number of iterations before a redraw of the features
iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
feature_map_type (FeatureMapType): the type of feature map being used,
for instance orthogonal random features.
Expand All @@ -67,7 +67,11 @@ def __init__(
super().__init__()

self.causal = causal
self.iter_before_redraw = iter_before_redraw
self.iter_before_redraw = (
(2 * iter_before_redraw)
if iter_before_redraw is not None
else iter_before_redraw
) # This will be used for both key and query
self.normalize_inputs = normalize_inputs
self.feature_map_type = feature_map_type
self.attn_drop = nn.Dropout(dropout, inplace=True)
Expand Down Expand Up @@ -98,8 +102,7 @@ def __init__(
"normalize_inputs": self.normalize_inputs,
}

self.feature_map_query: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
self.feature_map_key: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore

@staticmethod
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -135,8 +138,8 @@ def forward(
):

# Project key and queries onto the feature map space
k_prime = self.feature_map_key(k)
q_prime = self.feature_map_query(q)
k_prime = self.feature_map(k)
q_prime = self.feature_map(q)

with autocast(enabled=False):
# The softmax kernel approximation for Favor will easily overflow
Expand Down

0 comments on commit 5ce6428

Please sign in to comment.