Skip to content

Commit

Permalink
- Input projection refactoring
Browse files Browse the repository at this point in the history
- Adding several init options
  • Loading branch information
blefaudeux committed May 31, 2022
1 parent 2957a71 commit f468472
Show file tree
Hide file tree
Showing 20 changed files with 531 additions and 339 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

### Added
- Support several initialization options [#312]


## [0.0.11] - 2022-05-30
### Fixed
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,4 @@ The following repositories are used in xFormers, either in close to original for
* [RevTorch](https://github.com/RobinBruegger/RevTorch)
* [Nystromformer](https://github.com/mlpen/Nystromformer)
* [FairScale](https://github.com/facebookresearch/fairscale/)
* [Pytorch Image Models](https://github.com/rwightman/pytorch-image-models)
Binary file added docs/plots/mha/MHA_FW+bw_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/mha/MHA_FW_torch.float16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/plots/mha/MHA_FW_torch.float32.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(

# Now instantiate the metaformer trunk
config = xFormerConfig(xformer_config)

print(config)
self.trunk = xFormer.from_config(config)
print(self.trunk)
Expand Down
6 changes: 2 additions & 4 deletions examples/microGPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
"block_type": "encoder",
"num_layers": self.hparams.n_layer,
"dim_model": self.hparams.n_embd,
"layer_norm_style": "pre",
"layer_norm_style": "post",
"position_encoding_config": {
"name": "vocab",
"seq_len": self.hparams.block_size,
Expand Down Expand Up @@ -274,7 +274,7 @@ def top_k_logits(logits, k):
# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 256
BATCH = 128

WORKERS = 4
EPOCHS = 1
Expand Down Expand Up @@ -312,9 +312,7 @@ def top_k_logits(logits, k):
gpus=1,
max_epochs=EPOCHS,
precision=16,
gradient_clip_val=1, # Use to catch divergent gradients, if experimenting
log_every_n_steps=1,
# detect_anomaly=True, # Use to catch NaNs, if experimenting
accumulate_grad_batches=REF_BATCH // BATCH,
)

Expand Down
30 changes: 22 additions & 8 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _get_multihead(
heads,
device,
skip_output_projection=False,
use_seperate_proj_weights=True,
):
test_config = {
"name": attention_name,
Expand All @@ -52,6 +53,7 @@ def _get_multihead(
"num_heads": heads,
"dim_head": MODEL / heads,
"num_rules": 2, # Compositional Attention
"r": 0.5, # random attention, ratio of tokens that the attention can attend to
}

if skip_output_projection:
Expand All @@ -77,6 +79,7 @@ def noop(x):
residual_dropout=res_dropout,
num_heads=heads,
attention=attention,
use_separate_proj_weight=use_seperate_proj_weights,
).to(device)

return multi_head
Expand All @@ -101,9 +104,16 @@ def test_order_invariance(
pytest.skip(f"{attention_name} requires squared sequence lengths")

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

multi_head = _get_multihead(
attention_name, attn_dropout, residual_dropout, causal, heads, device
attention_name,
attn_dropout,
residual_dropout,
causal,
heads,
device,
use_seperate_proj_weights=False,
)

# Check that a shuffled input produces the same results
Expand Down Expand Up @@ -177,13 +187,10 @@ def test_kqv_ordering(
assert torch.allclose(res_false[0, :, :], res_false[1, :, :])


@pytest.mark.parametrize("small_init", [False, True])
@pytest.mark.parametrize("proj_bias", [False, True])
@pytest.mark.parametrize("same_sizes", [False, True])
@pytest.mark.parametrize("same_settings", [False, True])
def test_inproj(
small_init: bool, proj_bias: bool, same_sizes: bool, same_settings: bool
):
def test_inproj(proj_bias: bool, same_sizes: bool, same_settings: bool):

test_config = {
"name": "scaled_dot_product",
Expand All @@ -198,14 +205,19 @@ def test_inproj(
attention = build_attention(test_config)

# Construct the initial projection, test different options
in_params = InProjParams(MODEL, MODEL, proj_bias, small_init)
in_params = InProjParams(MODEL, MODEL, proj_bias)

if same_settings:
in_proj = InProjContainer(in_params, None, None)
out_features = MODEL
else:
out_features = MODEL if same_sizes else MODEL // 2
in_params_flip = InProjParams(MODEL, out_features, proj_bias, small_init)
in_proj = InProjContainer(in_params, in_params_flip, in_params_flip)
in_params_flip = InProjParams(MODEL, out_features, proj_bias)
in_proj = InProjContainer(
in_params_flip, # Q proj
in_params_flip, # K proj
in_params, # V proj
)

# build a multi head dispatch to test this attention mechanism
multi_head = MultiHeadDispatch(
Expand All @@ -215,6 +227,8 @@ def test_inproj(
num_heads=1,
attention=attention,
in_proj_container=in_proj,
dim_key=out_features,
dim_value=MODEL,
)

# Check kqv are not flipped
Expand Down
22 changes: 21 additions & 1 deletion tests/test_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import pytest
import torch

from xformers.factory.model_factory import xFormer, xFormerConfig
import xformers.factory.weight_init as xformers_weight_init
from xformers.factory import xFormer, xFormerConfig, xFormerWeightInit

BATCH = 2
SEQ = 16
Expand Down Expand Up @@ -195,3 +196,22 @@ def check_against_default(p):
# If we requested tied embedding weights, check that this is the case indeed
if tie_embedding_weights and not reversible:
assert model.encoders[0].pose_encoding == model.decoders[0].pose_encoding


@pytest.mark.parametrize("weight_init", [w.value for w in xFormerWeightInit])
@pytest.mark.parametrize("device", DEVICES)
def test_weight_init(weight_init, device):
torch.cuda.manual_seed(42)
torch.manual_seed(42)

config = test_configs_dict

# Make sure that all the init methods catch all the weights
xformers_weight_init._assert_if_not_initialized = True

# Build the model
config_instance = xFormerConfig( # noqa
config, tie_embedding_weights=False, weight_init=weight_init
)

_ = xFormer.from_config(config_instance).to(device)
12 changes: 4 additions & 8 deletions tests/test_pytorch_transformer_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
"attention": {
"name": "scaled_dot_product",
"dropout": DROP,
"causal": False,
"seq_len": SEQ,
},
"dim_model": EMB,
Expand All @@ -62,7 +61,6 @@
"attention": {
"name": "scaled_dot_product",
"dropout": DROP,
"causal": False,
"seq_len": SEQ,
},
},
Expand All @@ -74,7 +72,6 @@
"attention": {
"name": "scaled_dot_product",
"dropout": DROP,
"causal": False,
"seq_len": SEQ,
},
},
Expand All @@ -90,8 +87,9 @@
_test_config = [_test_config_encoder, _test_config_decoder]

def reset_seeds():
torch.manual_seed(0)
random.seed(0)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
random.seed(42)

@pytest.mark.skipif(
not torch.cuda.is_available(), reason="This test requires a gpu"
Expand Down Expand Up @@ -140,8 +138,7 @@ def test_pytorch_encoder_parity(device=torch.device("cuda")):

fit_ratio_xformer = eval_start_xformer / eval_stop_xformer
fit_ratio_pytorch = eval_start_pytorch / eval_stop_pytorch

print(fit_ratio_pytorch, fit_ratio_xformer)
print("fit ratios: ", fit_ratio_pytorch, fit_ratio_xformer)

# Catch a broken training
assert fit_ratio_xformer > 120
Expand Down Expand Up @@ -172,7 +169,6 @@ def test_pytorch_tranformer_parity(device=torch.device("cuda")):
dim_feedforward=4 * EMB,
dropout=DROP,
activation=ACTIVATION,
layer_norm_eps=1e-06,
batch_first=True, # (batch, seq, feature)
device=device,
)
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from .activations import Activation, build_activation # noqa
from .attention import Attention, build_attention # noqa
from .in_proj_container import InProjContainer, InProjParams # noqa
from .input_projection import InProjContainer, InProjParams # noqa
from .multi_head_dispatch import MultiHeadDispatch # noqa
from .multi_head_dispatch import MultiHeadDispatchConfig
from .patch_embedding import PatchEmbeddingConfig # noqa
Expand Down
2 changes: 1 addition & 1 deletion xformers/components/attention/compositional.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
register_attention,
)
from xformers.components.attention.core import _softmax
from xformers.components.in_proj_container import InProjContainer, InProjParams
from xformers.components.input_projection import InProjContainer, InProjParams


def _either_or(a: Optional[int], b: int) -> int:
Expand Down
Loading

0 comments on commit f468472

Please sign in to comment.