Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] [Pipe] supporting None and non-Tensors in forward's input/output #50693

Closed
wants to merge 1 commit into from

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Jan 18, 2021

I was trying to convert transformers t5 to use the new Pipe and I encountered a dozen of input args that some are Bools and some are optionally None, yet others tuples within tuples. You can see an example of the complex inputs it gets:

https://github.com/huggingface/transformers/blob/357fb1c5d8b6a16f042f9b504f023d935086e8e5/src/transformers/models/t5/modeling_t5.py#L600-L612

Currently Pipe only handles input/output which is either a single Tensor or a tuple of Tensors, so we can't use it as it stands now.

A user needs to be able to pass as a part of the input and output tuple:

  1. None - in transformers these are passed to forward when the thing needs to be optionally generated by downstream layers, but is a Tensor at other times
  2. Bools - these are flow control flags that the user's code needs to pass forward to the model. These aren't available during model's init.
  3. tuples of Tensors within a tuple - there are several aggregators that append to a tuple, which is then returned as part of outputs - so again, some recursive introspection and to() is needed see below.

Bottom line, a whole bunch of variations of variables might be needed to be passed to and from the forward function in the pipe chain. As long as these structures can be traversed and switched .to() the right device and spliced where this is needed, any type of variable should be supported. Please, see below.

I made microbatch.py work with Nones - this PR, but stumbled upon errors in C++ implementation:

TypeError: WaitBackward.forward: expected Tensor or tuple of Tensor (got NoneType) for return value 3

I'd love some help with this and also to figure out the Bools and other structures.

I'm very open to a different way of resolving this issue as well.

Proposal

To summarize I propose to change the user-side forward in Pipe to support the following requirements:

def forward(self, input_to_slice, input_to_copy):
    [...]
    return input_to_slice, input_to_copy

where the input_to_slice tuple:

  • may contain a single Tensor or a tuple of (Tensor|None). None must be supported in input, since None is that only sometimes, and normal input data at other times. So micro-batch splice if it's not None, and pass None otherwise.
  • first dimension of batch size len
  • obvious how to switch to()

basically, what we have now plus supporting any number of Nones in the tuple.

and then add an optional input_to_copy tuple, which:

  • could be None
  • may contain any nested tuples, lists, dicts, simple vars (but no objects, unless perhaps if they have .to() implemented)
  • any variable in the tuple can be of any dimension - we are not slicing these
  • will be switched to() by recursively traversing the structure. Here is a recursive_to function that does the recursive traversal. may need to add support for objects which implement .to()

Each stage of the Pipe will receive:

  • self
  • a microbatch slice of input_to_slice switched to the right device
  • a full copy of input_to_slice switched to the right device

the Pipe will return:

  • a reconstructed to full batch input_to_slice (updated to outputs)
  • a full copy of input_to_slice (updated to outputs)

The names are just to make the proposal clear - surely they should be named something else should the proposal be accepted.

Thank you!

p.s. If I'm not mistaken pytorch Pipe is derived from/modelled after FairScale and there is also DeepSpeed's implementation, so let me know if I should work it out with one of the above first and then sync with pytorch?

tagging @blefaudeux from FairScale and @ShadenSmith from DeepSpeed - in case you have some insight - and since we need all 3 sets of APIs to agree I believe. If it's someone else at your org that is in charge of that domain please kindly tag them instead. thank you!

tagging @pritamdamania87 who originally suggested I use Pipe here.

@facebook-github-bot facebook-github-bot added cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jan 18, 2021
@stas00 stas00 changed the title [wip] [Pipe] supporting None and Bools [wip] [Pipe] supporting None and Bools in forward's input/output Jan 18, 2021
@stas00 stas00 changed the title [wip] [Pipe] supporting None and Bools in forward's input/output [wip] [Pipe] supporting None and non-Tensors in forward's input/output Jan 19, 2021
@blefaudeux
Copy link
Contributor

@stas00 I know very little about that codebase, although from a distance I can imagine some reasons why there would be input type restrictions, for instance because of autograd, but that's just a supposition. Tagging @froody who would know best, and @msbaines as my second best guess

@stas00
Copy link
Contributor Author

stas00 commented Jan 19, 2021

@stas00 I know very little about that codebase, although from a distance I can imagine some reasons why there would be input type restrictions, for instance because of autograd, but that's just a supposition. Tagging @froody who would know best, and @msbaines as my second best guess

Thank you very much for tagging the right people, @blefaudeux!

@stas00
Copy link
Contributor Author

stas00 commented Jan 20, 2021

Here is a related discussion specific to DeepSpeed PP: microsoft/DeepSpeed#659

@pritamdamania87
Copy link
Contributor

def forward(self, input_to_slice, input_to_copy):

I really like this proposal @stas00! As per my understanding input_to_slice is the input that gets partitioned across multiple microbatches and input_to_copy is the "common" input across all microbatches (ex: could be some configs or bools indicating what to execute)?

the Pipe will return:

  • a reconstructed to full batch input_to_slice (updated to outputs)
  • a full copy of input_to_slice (updated to outputs)

I'm not sure I understood this currently, why does the pipe need to return input_to_slice, input_to_copy? Is this for the case where models might also output data per microbatch and also some "common" vars which are same across all microbatches?

  1. tuples of Tensors within a tuple - there are several aggregators that append to a tuple, which is then returned as part of outputs - so again, some recursive introspection and to() is needed see below.

This requirement doesn't seem to be covered in your proposal?

@stas00
Copy link
Contributor Author

stas00 commented Jan 21, 2021

def forward(self, input_to_slice, input_to_copy):

I really like this proposal @stas00!

Thank you for your validation, @pritamdamania87

As per my understanding input_to_slice is the input that gets partitioned across multiple microbatches and input_to_copy is the "common" input across all microbatches (ex: could be some configs or bools indicating what to execute)?

That's exactly right.

It may contain tensors too, so this input_to_copy will need to be traversed recursively and to() to the right device. But otherwise it gets passed as is.

the Pipe will return:

  • a reconstructed to full batch input_to_slice (updated to outputs)
  • a full copy of input_to_slice (updated to outputs)

I'm not sure I understood this currently, why does the pipe need to return input_to_slice, input_to_copy? Is this for the case where models might also output data per microbatch and also some "common" vars which are same across all microbatches?

  1. to match the pipeline stage definition where outputs of one stage become inputs to the next stage. so it's consistent. I can see how it doesn't have to be so and input_to_copy just gets copied to each stage as is.

  2. But if something gets aggregated it might be useful to return it to continue passing the same data and let the user decide whether they want to modify input_to_copy or return it as is, this could be configurable too, if the user wants things somewhat simpler. It could also be a magical do-what-i-mean - if the stage of the pipe returned this 2nd structure then pass it forward to the next pipeline stage, if only one variable was returned then pass along the original input_to_copy.

  1. tuples of Tensors within a tuple - there are several aggregators that append to a tuple, which is then returned as part of outputs - so again, some recursive introspection and to() is needed see below.

This requirement doesn't seem to be covered in your proposal?

Indeed. I think it'll have to fit into input_to_copy - since it'd be unrestricted, besides requiring all "leaves" to be .to()-able.

Basically, currently I'm dealing with a complex structure of batch sized tensors inside a tuple inside another tuple, and the tensors aren't even of the same shape (2 different shapes) - that's why they are in tuples and not a tensor in first place. So since it needs to be spliced for micro-batching, I don't think we have any choice here but to remap that structure to a normal tensor and then reconstruct it on the other side to what it originally was. I will have to convert it to 2 tensors so that I could match the sizes. I wish we had nestedtensor available. But it's OK, it's already complicated, so what's another flip.

Mind you we are converting what used to be a complex loop over ModuleList with a full control over what was passed to forward within that loop to having much less control.

@stas00
Copy link
Contributor Author

stas00 commented Jan 21, 2021

I wonder if we could just have **kwargs for that input_to_copy - so it'd be even more intuitive to the user. and it'd be even easier to port applications, since we already are using dict params. e.g. all those use_this=True, enable_that=False flags would just remain unmodified.

But I guess we would need to stick to a single dict variable and not all model(input, key1=val1, key2=val2) style of args.

On the other hand, the func signature could be even the totally normal:

def forward(self, *input_to_slice, **kwargs_to_copy):
    [...]
    return input_to_slice, kwargs_to_copy

except with a restriction that non-keyword args will have be a tuple[None|Tensor] and the first dimension of batch size if not None. Perhaps it'd be too confusing?

I guess the biggest inconsistency then would be - how then do you return the same back, python doesn't quite have a feature to do:

return x, y=5

I guess it will have to be packed into a dict then:

return x, dict(y=5)

If we do that, then following up to my answer to your question about why bother returning the input_to_copy struct, a user could return the full dict or just some parts of it and the pipe wrapper would merge the original kwargs_to_copy with what the stage returned if any and then pass on that updated structure to the next stage.

Thoughts?

This proposal is also backward compatible with the existing API as it extends the functionality and should work just fine with the original much more restrictive API.

@stas00
Copy link
Contributor Author

stas00 commented Jan 21, 2021

For posterity, to bypass that tensor only restriction I tried to solve the data aggregation problem using a closure:

        def present_key_value_states_add(x):
            nonlocal present_key_value_states
            present_key_value_states += (x,)

and passing it to the pipeline wrapper class, and I was getting everything messed up, until I decided to print out the order each stage + micro-batch gets executed and quickly (and obviously) saw that in the pipeline the order of execution is very unpredictable.

In this printout the first number is the block id, the second is the micro-batch id - we have 6 blocks and 2 micro-batches.

0 0
1 0
2 0
0 1
3 0
41 0 1

2 1
5 0
3 1
4 1
5 1
0 0
1 0
2 0
3 0
0 1

So scratching that idea and will try a different approach instead.

I think I may sort out how to flatten those tuple in tuple structures into tensors and then re-store them on the other side of the pipeline to what the application expects. I will post an update if and when I make it work.


Edit: omg, I made it work with using a closure and slots which ensured the correct insertion order. I track each block id and each micro-batch id:

        present_key_value_states = [[0 for x in range(2)] for y in range(6)]
        def present_key_value_states_add(x, block_id, micro_batch_id):
            nonlocal present_key_value_states
            present_key_value_states[block_id][micro_batch_id] = x

and then I had to reconstruct it back to the correct tuple of tuples and manually merge the microbatches. and switch all tensors to the same device.

I'm yet to figure out how to go about doing it in a simpler way, but this is my first success. I made PP work with t5 and a lot of hardcoded hacks. next cleaning it up and generalizing it!

I think the other more complicated solution would be to prepare an empty tensor of the correct dimensions and then pass it to forward and back through each block and stage - again fill the right slots but each of these is so difficult to think about. Perhaps after doing it a few times I will get the hang of it.

@pritamdamania87
Copy link
Contributor

Thanks for all of the detailed comments @stas00! I was thinking if the following contracts made sense:

  1. The first stage of the pipeline needs to have a forward signature that is compatible with the data types passed into the forward method for the Pipe. For example (this is using your inputs_to_slice and kwargs_to_copy idea):
def forward_stage1(tensor, foo, bar, baz):
  pass

Pipe.forward(tensor, foo=1, bar=2, baz=2)
  1. There is no restriction for intermediate stages of the pipeline. Basically the output of one stage of pipeline is recursively copied (via .to()) to the next device in the pipeline. As a result, we don't have to adhere to inputs_to_slice, kwargs_to_copy convention for intermediate modules.
  2. The return type of the last stage of the pipeline has to be compatible with the return type of the pipeline and should be a type we can aggregate. I'm wondering if we can keep the restriction that the last stage of the pipeline only returns Tensor/Sequence of Tensors that we can aggregate? Because, if we allow the last stage of the pipeline to return arbitrary types, it is not clear how to aggregate these across multiple microbatches.

Putting this all together in an example might help. So let's say our original model is as follows:

class MyModel(nn.Module):
   def forward(inp, foo, bar, baz):
       tmp1 = inp + foo
       tmp2 = tmp1 + bar
       tmp3 = tmp2 + baz
      return tmp3

We can represent this into a pipe as follows:

class Stage1(nn.Module):
   def forward(inp, foo, bar, baz)
      return inp + foo, bar, baz

class Stage2(nn.Module):
   def forward(inp, bar, baz):
      return inp+ bar, baz

class Stage3(nn.Module):
    def forward(inp, baz):
       return inp + baz

seq = nn.Sequential(Stage1(), Stage2(), Stage3())

# Pipe forward
Pipe.forward(torch.rand(10, 10), foo=1, bar=2, baz=3)

The high level idea here is that there is tight coupling between the first stage of the pipeline and Pipe's forward method and a similar tight coupling between the return type of the last stage and the return type of Pipe's forward. Apart from that, we allow the rest of the stages to be defined arbitrarily by the user (as long as the signatures between subsequent stages is compatible).

@stas00
Copy link
Contributor Author

stas00 commented Jan 22, 2021

Thank you for writing out that clear proposal and the examples.

Overall, yes, to everything you said.

wrt the 3rd point, it can be exactly as the originally proposed inputs signature of the first stage

return tuple_to_be_reconstructed, tuple_to_be_returned_as_is

Here is the concrete obstacle I've been trying to overcome wrt return values. I'm converting:

aggregate = ()
for block in layers: # set of 6 identical blocks (t5stack)
    output = block(input)
    aggregate += output[1]

into:

aggregate = ()
# I had to create a wrapper around the original layers to work around the limitations
block_pipe = Pipe(nn.Sequential(*layers), chunks=2, checkpoint="never")
output = block_pipe(input).local_value()
aggregate += ???

So you can see that in the case I've been working on the first stage of the pipe and all the other stages have to be the same, that's why here the last stage has to return the same output as the first stage. So it has to be the same as inputs. Does this help?

At the moment I used a really complex solution of using a closure with a simple python 2D list with slots, which I fill out through the pipe keeping track of the depth of the stack and the the micro-batche ids to know where to insert the aggregate chunk and then manually re-constructing the data on the exit from the pipe. Terrible, but it works.

And you're absolutely correct about intermediate stages not needing any restrictions - it's only the very first stage that has to behave in a very restricted way. I appreciate you clarifying this. I certainly missed this point and I think made my life much more complicated. I need to sit and ponder some more and I will post back if I get to see the light.

I'm very grateful for your feedback and explanations, @pritamdamania87.

@pritamdamania87
Copy link
Contributor

pritamdamania87 commented Jan 22, 2021

Thanks for providing the concrete example, helps a lot in understanding the problem! As per my understanding for the example you mentioned above, setting up the pipeline would look like:

class Block(nn..Module):
   def forward(self, inp, aggregate):
       output = block(inp)
       aggregate += output[1]
       return output, aggregate

seq = nn.Sequential(Block(), Block())
pipe = Pipe(seq, chunks = 2)
output, aggregate = pipe.forward(torch.rand(2, 2), aggregate)

Although, I'm not sure how aggregate would be appropriately accumulated in the pipeline. For example lets assume there are two microbatches with inputs: [0, 1] and [2, 3] respectively. Also, lets assume the block simply outputs the input it receives for simplicity.

Now the output and aggregate for microbatch 0 at the last stage of the pipeline is:

output = [0, 1]
aggregate = 2

For microbatch 1 at the last stage it would be:

output = [2, 3]
aggregate = 6

Now the pipe can return a combined output of [[0,1], [2,3]] by concatenating on the batch dimension, however we can't just return aggregate as is since the actual value that we want is 6 + 2 = 8. So don't we need to have some way of aggregating non-tensor values as well in this case?

@stas00
Copy link
Contributor Author

stas00 commented Jan 22, 2021

The problem is that the aggregate is not a number but a tuple of tuples of tensors:

snapshot_4

This what goes in and a similar structure with different contents needs to go out. (well a few of those)

You can see that the first tuple is of size 6 - we have 6 stages in the pipeline. So each stage takes one of 6 and leaves another of these on the way out.

Then there is a tuple of size 4 - which are 4 different keys, again it uses all these and leaves another set on the way out. Each tensor is of a different dimension - that's why I think they were forced to use tuples in first place.

Then you can see the actual tensors of batch size 3. Which is a problem, since the batch-size to be sliced on is hidden deep inside.

It's trivial when it's in the original loop and the batch size remains unchanged, but in this situation it's nuts.

As I said I made it work using a closure to aggregate and a lot of composing and recomposing from tensors back to python structures, And on the way in I invert this structure and stack it into a huge tensor, so that the batch dimension is first to be spliced on. Then on the other side I recompose it back to a tuple of tuple of tensors, by chunking it twice.

This is not a code that can be left in production. I don't think it's efficient either.

I think the whole thing needs to be rethought, but this is how most transformers models are written. Of course, they weren't written to lend to an easy pipeline conversion.

@pritamdamania87
Copy link
Contributor

Then you can see the actual tensors of batch size 3. Which is a problem, since the batch-size to be sliced on is hidden deep inside.

Thanks for providing a detailed example. I'm wondering if we could recursively enter a Tuple/List and slice Tensors we find inside them? Today the Pipe API does this with one level of a Tuple, but I think we should be able to support arbitrarily nested Tensors too. Would this resolve the issue here or is the batch dimension across Tensors in the Tuple of different sizes too?

Since the inputs here seem fairly complicated, I'm wondering if you could share some steps in terms of setting up and running this Transformer model. If I can run it locally and inspect the inputs myself, I'll probably have a much better idea of whether there is something we can do to support such models in the pipeline.

I have very basic understanding of Transformer models, but I was wondering if the Transformer models could be rewritten to be similar to what we have in our sample benchmarks: https://github.com/pytorch/pytorch/blob/156da22566fc8c49065a6075a4e1352bf4de63d9/benchmarks/distributed/pipeline/pipe.py.

Looking at https://github.com/huggingface/transformers/blob/357fb1c5d8b6a16f042f9b504f023d935086e8e5/src/transformers/models/t5/modeling_t5.py#L600-L612, it seems like that API is more functional where things like hidden_states are passed into the forward function instead of being initialized as part of the Block.

@pritamdamania87
Copy link
Contributor

@stas00 Another example of modularizing Transformers for pipelining can be found here: https://github.com/pytorch/fairseq/blob/master/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py. I was wondering if this would work for your use case?

@stas00
Copy link
Contributor Author

stas00 commented Jan 23, 2021

Then you can see the actual tensors of batch size 3. Which is a problem, since the batch-size to be sliced on is hidden deep inside.

Thanks for providing a detailed example. I'm wondering if we could recursively enter a Tuple/List and slice Tensors we find inside them? Today the Pipe API does this with one level of a Tuple, but I think we should be able to support arbitrarily nested Tensors too. Would this resolve the issue here or is the batch dimension across Tensors in the Tuple of different sizes too?

I think this would be super magical, yes, that would have saved so much trouble.

Since the inputs here seem fairly complicated, I'm wondering if you could share some steps in terms of setting up and running this Transformer model. If I can run it locally and inspect the inputs myself, I'll probably have a much better idea of whether there is something we can do to support such models in the pipeline.

Yes, of course. It's very simple.

cd /tmp
https://github.com/huggingface/transformers/
cd transformers
pip install -e .

and now here is a very simple script I've been using to test:

import sys
# edit this to point to where you checked out transformers if it's not under /tmp/transformers - "src" is where the source is
sys.path.insert(0, "/tmp/transformers/src")

from transformers import T5Tokenizer, T5ForConditionalGeneration
mname = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(mname)
model = T5ForConditionalGeneration.from_pretrained(mname, return_dict=True)

texts = ["This is good", "This is bad"]
texts = ["translate English to French: "+x for x in texts]
batch = tokenizer.prepare_seq2seq_batch(texts, return_tensors="pt")
outputs = model.generate(**batch)
for x in outputs:
    decoded = tokenizer.decode(x, skip_special_tokens=True)
    print(decoded)

Now the stack goes for this t5-small with 6 blocks:

T5ForConditionalGeneration->
   T5Stack(encoder)->T5Block->T5Block->T5Block->T5Block->T5Block->T5Block
   T5Stack(decoder)->T5Block->T5Block->T5Block->T5Block->T5Block->T5Block

I omitted a few small extra layers, but this is the bulk of it. And each T5Block calls multiple other modules, but I think those should remain in the block and not sequentialized.

As you can see the stacks lend really well to a pipeline, so my initial attempt is not to convert the whole model to Pipe, but only the repeated blocks, resulting in 2 pipes, so currently my attempted version I'm working on symbolically looks like:

T5ForConditionalGeneration->
   T5Stack(encoder)->Pipe(Sequential([T5StackPipeSegment* 6])
   T5Stack(decoder)->Pipe(Sequential([T5StackPipeSegment* 6])

Please let me know if you have any questions or if in any way I can help you to quickly understand it. I find that using a good debugger like pycharm makes it much easier to quickly navigate through the stages and easily visualize the parameters.

I have very basic understanding of Transformer models, but I was wondering if the Transformer models could be rewritten to be similar to what we have in our sample benchmarks: https://github.com/pytorch/pytorch/blob/156da22566fc8c49065a6075a4e1352bf4de63d9/benchmarks/distributed/pipeline/pipe.py.

Yes, this is what I have been using as the example. As you originally recommended I look into it.

Thank you for all the other links, @pritamdamania87 - I will study them in the next few days.

@stas00
Copy link
Contributor Author

stas00 commented Jan 28, 2021

@pritamdamania87, so what needs to happen for the adjusted by you proposal to become a reality?

Would be awesome to have the main functionality working and well-tested before 1.8 is released.


Also there are a few more issues to discuss - should I open a separate issue about those?

We need these 2 supported by DeepSpeed:

  1. Tied Layers - this one is crucial or none of the transformers training will do well - since almost all models there have tied weights.
  2. Memory-Efficient Model Construction - this one I haven't had a chance to validate yet, as I'm first focusing on making things just work. But it'd definitely be very important.

Thank you!

@pritamdamania87
Copy link
Contributor

@pritamdamania87, so what needs to happen for the adjusted by you proposal to become a reality?

Would be awesome to have the main functionality working and well-tested before 1.8 is released.

Hey @stas00, I've been tied up with a few things related to the 1.8 release and I'll get to this once that is resolved. Regarding the 1.8 release, we're releasing pipeline parallelism with the exiting API. Although, this is still a "beta" feature, so the enhancements proposed here can be included in 1.9.

Also there are a few more issues to discuss - should I open a separate issue about those?

Yes, it would be good to open separate issues for those.

@stas00
Copy link
Contributor Author

stas00 commented Feb 9, 2021

Here you are, @pritamdamania87.

I appreciate the update and will patiently wait till your other tasks are completed.

@msbaines
Copy link

msbaines commented Feb 9, 2021

@pritamdamania87, so what needs to happen for the adjusted by you proposal to become a reality?

Would be awesome to have the main functionality working and well-tested before 1.8 is released.

Also there are a few more issues to discuss - should I open a separate issue about those?

We need these 2 supported by DeepSpeed:

  1. Tied Layers - this one is crucial or none of the transformers training will do well - since almost all models there have tied weights.

Pipe supports shared-parameters by default as long the partitions with the shared-parameter are mapped to the same device:

raise ValueError("module with duplicate parameters on distinct devices is not supported")

What type of models are you using shared-parameters in? So far, we've only seen shared-parameters being used for the shared-embedding table used in Machine Translation models.

  1. Memory-Efficient Model Construction - this one I haven't had a chance to validate yet, as I'm first focusing on making things just work. But it'd definitely be very important.

Thank you!

@stas00
Copy link
Contributor Author

stas00 commented Feb 9, 2021

Thank you for your follow up, @msbaines - since I created a dedicated issue for this feature let's continue there: #51931 (comment)

@pritamdamania87
Copy link
Contributor

and now here is a very simple script I've been using to test:

Looking at the simple script it seems like we use only input_ids and attention_mask which can be represented as a Tuple of two Tensors in our current Pipe API. Although, I'm assuming you're running into issues with much more complicated examples since the T5Block has several input arguments. I was wondering if you could point me to a more complex example that would illustrate the problem?

@stas00
Copy link
Contributor Author

stas00 commented Feb 16, 2021

I think you're thinking model(**inputs), but this is model.generate(**inputs). And so I present to you generate:

https://github.com/huggingface/transformers/blob/1c8c2d9ab34b8c8d326db9e0608f8e54cfccb885/src/transformers/generation_utils.py#L616-L647

and lots and lot happens between it and eventually model(**inputs) and most likely you want this entry point where the model is called:

https://github.com/huggingface/transformers/blob/1c8c2d9ab34b8c8d326db9e0608f8e54cfccb885/src/transformers/generation_utils.py#L1407-L1412

which in the case of T5ForConditionalGeneration calls this first forward:

https://github.com/huggingface/transformers/blob/1c8c2d9ab34b8c8d326db9e0608f8e54cfccb885/src/transformers/models/t5/modeling_t5.py#L1443-L1460

@pritamdamania87
Copy link
Contributor

@stas00 Yes you are right, I was thinking about model(**inputs). Although, when I actually printed the args inside the forward function of T5Block, everything else was None except input_ids and attention_mask.

Btw, can you create a gh issue for this and we can continue the discussion there :) It's easier to keep track of gh issues instead of discussing this on a PR :)

@pritamdamania87
Copy link
Contributor

@stas00 Yes you are right, I was thinking about model(**inputs). Although, when I actually printed the args inside the forward function of T5Block, everything else was None except input_ids and attention_mask.

Nvm, I was looking at only the first few log lines it looks like the later T5Blocks have most of the input arguments populated.

@stas00
Copy link
Contributor Author

stas00 commented Feb 17, 2021

Yes, the first time the model is called from generate it's almost all Nones, and then it's all filled out on subsequent calls within the same generate call.

I'd be happy to open a dedicated Issue - I'm just a bit lost on which specific topic? I know my PR turned into a small battle field - in a sense that we discuss multiple things....

@pritamdamania87
Copy link
Contributor

I'd be happy to open a dedicated Issue - I'm just a bit lost on which specific topic? I know my PR turned into a small battle field - in a sense that we discuss multiple things....

I guess creating an issue with your original PR summary would be a good idea :) Basically, the issue I'm interested in tracking is if there is a nice way to incorporate these T5 blocks into a pipeline.

@pritamdamania87
Copy link
Contributor

After spending some time poking around, it looks like the limitation of Tensor or Tuple of Tensors comes from the fact that PyTorch's autograd machinery has this limitation when it comes to checkpointing: https://github.com/pytorch/pytorch/blob/master/torch/csrc/autograd/python_function.cpp#L369. This is basically the error that you were seeing:

TypeError: WaitBackward.forward: expected Tensor or tuple of Tensor (got NoneType) for return value 3

The restriction of Tensor/tuple of Tensors comes from the autograd machinery since it is hard to support arbitrary datatypes there. Although, I think I might have some workaround that might work here.

This is what the input structure for the T5Block looks like when I ran your example locally:

HIDDEN: torch.Size([2, 1, 512])
ATTENTION MASK: torch.Size([2, 1, 1, 7])
BIAS: torch.Size([2, 8, 1, 7])
EHIDDEN: torch.Size([2, 9, 512])
EATTENTION: torch.Size([2, 1, 1, 9])
EPOSBIAS: torch.Size([2, 8, 1, 9])
ELBIAS: None
ELAYERHEADMAS: None
PAST KEY: (torch.Size([2, 8, 6, 64]), torch.Size([2, 8, 6, 64]), torch.Size([2, 8, 9, 64]), torch.Size([2, 8, 9, 64]), )
USE CACHE: True
OUTPUT ATTENTIONS: False
RETURN DICT: True
  1. The first 5 inputs are just tensors, so we can wrap all of them into a tuple of tensors.
  2. The next two None arguments can be replaced by something like torch.Tensor() to represent None as an empty Tensor and put those arguments into the tuple as well.
  3. Now past_key_value is a tuple of tensors, but we can just append this tuple to our existing tuple of tensors to have a flat tuple with all the tensors. We can rewrite the T5Block such that it assumes all tensors after encoder_layer_head_mask belong to past_key_value.
  4. For the boolean values you could probably represent them as a torch.BoolTensor.

This is slightly hacky, but I think it would work and the core idea is to represent everything in a flat tuple of Tensors.

@stas00
Copy link
Contributor Author

stas00 commented Feb 18, 2021

Thank you for getting to the root of this limitation, @pritamdamania87! That helps a lot to know why that limitation is there in first place!

Thank you for proposing the hacky workaround. I think I have mentioned a few weeks ago that I did make it work with 2 pipelines using slightly different type of hacks: huggingface/transformers#9765 Except the outcome is very inefficient - I can't break 50% gpu util over 2 gpus with any chunk size. So it's kind of pointless at the moment, just as well use a naive MP and have just one gpu working at a time (which is a terrible waste of resources).

The problem is that any of the hacks we may have to adopt leads to quite complex code. And since transformers tries to have the models editable by researchers, this is just not going to be great.

And as I mentioned elsewhere sagemaker as of recent provides pipeline support to any model, not requiring it to be nn.Sequential so they must have found a solution around this limitation. Unfortunately their code is proprietary. They do say that it's faster with nn.Sequential.

@pritamdamania87
Copy link
Contributor

I think I have mentioned a few weeks ago that I did make it work with 2 pipelines using slightly different type of hacks: huggingface/transformers#9765 Except the outcome is very inefficient - I can't break 50% gpu util over 2 gpus with any chunk size. So it's kind of pointless at the moment, just as well use a naive MP and have just one gpu working at a time (which is a terrible waste of resources).

If you still have some of this code, is it possible to share some steps to run it so I can try to run it locally to see what might be causing low gpu utilization?

@stas00
Copy link
Contributor Author

stas00 commented Feb 18, 2021

If you still have some of this code, is it possible to share some steps to run it so I can try to run it locally to see what might be causing low gpu utilization?

It's fully documented in the PR huggingface/transformers#9765, I can't link directly to the headers, but you will find copy-n-paste Setup followed by 2 Deployment versions - one via a custom script (ready to be run) but can't appreciate performance in this one, and then via the HF trainer where you can push lots of data and thus enough runtime to measure utilization. The full run command is under the "Benchmarks" section - the baseline and then with pipeline enabled.

Please let me know if you get stuck anywhere, I tried to make the reproduction an easy copy-n-paste.

@stas00
Copy link
Contributor Author

stas00 commented Mar 13, 2021

I'd be happy to open a dedicated Issue - I'm just a bit lost on which specific topic? I know my PR turned into a small battle field - in a sense that we discuss multiple things....

I guess creating an issue with your original PR summary would be a good idea :) Basically, the issue I'm interested in tracking is if there is a nice way to incorporate these T5 blocks into a pipeline.

Apologies, I neglected to do it sooner, here it is: #53952

Thank you, @pritamdamania87!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue open source Stale
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants