Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TF: XLA model output differs when certain outputs are passed #16838

Closed
gante opened this issue Apr 19, 2022 · 24 comments · Fixed by #16892
Closed

TF: XLA model output differs when certain outputs are passed #16838

gante opened this issue Apr 19, 2022 · 24 comments · Fixed by #16892
Assignees
Labels

Comments

@gante
Copy link
Member

gante commented Apr 19, 2022

Depending on the passed inputs, the output of an XLA-compiled model may significantly differ from its non-XLA counterpart. This suggests we should add tests for XLA-output equivalence, just like we do with e.g. PT-TF, as it is not guaranteed.

At the moment, this blocks further developments in generate() (can't reliably reproduce non-XLA results with XLA). I will assess this problem for T5 (first model where I've noticed this), then check whether it is present for other key models, and finally add equivalence tests.

cc @patrickvonplaten @Rocketknight1 (feel free to pitch in with ideas and suggestions)


Example for reproducibility (updated: assert diff < x-> print diff):

import tensorflow as tf
from transformers import TFT5ForConditionalGeneration, T5Tokenizer


tokenizer = T5Tokenizer.from_pretrained("t5-base")
model = TFT5ForConditionalGeneration.from_pretrained("t5-base")
model_xla = tf.function(model, jit_compile=True)
pad_token_id = model.config.pad_token_id

sentence_1 = "Translate English to German: I have a cat, two dogs, three horses, and four birds."
sentence_2 = "Translate English to German: I have a cat, two dogs, and three horses."

ids_single = tokenizer([sentence_1], return_tensors="tf", padding=True).input_ids
decoder_ids_single = tf.zeros((1, 1), dtype=tf.int32)
attention_single = tf.cast(tf.math.not_equal(ids_single, pad_token_id), dtype=tf.int32)  # as computed in generate

ids_pair = tokenizer([sentence_1, sentence_2], return_tensors="tf", padding=True).input_ids
decoder_ids_pair = tf.zeros((2, 1), dtype=tf.int32)
attention_pair = tf.cast(tf.math.not_equal(ids_pair, pad_token_id), dtype=tf.int32)  # as computed in generate

# case 1: with batch size = 1 and NO attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_single, decoder_input_ids=decoder_ids_single)
outputs_xla = model_xla(input_ids=ids_single, decoder_input_ids=decoder_ids_single)
print(tf.math.reduce_max(tf.math.abs(outputs.logits - outputs_xla.logits)).numpy())

# case 2: with batch size > 1 and NO attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_pair, decoder_input_ids=decoder_ids_pair)
outputs_xla = model_xla(input_ids=ids_pair, decoder_input_ids=decoder_ids_pair)
print(tf.math.reduce_max(tf.math.abs(outputs.logits - outputs_xla.logits)).numpy())

# case 3 FAILING: with batch size = 1 and attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single)
outputs_xla = model_xla(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single)
print(tf.math.reduce_max(tf.math.abs(outputs.logits - outputs_xla.logits)).numpy())

# case 4 FAILING: with batch size < 1 and attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_pair, decoder_input_ids=decoder_ids_pair, attention_mask=attention_pair)
outputs_xla = model_xla(input_ids=ids_pair, decoder_input_ids=decoder_ids_pair, attention_mask=attention_pair)
print(tf.math.reduce_max(tf.math.abs(outputs.logits - outputs_xla.logits)).numpy())
@gante gante added the bug label Apr 19, 2022
@gante gante self-assigned this Apr 19, 2022
@patrickvonplaten
Copy link
Contributor

How significant are the differences? Would it pass with 1e-1?

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 19, 2022

How significant are the differences? Would it pass with 1e-1?

Tried twice, for # case 3 FAILING: with batch size = 1 and attention mask, the diffs are as large as 12.3...

@Rocketknight1
Copy link
Member

Just tried it here.

On CPU:

Test number Max error
1 1.5258789e-05
2 1.335144e-05
3 12.374499
4 12.5263195

On GPU (3090, using TensorFloat32):

Test number Max error
1 0.0053577423
2 0.0062656403
3 0.0053577423
4 0.004333496

@Rocketknight1
Copy link
Member

My best guess is that there are two separate issues:

  1. XLA on CPU is buggy (I believe this isn't an intended use-case for XLA anyway, because kernel fusion doesn't make much difference there)
  2. GPUs, especially when using tensor cores/TensorFloat32, have somewhat worse precision than CPU, but it's fine if we use a larger tolerance.

@patrickvonplaten
Copy link
Contributor

Wait so XLA works on GPU, but not on CPU? That's very weird

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 19, 2022

@gante Probably the following code and outputs could make you spot the places more easily.

There is one thing: XLA model doesn't return hidden_states, attentions, etc., even if I specify to output them.
Therefore I couldn't compare them.

(able to get the full outputs with a ugly hack)

Code

import numpy as np
import tensorflow as tf
from transformers import TFT5Model, T5Tokenizer
from transformers.utils.generic import ModelOutput

checkpoint = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = TFT5Model.from_pretrained(checkpoint)

# Ugly hack to retrun all outputs
model.config.output_hidden_states = True
model.config.output_attentions = True
model = TFT5Model.from_pretrained(checkpoint, config=model.config)

model_xla = tf.function(model, jit_compile=True)

# tokenizer.pad_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id

sentence_1 = "Translate English to German: I have a cat, two dogs, three horses, and four birds."
sentence_2 = "Translate English to German: I have a cat, two dogs, and three horses."

ids_single = tokenizer([sentence_1], return_tensors="tf", padding=True).input_ids
decoder_ids_single = tf.zeros((1, 1), dtype=tf.int32)
# attention_single = tf.cast(tf.math.not_equal(ids_single, pad_token_id), dtype=tf.int32)  # as computed in generate
attention_single = tf.cast(tf.ones_like(ids_single), dtype=tf.int32)  # as computed in generate

ids_pair = tokenizer([sentence_1, sentence_2], return_tensors="tf", padding=True).input_ids
decoder_ids_pair = tf.zeros((2, 1), dtype=tf.int32)
# attention_pair = tf.cast(tf.math.not_equal(ids_pair, pad_token_id), dtype=tf.int32)  # as computed in generate
attention_pair = tf.cast(tf.ones_like(ids_pair), dtype=tf.int32)

# case 3 FAILING: with batch size = 1 and attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, output_hidden_states=True, output_attentions=True)
outputs_xla = model_xla(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, output_hidden_states=True, output_attentions=True)


# Please ignore the bad naming - this is just a quick copy from the test script
def check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):

    # Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
    if isinstance(tf_outputs, ModelOutput):

        tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
        pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])

        # (if without the hack) XLA models don't return full outputs at this moment ... need to ignore them at this moment
        # keys = tuple(set(tf_keys).intersection(pt_keys))
        # tf_outputs = tuple([tf_outputs[k] for k in keys])
        # pt_outputs = tuple([pt_outputs[k] for k in keys])

        # convert to the case of `tuple`
        # appending each key to the current (string) `names`
        attributes = tuple([f"{name}.{k}" for k in tf_keys])
        check_pt_tf_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes)

    # Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
    elif type(tf_outputs) in [tuple, list]:

        if attributes is not None:
            # case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
            pass
        else:
            # case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
            attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])

        for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
            check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)

    elif isinstance(tf_outputs, tf.Tensor):
        tf_outputs = tf_outputs.numpy()
        pt_outputs = pt_outputs.numpy()

        # deal with NumPy's scalars to make replacing nan values by 0 work.
        if np.isscalar(tf_outputs):
            tf_outputs = np.array([tf_outputs])
            pt_outputs = np.array([pt_outputs])

        tf_nans = np.isnan(tf_outputs)
        pt_nans = np.isnan(pt_outputs)

        pt_outputs[tf_nans] = 0
        tf_outputs[tf_nans] = 0
        pt_outputs[pt_nans] = 0
        tf_outputs[pt_nans] = 0

        max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
        print(f"{name}: {max_diff}")
    else:
        raise ValueError(
            f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead.")

check_pt_tf_outputs(outputs, outputs_xla, model_class=TFT5Model)

Outputs

outputs.last_hidden_state: 2.800762176513672
outputs.past_key_values_0_0: 4.291534423828125e-06
outputs.past_key_values_0_1: 1.0728836059570312e-06
outputs.past_key_values_0_2: 3.4570693969726562e-06
outputs.past_key_values_0_3: 3.337860107421875e-06
outputs.past_key_values_1_0: 0.4949379563331604
outputs.past_key_values_1_1: 0.8448842763900757
outputs.past_key_values_1_2: 4.291534423828125e-06
outputs.past_key_values_1_3: 4.887580871582031e-06
outputs.past_key_values_2_0: 0.4911351203918457
outputs.past_key_values_2_1: 0.5065852403640747
outputs.past_key_values_2_2: 4.76837158203125e-06
outputs.past_key_values_2_3: 5.7220458984375e-06
outputs.past_key_values_3_0: 0.47093653678894043
outputs.past_key_values_3_1: 0.5624567270278931
outputs.past_key_values_3_2: 4.410743713378906e-06
outputs.past_key_values_3_3: 5.9604644775390625e-06
outputs.past_key_values_4_0: 0.775518536567688
outputs.past_key_values_4_1: 0.934751570224762
outputs.past_key_values_4_2: 5.7220458984375e-06
outputs.past_key_values_4_3: 7.152557373046875e-06
outputs.past_key_values_5_0: 1.0620229244232178
outputs.past_key_values_5_1: 1.1955945491790771
outputs.past_key_values_5_2: 5.7220458984375e-06
outputs.past_key_values_5_3: 9.059906005859375e-06
outputs.past_key_values_6_0: 1.5020784139633179
outputs.past_key_values_6_1: 1.768876552581787
outputs.past_key_values_6_2: 6.4373016357421875e-06
outputs.past_key_values_6_3: 8.344650268554688e-06
outputs.past_key_values_7_0: 1.9831377267837524
outputs.past_key_values_7_1: 1.7343039512634277
outputs.past_key_values_7_2: 6.67572021484375e-06
outputs.past_key_values_7_3: 1.0251998901367188e-05
outputs.past_key_values_8_0: 2.3230268955230713
outputs.past_key_values_8_1: 2.937762498855591
outputs.past_key_values_8_2: 5.7220458984375e-06
outputs.past_key_values_8_3: 9.775161743164062e-06
outputs.past_key_values_9_0: 2.8203392028808594
outputs.past_key_values_9_1: 5.384043216705322
outputs.past_key_values_9_2: 5.9604644775390625e-06
outputs.past_key_values_9_3: 1.33514404296875e-05
outputs.past_key_values_10_0: 4.303163528442383
outputs.past_key_values_10_1: 10.02894401550293
outputs.past_key_values_10_2: 6.198883056640625e-06
outputs.past_key_values_10_3: 1.430511474609375e-05
outputs.past_key_values_11_0: 4.163003921508789
outputs.past_key_values_11_1: 7.657519817352295
outputs.past_key_values_11_2: 4.76837158203125e-06
outputs.past_key_values_11_3: 1.9073486328125e-05
outputs.decoder_hidden_states_0: 0.0
outputs.decoder_hidden_states_1: 2151.3359375
outputs.decoder_hidden_states_2: 2724.79736328125
outputs.decoder_hidden_states_3: 4147.70751953125
outputs.decoder_hidden_states_4: 6162.63720703125
outputs.decoder_hidden_states_5: 7066.3046875
outputs.decoder_hidden_states_6: 7329.43603515625
outputs.decoder_hidden_states_7: 7471.92333984375
outputs.decoder_hidden_states_8: 7749.91162109375
outputs.decoder_hidden_states_9: 8324.51953125
outputs.decoder_hidden_states_10: 8609.3359375
outputs.decoder_hidden_states_11: 7732.30224609375
outputs.decoder_hidden_states_12: 2.800762176513672
outputs.decoder_attentions_0: 0.0
outputs.decoder_attentions_1: 0.0
outputs.decoder_attentions_2: 0.0
outputs.decoder_attentions_3: 0.0
outputs.decoder_attentions_4: 0.0
outputs.decoder_attentions_5: 0.0
outputs.decoder_attentions_6: 0.0
outputs.decoder_attentions_7: 0.0
outputs.decoder_attentions_8: 0.0
outputs.decoder_attentions_9: 0.0
outputs.decoder_attentions_10: 0.0
outputs.decoder_attentions_11: 0.0
outputs.cross_attentions_0: 0.9293187856674194
outputs.cross_attentions_1: 0.8967262506484985
outputs.cross_attentions_2: 0.7246492505073547
outputs.cross_attentions_3: 0.9164008498191833
outputs.cross_attentions_4: 0.8164070248603821
outputs.cross_attentions_5: 0.7364302277565002
outputs.cross_attentions_6: 0.6568543314933777
outputs.cross_attentions_7: 0.6275004744529724
outputs.cross_attentions_8: 0.6810514330863953
outputs.cross_attentions_9: 0.631909966468811
outputs.cross_attentions_10: 0.4159456491470337
outputs.cross_attentions_11: 0.39396628737449646
outputs.encoder_last_hidden_state: 5.960464477539062e-07
outputs.encoder_hidden_states_0: 0.0
outputs.encoder_hidden_states_1: 0.000244140625
outputs.encoder_hidden_states_2: 0.0003662109375
outputs.encoder_hidden_states_3: 0.00048828125
outputs.encoder_hidden_states_4: 0.00048828125
outputs.encoder_hidden_states_5: 0.00048828125
outputs.encoder_hidden_states_6: 0.0009765625
outputs.encoder_hidden_states_7: 0.00048828125
outputs.encoder_hidden_states_8: 0.001953125
outputs.encoder_hidden_states_9: 0.001953125
outputs.encoder_hidden_states_10: 0.0078125
outputs.encoder_hidden_states_11: 0.0078125
outputs.encoder_hidden_states_12: 5.960464477539062e-07
outputs.encoder_attentions_0: 5.066394805908203e-07
outputs.encoder_attentions_1: 5.364418029785156e-07
outputs.encoder_attentions_2: 7.152557373046875e-07
outputs.encoder_attentions_3: 5.960464477539062e-07
outputs.encoder_attentions_4: 5.662441253662109e-07
outputs.encoder_attentions_5: 5.960464477539062e-07
outputs.encoder_attentions_6: 5.364418029785156e-07
outputs.encoder_attentions_7: 6.258487701416016e-07
outputs.encoder_attentions_8: 8.642673492431641e-07
outputs.encoder_attentions_9: 5.960464477539062e-07
outputs.encoder_attentions_10: 7.152557373046875e-07
outputs.encoder_attentions_11: 5.960464477539062e-07

@gante
Copy link
Member Author

gante commented Apr 19, 2022

Thank you for your suggestions, you have solved the puzzle 🙏 The winning suggestion award goes to @Rocketknight1 -- XLA on CPU is indeed buggy.

I've spun up an Nvidia T4 ( = no tf32 format) and got an error < 1e-5 for all cases. tf32 does make the difference slightly bigger, the having a GPU is the main difference. It has also passed the generate cases that were failing on XLA with CPU (see below).

As a result of this thread, I was thinking of:

  1. Raising an exception in TF generate when use_xla is True and there are no GPU devices -- @patrickvonplaten WDYT?
  2. Pushing all XLA tests to GPU;
  3. Opening an issue in the TensorFlow repo -- @Rocketknight1 do you think they will care? 🤔

Greedy search translating correctly with GPU:
Screenshot 2022-04-19 at 20 12 38

Greedy search failing with CPU:
Screenshot 2022-04-19 at 20 19 11

Sample behaving okay with GPU (sampling 10 outputs for the first sentence input):
Screenshot 2022-04-19 at 20 15 11

@Rocketknight1
Copy link
Member

@gante I think they certainly would be interested, but we'd have to localize the bug a little more! If you could fix an input and make a minimal single module that showed the buggy behaviour, you should definitely report that upstream. I totally understand if that's not a priority with everything else on your plate, though!

@patrickvonplaten
Copy link
Contributor

Cool, great job guys in locating the error!

I don't think it's a good idea to to raise an error / exception if XLA is enabled on CPU. XLA should work on CPU - why wouldn't it? To me this clearly looks like a TF bug and quite a big one actually.

IMO, lots of people debug their code on CPU in XLA so I think it is pretty important that it works on CPU.
E.g. some generate processors work differently in XLA and it's important to verify as easy as possible (on CPU) that new XLA code works as expected.
Also note that XLA works inherently differently to non-XLA (static shapes, different computation operations). This should also be easy to test/debug on CPU. E.g. to me it's a non-negligible use case to check if your code leads to constant recompilation or not on CPU.

Also we need to test XLA on CPU as well so that it runs on circle ci IMO

cc @sanchit-gandhi, who's is working quite a bit with XLA at the moment.

@patrickvonplaten
Copy link
Contributor

It shouldn't be too difficult to locate where the difference is coming from since we know that without attention_mask is works no?

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 20, 2022

Change this line

weights = tf.nn.softmax(scores, axis=-1) # (batch_size, n_heads, query_length, key_length)

to

weights = tf.math.softmax(scores + 1.0, axis=-1)

will solve the problem. This gives the same weights (on CPU + XLA) as the ones computed on GPU machine (both non-XLA & XLA).

I tested this trick with @gante code samples.

I also looked the expected values for weights using the code below.
The expected values looks like

[8.03906238e-04 4.91665269e-04 6.60848498e-01 7.20867813e-02,  ...]

without this, on CPU + XLA, we get

[0.04347826, 0.04347826, 0.04347826, 0.04347826,  ...]

I guess some trick (about numerical stability of Softmax) is not done for XLA + CPU.

The code I use

import numpy as np
import tensorflow as tf
from transformers import TFT5Model, T5Tokenizer
from transformers.utils.generic import ModelOutput

checkpoint = "t5-base"

tokenizer = T5Tokenizer.from_pretrained(checkpoint)
model = TFT5Model.from_pretrained(checkpoint)

# Ugly hack to retrun all outputs
model.config.output_hidden_states = True
model.config.output_attentions = True
model = TFT5Model.from_pretrained(checkpoint, config=model.config)

model_xla = tf.function(model, jit_compile=True)

# tokenizer.pad_token_id = tokenizer.eos_token_id
pad_token_id = tokenizer.pad_token_id

sentence_1 = "I have a cat, two dogs"
sentence_2 = "I have a cat"

sentence_1 = "Translate English to German: I have a cat, two dogs, three horses, and four birds."
sentence_2 = "Translate English to German: I have a cat, two dogs, and three horses."

ids_single = tokenizer([sentence_1], return_tensors="tf", padding=True).input_ids
decoder_ids_single = tf.zeros((1, 1), dtype=tf.int32)
# attention_single = tf.cast(tf.math.not_equal(ids_single, pad_token_id), dtype=tf.int32)  # as computed in generate
attention_single = tf.cast(tf.ones_like(ids_single), dtype=tf.int32)  # as computed in generate
decoder_attention_single = tf.cast(tf.ones_like(decoder_ids_single), dtype=tf.int32)  # as computed in generate


ids_pair = tokenizer([sentence_1, sentence_2], return_tensors="tf", padding=True).input_ids
decoder_ids_pair = tf.zeros((2, 1), dtype=tf.int32)
# attention_pair = tf.cast(tf.math.not_equal(ids_pair, pad_token_id), dtype=tf.int32)  # as computed in generate
attention_pair = tf.cast(tf.ones_like(ids_pair), dtype=tf.int32)
decoder_attention_pair = tf.cast(tf.ones_like(decoder_ids_pair), dtype=tf.int32)  # as computed in generate



# case 3 FAILING: with batch size = 1 and attention mask, XLA and non-XLA match
outputs = model(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, decoder_attention_mask=decoder_attention_single, output_hidden_states=True, output_attentions=True)
outputs_xla = model_xla(input_ids=ids_single, decoder_input_ids=decoder_ids_single, attention_mask=attention_single, decoder_attention_mask=decoder_attention_single, output_hidden_states=True, output_attentions=True)

@sanchit-gandhi
Copy link
Contributor

sanchit-gandhi commented Apr 20, 2022

As @patrickvonplaten mentioned, it's pretty imperative to have XLA working on CPU for any kind of debugging - there are all sorts of debugging methods that pull values back to the host and perform checks on an op-by-op basis (see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#nans). These are pretty crucial for understanding the inner-workings of a compiled function that you wouldn't otherwise see if running XLA purely on an accelerator.

Also running JAX/Flax on CPU, the floating-point precision of internal computations used in TPU matrix multiplications and convolutions is always highest. When you move to a TPU, the floating-point precision is lowered by default. We need to be able to test our code on CPU to run at highest precision, especially for any sort of PT-Flax equivalence tests (see #15754). I'm not familiar with how TF treats matmul precisions, but it's these sorts of considerations that mean running XLA on CPU is pretty essential!

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Apr 20, 2022

Great find @ydshieh! We should talk to the TF guys about this no?

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 20, 2022

(sorry, accidentally edited @patrickvonplaten above comment)

Yes. Let's extract (or create) some inputs , and reproduce the issue with only the softmax part.

@gante
Copy link
Member Author

gante commented Apr 20, 2022

This is great @ydshieh 🔥 I'm going to build a toy example and open an issue in TF, linking to this thread.

@gante
Copy link
Member Author

gante commented Apr 20, 2022

Pinned the problem: it is due to the softmax with numerically masked (= large negative) inputs, on XLA+CPU. I've opened an issue on TensorFlow (as backlinked above), where it contains a simple reproducible example.

Meanwhile, avoid XLA+CPU :D

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 20, 2022

If this would require long time for TF team to fix, we might use a wrapped version of tf.nn.softmax.
I don't like much this approach though, just an option.

@patrickvonplaten
Copy link
Contributor

I'm sure somewhere hidden there is a tf softmax that is stable on XLA. We could then create a custom def softmax(...) in https://github.com/huggingface/transformers/blob/main/src/transformers/tf_utils.py that wraps tf.nn.softmax(...) in non-XLA and a stable version for XLA

@gante
Copy link
Member Author

gante commented Apr 22, 2022

It should work! The toy example below adds said wrapper (with +1), and both CPU and GPU XLA have a difference of ~1e-8 to its non-XLA version 👍 It also confirms that the stable softmax outputs the same as the original softmax.

import tensorflow as tf


LARGE_PENALTY = -1e9


def stable_softmax(x):
    return tf.nn.softmax(x + 1)


def masked_softmax(x, boolean_mask):
    numerical_mask = (1. - tf.cast(boolean_mask, dtype=tf.float32)) * LARGE_PENALTY
    masked_x = x + numerical_mask
    return stable_softmax(masked_x)


xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
x = tf.random.normal((1, 10))

# same outcome regardless of the boolean mask here
boolean_mask = tf.convert_to_tensor([[1] * 9 + [0] * 1], dtype=tf.int32)

# passes
numerical_mask = (1. - tf.cast(boolean_mask, dtype=tf.float32)) * LARGE_PENALTY
masked_x = x + numerical_mask
xla_out = xla_stable_softmax(masked_x)
out = stable_softmax(masked_x)
print(tf.math.reduce_max(tf.math.abs(xla_out - out)).numpy())
assert tf.experimental.numpy.allclose(xla_out, out)

# The stable softmax has the same output as the original fn
unstable_out = tf.nn.softmax(masked_x)
print(tf.math.reduce_max(tf.math.abs(unstable_out - out)).numpy())
assert tf.experimental.numpy.allclose(unstable_out, out)

# passes (with the + 1 in the softmax)
xla_out = xla_masked_softmax(x, boolean_mask)
out = masked_softmax(x, boolean_mask)
print(tf.math.reduce_max(tf.math.abs(xla_out - out)).numpy())
assert tf.experimental.numpy.allclose(xla_out, out)

Opening a PR soon with this temporary fix, and will replace ALL softmax calls with this wrapped version.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

Could we use the following instead 🙏
Didn't work! Very strange :(

tf.nn.softmax(x - tf.math.reduce_max(x, axis=-1, keepdims=True), axis=-1)

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

LARGE_PENALTY

The problem with +1 is that it won't work in general (despite it works in our 2 cases, but I don't know why )

@gante
Copy link
Member Author

gante commented Apr 22, 2022

@ydshieh I agree that it should be more stable numerically, but I'd rather add a fixed constant. reduce_max would add extra computational requirements (reduce operations are not lightweight) and, if it does fix numerical stability issues, it could be introducing a drift between the model at train time and at inference time.

Perhaps not 1, but a very small constant like 1e-9 (which also works in this toy example)

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 22, 2022

OK, good point @gante . And my suggestion didn't work well even with your code above! So good for me to use a constant.

@Rocketknight1
Copy link
Member

From further experimentation, I think the reason the small constant works has nothing to do with numerical stability - I think inserting an addition just changes the particular compiled program that XLA generates, and so avoids this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants