Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diverging PT-Flax Wav2Vec2 Hidden-States #15754

Closed
sanchit-gandhi opened this issue Feb 21, 2022 · 4 comments
Closed

Diverging PT-Flax Wav2Vec2 Hidden-States #15754

sanchit-gandhi opened this issue Feb 21, 2022 · 4 comments
Assignees

Comments

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Feb 21, 2022

I noticed an absence of PT-Flax cross-tests in the test file for the FlaxWav2Vec2 model. I was wondering if at any point a layer-by-layer comparison of the PT-Flax hidden-states was performed? I was having a look through for myself and found that whilst the final hidden-states agreed to within the prescribed 4e-2 threshold (see final line of output), many of the intermediate states did not. Running the script reveals this:

from transformers import FlaxWav2Vec2Model, Wav2Vec2Model
import torch
import numpy as np


model_fx = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
model_pt = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")

input_torch = torch.ones((2, 5000), dtype=torch.float32)
input_fx = input_torch.cpu().numpy()

with torch.no_grad():
    output_hidden_states_pt = model_pt(input_torch, output_hidden_states=True)
output_hidden_states_fx = model_fx(input_fx, output_hidden_states=True)


features_pt = output_hidden_states_pt.extract_features
features_fx = output_hidden_states_fx.extract_features


for feature_fx, feature_pt in zip(features_fx, features_pt):
    print(f"Feature diff {np.max(np.abs(feature_pt.numpy()) - np.asarray(np.abs(feature_fx)))}")


hidden_states_pt = output_hidden_states_pt.hidden_states
hidden_states_fx = output_hidden_states_fx.hidden_states


for layer, (hidden_state_fx, hidden_state_pt) in enumerate(zip(hidden_states_fx, hidden_states_pt)):
    print(f"Layer {layer} diff: {np.max(np.abs(hidden_state_pt.numpy()) - np.asarray(np.abs(hidden_state_fx)))}")


output_logits_pt = output_hidden_states_pt.last_hidden_state
output_logits_fx = output_hidden_states_fx.last_hidden_state


print("Check if shapes are equal")
print(f"Shape PyTorch {output_logits_pt.shape} | Shape Flax {output_logits_fx.shape}")

print("Check if output values are equal")
print(f"Diff {np.max(np.abs(output_logits_pt.numpy()) - np.asarray(np.abs(output_logits_fx)))})")

Output:

Feature diff 0.18386149406433105
Feature diff 0.18386149406433105
Layer 0 diff: 0.3600578308105469
Layer 1 diff: 0.8296781778335571
Layer 2 diff: 0.7679004669189453
Layer 3 diff: 0.8805904388427734
Layer 4 diff: 0.8832664489746094
Layer 5 diff: 1.0264105796813965
Layer 6 diff: 0.9948457479476929
Layer 7 diff: 3.742971420288086
Layer 8 diff: 4.417335510253906
Layer 9 diff: 3.787109375
Layer 10 diff: 4.51409912109375
Layer 11 diff: 5.208351135253906
Layer 12 diff: 4.7732086181640625
Layer 13 diff: 6.4005889892578125
Layer 14 diff: 5.65545654296875
Layer 15 diff: 5.3276214599609375
Layer 16 diff: 5.1604156494140625
Layer 17 diff: 6.2522430419921875
Layer 18 diff: 6.3909912109375
Layer 19 diff: 7.9093780517578125
Layer 20 diff: 7.6656036376953125
Layer 21 diff: 5.81195068359375
Layer 22 diff: 206.0625
Layer 23 diff: 150.2744140625
Layer 24 diff: -0.012473279610276222
Check if shapes are equal
Shape PyTorch torch.Size([2, 15, 1024]) | Shape Flax (2, 15, 1024)
Check if output values are equal
Diff 0.0317840576171875)

Environment:

- `transformers` version: 4.17.0.dev0
- Platform: Darwin-20.2.0-x86_64-i386-64bit
- Python version: 3.7.9
- PyTorch version (GPU?): 1.10.2 (False)
- Tensorflow version (GPU?): 2.8.0 (False)
- Flax version (CPU?/GPU?/TPU?): 0.4.0 (cpu)
- Jax version: 0.3.0
- JaxLib version: 0.3.0
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
@sanchit-gandhi sanchit-gandhi changed the title Diverging PT-Flax Hidden-States Diverging PT-Flax Wav2Vec2 Hidden-States Feb 21, 2022
@sanchit-gandhi sanchit-gandhi self-assigned this Feb 22, 2022
@patrickvonplaten
Copy link
Contributor

Just looked into this and I'm pretty sure the reason is because matrix multiplication is approximated on TPU whereas it is not really on CPU. I'm getting identical results to the ones @sanchit-gandhi has reported above, when the code is run on TPU.

E.g. if I put your code snippet:

from transformers import FlaxWav2Vec2Model, Wav2Vec2Model
import torch
import numpy as np


model_fx = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
model_pt = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")

input_torch = torch.ones((2, 5000), dtype=torch.float32)
input_fx = input_torch.cpu().numpy()

with torch.no_grad():
    output_hidden_states_pt = model_pt(input_torch, output_hidden_states=True)
output_hidden_states_fx = model_fx(input_fx, output_hidden_states=True)


features_pt = output_hidden_states_pt.extract_features
features_fx = output_hidden_states_fx.extract_features


for feature_fx, feature_pt in zip(features_fx, features_pt):
    print(f"Feature diff {np.max(np.abs(feature_pt.numpy()) - np.asarray(np.abs(feature_fx)))}")


hidden_states_pt = output_hidden_states_pt.hidden_states
hidden_states_fx = output_hidden_states_fx.hidden_states


for layer, (hidden_state_fx, hidden_state_pt) in enumerate(zip(hidden_states_fx, hidden_states_pt)):
    print(f"Layer {layer} diff: {np.max(np.abs(hidden_state_pt.numpy()) - np.asarray(np.abs(hidden_state_fx)))}")


output_logits_pt = output_hidden_states_pt.last_hidden_state
output_logits_fx = output_hidden_states_fx.last_hidden_state


print("Check if shapes are equal")
print(f"Shape PyTorch {output_logits_pt.shape} | Shape Flax {output_logits_fx.shape}")

print("Check if output values are equal")
print(f"Diff {np.max(np.abs(output_logits_pt.numpy()) - np.asarray(np.abs(output_logits_fx)))})")

in a file called run.py and then do the following:

JAX_PLATFORM_NAME=tpu python run.py

on a TPUVM v3-8 I'm getting identical results to the ones you reported above. However, when force the computation to take place on CPU by doing the following:

JAX_PLATFORM_NAME=cpu python run.py

I'm getting "much better" results which look as follows:

Feature diff 1.2874603271484375e-05
Feature diff 1.2874603271484375e-05
Layer 0 diff: 0.0005588531494140625
Layer 1 diff: 0.000522613525390625
Layer 2 diff: 0.000514984130859375
Layer 3 diff: 0.0005235671997070312
Layer 4 diff: 0.0005626678466796875
Layer 5 diff: 0.0007305145263671875
Layer 6 diff: 0.0007200241088867188
Layer 7 diff: 0.0028638839721679688
Layer 8 diff: 0.003765106201171875
Layer 9 diff: 0.00331878662109375
Layer 10 diff: 0.0035552978515625
Layer 11 diff: 0.003631591796875
Layer 12 diff: 0.0038299560546875
Layer 13 diff: 0.00469207763671875
Layer 14 diff: 0.004852294921875
Layer 15 diff: 0.0047607421875
Layer 16 diff: 0.0048065185546875
Layer 17 diff: 0.0046844482421875
Layer 18 diff: 0.0043487548828125
Layer 19 diff: 0.0041351318359375
Layer 20 diff: 0.00341796875
Layer 21 diff: 0.00220489501953125
Layer 22 diff: 0.01416015625
Layer 23 diff: 0.091796875
Layer 24 diff: -0.13190405070781708

@patrickvonplaten
Copy link
Contributor

@sanchit-gandhi , you can also get the same high precision when adding the following flag:

JAX_DEFAULT_MATMUL_PRECISION=float32 python run.py

This way the code is still run on TPU but now we get a much higher precision which is as good as the one for CPU:

Feature diff 1.9788742065429688e-05
Feature diff 2.3126602172851562e-05
Layer 0 diff: 0.000499725341796875
Layer 1 diff: 0.00046062469482421875
Layer 2 diff: 0.0004367828369140625
Layer 3 diff: 0.00043582916259765625
Layer 4 diff: 0.0004558563232421875
Layer 5 diff: 0.00041866302490234375
Layer 6 diff: 0.000457763671875
Layer 7 diff: 0.00038313865661621094
Layer 8 diff: 0.0006923675537109375
Layer 9 diff: 0.00075531005859375
Layer 10 diff: 0.00160980224609375
Layer 11 diff: 0.0029754638671875
Layer 12 diff: 0.003421783447265625
Layer 13 diff: 0.0038909912109375
Layer 14 diff: 0.003070831298828125
Layer 15 diff: 0.003139495849609375
Layer 16 diff: 0.0022563934326171875
Layer 17 diff: 0.002674102783203125
Layer 18 diff: 0.0031871795654296875
Layer 19 diff: 0.00322723388671875
Layer 20 diff: 0.003253936767578125
Layer 21 diff: 0.004852294921875
Layer 22 diff: 0.05615234375
Layer 23 diff: 0.228515625
Layer 24 diff: -0.1278175562620163
Check if shapes are equal
Shape PyTorch torch.Size([2, 15, 1024]) | Shape Flax (2, 15, 1024)
Check if output values are equal
Diff 5.14984130859375e-05)

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Feb 22, 2022

Think we can close this for now - maybe jax will soon change it's default TPU precision behavior

@patrickvonplaten
Copy link
Contributor

Info is taken from this PR btw: jax-ml/jax#6143

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants