Skip to content

Commit

Permalink
[block_factory] Tentative improvements in handling the residual path (f…
Browse files Browse the repository at this point in the history
…acebookresearch#186)

* tentative improvements in handling the residual path, and cleaning up unused fields
* Split residual and norm paths
  • Loading branch information
blefaudeux authored Jul 8, 2021
1 parent ab1e91d commit 9273d7c
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 97 deletions.
9 changes: 7 additions & 2 deletions tests/test_block_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys())
@pytest.mark.parametrize("layer_norm_style", ["pre", "post"])
@pytest.mark.parametrize("device", DEVICES)
def test_xformer_encoder_block(
attention_name: str,
Expand All @@ -43,6 +44,7 @@ def test_xformer_encoder_block(
residual_dropout: float,
causal: bool,
activation: Activation,
layer_norm_style: str,
device: torch.device,
):

Expand Down Expand Up @@ -84,6 +86,7 @@ def test_xformer_encoder_block(
multi_head_config=multi_head_config,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
layer_norm_style=layer_norm_style,
)

# Test that the whole block can be instantiated
Expand All @@ -110,6 +113,7 @@ def test_xformer_encoder_block(
@pytest.mark.parametrize("activation", [a.value for a in Activation])
@pytest.mark.parametrize("attention_name", ATTENTION_REGISTRY.keys())
@pytest.mark.parametrize("feedforward_name", FEEDFORWARD_REGISTRY.keys())
@pytest.mark.parametrize("layer_norm_style", ["pre", "post"])
@pytest.mark.parametrize("device", DEVICES)
def test_xformer_decoder_block(
attention_name: str,
Expand All @@ -119,6 +123,7 @@ def test_xformer_decoder_block(
residual_dropout: float,
causal: bool,
activation: Activation,
layer_norm_style: str,
device: torch.device,
):

Expand Down Expand Up @@ -164,8 +169,8 @@ def test_xformer_decoder_block(

decoder_block_config = xFormerDecoderConfig(
dim_model=MODEL,
multi_head_config_pre_encoder=multi_head_config,
multi_head_config_post_encoder=multi_head_config,
multi_head_config_masked=multi_head_config,
multi_head_config_cross=multi_head_config,
feedforward_config=feedforward_config,
position_encoding_config=position_encoding_config,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"vocab_size": 64,
},
"num_layers": 2,
"multi_head_config_pre_encoder": {
"multi_head_config_masked": {
"num_heads": 4,
"dim_model": 384,
"residual_dropout": 0,
Expand All @@ -66,7 +66,7 @@
"seq_len": 512,
},
},
"multi_head_config_post_encoder": {
"multi_head_config_cross": {
"num_heads": 4,
"dim_model": 384,
"residual_dropout": 0,
Expand Down
Loading

0 comments on commit 9273d7c

Please sign in to comment.