Masked tensor support as CompositeTensors? #18413
Labels
keras-team-review-pending
Pending review by a Keras team member.
type:feature
The user is asking for a new feature.
type:support
User is asking for help / asking an implementation question. Stackoverflow would be better suited.
I'm not sure this is the correct place for this, and I'm even less confident that it's my place to say any of it, but having spent a good deal of time playing with masked tensors I figured I'd share my thoughts/frustrations in the hope that some aspects of the design could be reconsidered for the next major version bump. Apologies in advance for the long read.
TL;DR
Masking in metrics
abusable - both intentionally and accidentally
Let's say I find a training script from a keras-based SOTA image classification model. I can almost guarantee you I can take that script and, by a minor tweak to the model architecture and without any retraining or manual weither adjustment, get 100% top-1 accuracy. How? By adding a masking layer to the final predictions such that all but the most confident prediction are masked, the metrics will simply ignore all the rest.
This seems absurd to me. Now maybe you're thinking, "That's a bit far-fetched, we shouldn't have to program such that our code can't intentionally be abused this way," and to some extent I'd agree. But mask propagation is incredibly easy to implement incorrectly, and fails silently when such a mistake is made. These silent failurs can be highly misleading.
Example: I recently wanted to see how well a very basic res-net based architecture would do on common NLP tasks. I took one of the official keras/keras-nlp examples and replaced the transformer model with a basic resnet architecture, and low-and-behold accuracy improved significantly. Does this mean resnets are the future of NLP?
No. At least, that's not something I can say from my experiments. Why not? Well, my resnet had residual connections that I made using
x + y
instead ofAdd()((x, y))
. This (silently) destroyed the mask, so the loss and metrics included contributions from padding. Unsurprisingly, it's very easy for the model to learn that only PAD tokens appear after a PAD token, and when your dataset has a non-trivial amount of padding, this significantly improves metrics.Not used correctly in any keras NLP example (from what I understand)
I accept that masking can be relevant in more than just NLP cases, but that seems to be the most commonly used case in examples. AFAIK there's not a single NLP example that does this correctly however. To illustrate, consider a sequence generation example, where the labels is a padded version of "hello world", padded to token length 6. We'll assume a causal decoder that takes inputs and evaluates based on a shifted version. The inputs should be masked to exclude padding tokens, as should the targets - but I'm yet to see an example which applies this shift to the masks
Note the mismatch between the input and target mask at the 4th token. As a result, every example I've seen gives credit to a model for learning that every "(END)" token is followed by a "(PAD)" token. In my opinion, this is not a meaningful thing to learn, and serves only to artificially inflate metric scores, particularly for short sequences.
To further illustrate the resnet example I described above, the actual mask my metrics used after accidentaly using
+
instead ofAdd()
was:In this case, a model can achieve 50% accuracy by learning nothing more than
PAD
always followsEND
/PAD
tokens (and this can be improved my increasing the maximum sequence length/amount of padding).Metrics issues: resolution
Both of these issues can be resolved using
sample_weights
instead of masking. I'm unaware of any case where masking can do anything that an appropriatesample_weights
couldn't, and implementation during dataset construction is straight forward:On this point, my question is: if
sample_weights
can do everythingmasking
, isn't abusable and is (in my opinion) less confusing, can we completely remove consideration of masks from metrics?Masking in Models
Despite the above, I accept there's a valid place fo masking in model construction, be in NLP, GNNs, point cloud networks or other use cases. In tensorflow there are
RaggedTensor
s that can potentially be used for these representations, but masked tensors are often better because they allow for constant shape which can have a significant effect on performance especially with XLA compilation.That said, having played around with supporting masks in my model, I find the developer experience unintuitive, error prone and difficult to debug. Supporting them in a layer involves setting a public attribute, changing function signatures for
call
,build
andcompute_output_[spec|shape]
and potentially adding a newcompute_output_mask
. If you make a mistake in any of these, there is generally no error and the fallback behaviour is almost never what's intended. At best you get a warning (not always), but that warning doesn't include a stack trace so all you know is you've made a mistake somewhere. Outside of layer implementations there's not even a way to check whether a tensor is masked or not using the public API, and even unit tests depend on private API use to test some functionality.The fact that there are numerous bugs related to masking currently open makes me believe I'm not the only one struggling.
This is completely different to my experience with composite tensors in tensorflow. If I want a layer to support a
RaggedTensor
, I simply make sure mycall
method supportsRaggedTensor
inputs. Same forSparseTensor
s. If I try and use aRaggedTensor
input in an unsupported way, it throws an error and I know exactly what I need to fix. If I want the output to have a different structure, I return a different structure. If I want to check if a tensor is a regular tensor or a ragged/sparse tensor when building models using the functional API I just look at it's type. There is no analagoussupports_masking
attribute to set orcompute_output_mask
method to implement, and the number and names of arguments in other methods remain unchanged.Now I appreciate that
keras_core
does not currently support composite tensors, but I'm very much hoping that's on the cards, at least in a limited capacity with a few concrete classes (e.g. SparseTensor, RaggedTensor) or a more general framework for users to define their own composite structures like tensorflow and jax have.My questions on this point are:
Relevant issues:
Merge
(e.g.Add
,Maximum
) layers do not propagate masks #18416The text was updated successfully, but these errors were encountered: