diff --git a/CHANGELOG.md b/CHANGELOG.md index 46a30e15eb..f01c4704b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## TBD +### Fixed +- bugfix Favor, single feature map [#183] ## [0.0.8] - 2022-01-07 ### Fixed diff --git a/tests/test_favor.py b/tests/test_favor.py index c060a7467f..2f5ba37305 100644 --- a/tests/test_favor.py +++ b/tests/test_favor.py @@ -85,8 +85,8 @@ def test_feature_map_shape(): ) _ = att(batch, batch, batch) - assert att.feature_map_key.features.shape[0] == batch.shape[-1] - assert att.feature_map_key.features.shape[1] == nb_random_features + assert att.feature_map.features.shape[0] == batch.shape[-1] + assert att.feature_map.features.shape[1] == nb_random_features def test_feature_map_redraw(): @@ -102,20 +102,16 @@ def check(should_redraw: bool): iter_before_redraw=1 if should_redraw else 100, ) v0 = att(batch, batch, batch) - assert att.feature_map_query is not None - assert att.feature_map_key is not None + assert att.feature_map is not None - fq0 = att.feature_map_query.features - fk0 = att.feature_map_key.features + f0 = att.feature_map.features v1 = att(batch, batch, batch) - fq1 = att.feature_map_query.features - fk1 = att.feature_map_key.features + f1 = att.feature_map.features # There should not have been a redraw after v0 assert should_redraw != torch.allclose(v0, v1) - assert should_redraw != torch.allclose(fq0, fq1) # type: ignore - assert should_redraw != torch.allclose(fk0, fk1) # type: ignore + assert should_redraw != torch.allclose(f0, f1) # type: ignore check(should_redraw=True) check(should_redraw=False) diff --git a/xformers/components/attention/favor.py b/xformers/components/attention/favor.py index 88be336b73..3f80a80657 100644 --- a/xformers/components/attention/favor.py +++ b/xformers/components/attention/favor.py @@ -57,7 +57,7 @@ def __init__( Args: dropout (float): the probability of an output to be randomly dropped at training time dim_features (int): the dimension of the random features space - iter_before_redraw (int): the number of iterations before a redraw of the features + iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features feature_map_type (FeatureMapType): the type of feature map being used, for instance orthogonal random features. @@ -67,7 +67,11 @@ def __init__( super().__init__() self.causal = causal - self.iter_before_redraw = iter_before_redraw + self.iter_before_redraw = ( + (2 * iter_before_redraw) + if iter_before_redraw is not None + else iter_before_redraw + ) # This will be used for both key and query self.normalize_inputs = normalize_inputs self.feature_map_type = feature_map_type self.attn_drop = nn.Dropout(dropout, inplace=True) @@ -98,8 +102,7 @@ def __init__( "normalize_inputs": self.normalize_inputs, } - self.feature_map_query: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore - self.feature_map_key: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore + self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore @staticmethod def _maybe_promote(x: torch.Tensor) -> torch.Tensor: @@ -135,8 +138,8 @@ def forward( ): # Project key and queries onto the feature map space - k_prime = self.feature_map_key(k) - q_prime = self.feature_map_query(q) + k_prime = self.feature_map(k) + q_prime = self.feature_map(q) with autocast(enabled=False): # The softmax kernel approximation for Favor will easily overflow