-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Diverging PT-Flax Wav2Vec2 Hidden-States #15754
Comments
Just looked into this and I'm pretty sure the reason is because matrix multiplication is approximated on TPU whereas it is not really on CPU. I'm getting identical results to the ones @sanchit-gandhi has reported above, when the code is run on TPU. E.g. if I put your code snippet: from transformers import FlaxWav2Vec2Model, Wav2Vec2Model
import torch
import numpy as np
model_fx = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60", from_pt=True)
model_pt = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-large-lv60")
input_torch = torch.ones((2, 5000), dtype=torch.float32)
input_fx = input_torch.cpu().numpy()
with torch.no_grad():
output_hidden_states_pt = model_pt(input_torch, output_hidden_states=True)
output_hidden_states_fx = model_fx(input_fx, output_hidden_states=True)
features_pt = output_hidden_states_pt.extract_features
features_fx = output_hidden_states_fx.extract_features
for feature_fx, feature_pt in zip(features_fx, features_pt):
print(f"Feature diff {np.max(np.abs(feature_pt.numpy()) - np.asarray(np.abs(feature_fx)))}")
hidden_states_pt = output_hidden_states_pt.hidden_states
hidden_states_fx = output_hidden_states_fx.hidden_states
for layer, (hidden_state_fx, hidden_state_pt) in enumerate(zip(hidden_states_fx, hidden_states_pt)):
print(f"Layer {layer} diff: {np.max(np.abs(hidden_state_pt.numpy()) - np.asarray(np.abs(hidden_state_fx)))}")
output_logits_pt = output_hidden_states_pt.last_hidden_state
output_logits_fx = output_hidden_states_fx.last_hidden_state
print("Check if shapes are equal")
print(f"Shape PyTorch {output_logits_pt.shape} | Shape Flax {output_logits_fx.shape}")
print("Check if output values are equal")
print(f"Diff {np.max(np.abs(output_logits_pt.numpy()) - np.asarray(np.abs(output_logits_fx)))})") in a file called JAX_PLATFORM_NAME=tpu python run.py on a TPUVM v3-8 I'm getting identical results to the ones you reported above. However, when force the computation to take place on CPU by doing the following: JAX_PLATFORM_NAME=cpu python run.py I'm getting "much better" results which look as follows: Feature diff 1.2874603271484375e-05
Feature diff 1.2874603271484375e-05
Layer 0 diff: 0.0005588531494140625
Layer 1 diff: 0.000522613525390625
Layer 2 diff: 0.000514984130859375
Layer 3 diff: 0.0005235671997070312
Layer 4 diff: 0.0005626678466796875
Layer 5 diff: 0.0007305145263671875
Layer 6 diff: 0.0007200241088867188
Layer 7 diff: 0.0028638839721679688
Layer 8 diff: 0.003765106201171875
Layer 9 diff: 0.00331878662109375
Layer 10 diff: 0.0035552978515625
Layer 11 diff: 0.003631591796875
Layer 12 diff: 0.0038299560546875
Layer 13 diff: 0.00469207763671875
Layer 14 diff: 0.004852294921875
Layer 15 diff: 0.0047607421875
Layer 16 diff: 0.0048065185546875
Layer 17 diff: 0.0046844482421875
Layer 18 diff: 0.0043487548828125
Layer 19 diff: 0.0041351318359375
Layer 20 diff: 0.00341796875
Layer 21 diff: 0.00220489501953125
Layer 22 diff: 0.01416015625
Layer 23 diff: 0.091796875
Layer 24 diff: -0.13190405070781708 |
@sanchit-gandhi , you can also get the same high precision when adding the following flag:
This way the code is still run on TPU but now we get a much higher precision which is as good as the one for CPU: Feature diff 1.9788742065429688e-05
Feature diff 2.3126602172851562e-05
Layer 0 diff: 0.000499725341796875
Layer 1 diff: 0.00046062469482421875
Layer 2 diff: 0.0004367828369140625
Layer 3 diff: 0.00043582916259765625
Layer 4 diff: 0.0004558563232421875
Layer 5 diff: 0.00041866302490234375
Layer 6 diff: 0.000457763671875
Layer 7 diff: 0.00038313865661621094
Layer 8 diff: 0.0006923675537109375
Layer 9 diff: 0.00075531005859375
Layer 10 diff: 0.00160980224609375
Layer 11 diff: 0.0029754638671875
Layer 12 diff: 0.003421783447265625
Layer 13 diff: 0.0038909912109375
Layer 14 diff: 0.003070831298828125
Layer 15 diff: 0.003139495849609375
Layer 16 diff: 0.0022563934326171875
Layer 17 diff: 0.002674102783203125
Layer 18 diff: 0.0031871795654296875
Layer 19 diff: 0.00322723388671875
Layer 20 diff: 0.003253936767578125
Layer 21 diff: 0.004852294921875
Layer 22 diff: 0.05615234375
Layer 23 diff: 0.228515625
Layer 24 diff: -0.1278175562620163
Check if shapes are equal
Shape PyTorch torch.Size([2, 15, 1024]) | Shape Flax (2, 15, 1024)
Check if output values are equal
Diff 5.14984130859375e-05) |
Think we can close this for now - maybe jax will soon change it's default TPU precision behavior |
Info is taken from this PR btw: jax-ml/jax#6143 |
I noticed an absence of PT-Flax cross-tests in the test file for the FlaxWav2Vec2 model. I was wondering if at any point a layer-by-layer comparison of the PT-Flax hidden-states was performed? I was having a look through for myself and found that whilst the final hidden-states agreed to within the prescribed 4e-2 threshold (see final line of output), many of the intermediate states did not. Running the script reveals this:
Output:
Environment:
The text was updated successfully, but these errors were encountered: