Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow model to specify padding value instead of using first element #1279

Closed
oleg-yaroshevskiy opened this issue Mar 14, 2019 · 10 comments
Closed
Assignees
Labels
needs prio stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response type:feature

Comments

@oleg-yaroshevskiy
Copy link

oleg-yaroshevskiy commented Mar 14, 2019

Hi,
I know, here it's zero tolerance to questions but after some testing I believe this is a bug. Can't be reproduced without using pad_variable_length_inputs: true flag. Will close this one from t2t asap.


For long time I'm trying to solve a tf serving problem for transformer nmt model inference. So let's assume I create simple signature:

inputs = tf.placeholder(dtype=tf.int64, shape=[None, None], name="input_logits")
features = {"inputs": standardize_shapes(inputs)}

        return tf.estimator.export.ServingInputReceiver(
            features=features, receiver_tensors=inputs)

and everything works well.
Then I want to enable batching and due to different sequence lengths I set
pad_variable_length_inputs: true in batching.config file.
Thats where the problem starts. Serving returns garbage for short sequences f.e:

first (1, 18) {'instances': [{'inputs': [336, 201, 506, 26, 1902, 2339, 6677, 13748, 14, 55, 11864, 3258, 7380, 1368, 2770, 26172, 662, 1]}]} enfr {
 >>>   "predictions": [
        {
            "outputs": [367, 169, 1306, 2, 37, 10043, 301, 19606, 4, 19005, 151, 85, 15888, 11427, 20653, 5131, 21936, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
            "scores": -9.25369
        }
    ]
}

second (1, 3) {'instances': [{'inputs': [1544, 657, 1]}]} enfr {
>>>    "predictions": [
        {
            "outputs": [2865, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 1141, 2865, 2865, 2865, 2865, 2865, 2865, 2865, 2865, 2865, 2865, 2865, 2865, 2865, 1],
            "scores": -28.8886
        }
    ]
}

as we can see a longer one was predicted well and for short one it messed up with repeated tokens. This can't be reproduced for single element inference or out of tf serving env.

Example after decoding:

['Made from steel']->['MadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeM
adeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMadeMa
deMadeMadeMadeMadeMademade.']

Any idea? Is it related to tf-serving padding? Or not the best input signature?
Give me a clue,
thanks.

Environment information

tf serving 1.13.0
tensor2tensor==1.6.6

@oleg-yaroshevskiy
Copy link
Author

I've located the problem in batching_util.cc:

    T pad_value(input.flat<T>()(0));  // using existing values in padding
    output->tensor<T, num_dims>() = inputs.pad(padding, pad_value);

Looks like we pad a tensor with first value of flatten tensor (or?) and not with a zero and that probably might work for image processing but not good for sequence batching.

Please take a look.

@christisg
Copy link
Member

Thank you for your feedback.

Could you elaborate on why padding with the first value (and not zero) doesn't work with sequence batching?

We chose to use a real element instead of zero precisely to avoid assigning a specific meaning to zero.

@gowthamkpr
Copy link

@oleg-yaroshevskiy Closing this issue as this has been in "awaiting response" status for more than a month. Please add additional comments and we can open the issue again. Thanks!

@maximedb
Copy link

Is there anyway to specify the padding value in the serving parameters ? Padding with zero is quite common.

@oleg-yaroshevskiy
Copy link
Author

Could you elaborate on why padding with the first value (and not zero) doesn't work with sequence batching?

That might be related to tensor2tensor transformer implementation, as I'm not sure if they do any masking in the encoder. So imagine batch size 2 input will look like:
[5, 7, 1, 5, 5, 5]
[4, 7, 9, 8, 6, 1]
and doing self-attention over whole sequence will lead to bad encoder representations.

In huggingface transformers they have optional masking argument.
I believe zero paddings is a common pad value representation in nlp nowadays.

@misterpeddy misterpeddy changed the title Batching padding issue Allow model to specify padding value instead of using first element Mar 2, 2020
@misterpeddy
Copy link
Member

Reopening and marking as a feature request. If anyone wants to contribute to the project to fix this please feel free to!

@singhniraj08
Copy link

@oleg-yaroshevskiy,

Are you still looking for a resolution? We are planning on prioritising the issues based on the community interests. Please let us know if this issue still persists with the latest TF Serving 2.12.1 release so that we can work on fixing it.
Thank you for your contributions.

@github-actions
Copy link

This issue has been marked stale because it has no recent activity since 7 days. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label May 26, 2023
@github-actions
Copy link

github-actions bot commented Jun 2, 2023

This issue was closed due to lack of activity after being marked stale for past 7 days.

@github-actions github-actions bot closed this as completed Jun 2, 2023
@yzhaoa
Copy link

yzhaoa commented Jan 19, 2024

I think the correct way to work around this issue is to encode the var-length features using RaggedTensor.from_tensor(tensor, lengths) format, i.e. with two input tensors:

  1. input: a variable-length (None, None)-shaped tensor, including with 0 paddings if you want.
  2. input_lengths: a scalar (before batching) (None,)-shaped int64 tensor containing the length of each input row.

The original variable-lengthed RaggedTensor can then be recovered by calling RaggedTensor.from_tensor(input, lengths=input_lengths).

This is safe against whatever padding algorithm Tensorflow Serving uses because the lengths argument will truncate off all the padding values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs prio stale This label marks the issue/pr stale - to be closed automatically if no activity stat:awaiting response type:feature
Projects
None yet
Development

No branches or pull requests

8 participants