Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify BERT/BERT-descendants to be TorchScript-able (not just traceable) #5067

Closed
sbrody18 opened this issue Jun 16, 2020 · 20 comments
Closed
Labels

Comments

@sbrody18
Copy link

sbrody18 commented Jun 16, 2020

🚀 Feature request

Modify BERT models (src/transformers/modeling_bert.py) to conform to TorchScript requirements, so they can be jit.script()-ed, not just jit.trace()-ed (as is currently the only supported option)

Note: I have a working version implementing this, which I would like to contribute.
See below.

Motivation

A scriptable model would allow for variable-length input, offering big speedup gains and simplification (no need to create different models for different input lengths).

In addition, it would avoid other potential pitfalls with tracing (e.g., code paths that are input dependent and not covered by the tracing example input).

Related issues:
#2417
#1204
possibly also
#1477
#902

Your contribution

I have a working PR that modifies all the models in src/transformers/modeling_bert.py and makes them TorchScript-able. I have not tested it on other models that use BERT components (e.g., albert), but it should be possible to expand the capability to those, as well.

However, it would require some significant work to make it ready for submission: besides formatting, documentation, testing etc., my current version changes the method signatures, and I would need to avoid that to maintain backward-compatibility.

Before putting in that work, I'd like to make sure that such a PR is something you'd be interested in and would be willing to merge in, assuming it meets the requirements.

@LysandreJik
Copy link
Member

Hi! This is interesting. Could you resume what are the changes that would be needed in order to have our models scriptable?

@sbrody18
Copy link
Author

Sure, mostly my changes fall into these categories:

1. Class members can only be basic types, None, nn.Modules, or list or tuple thereof

  • Solution: don't save whole config in the model, only individual entries you need, which are basic types
  • Solution for nn.functional: use the nn.Module equivalent of nn.functional
  • Solution for other functions: define and call the function globally, not as a class member

2. Inputs are assumed to be Tensors

  • Solution: use typing to tell TorchScript the types (note - requires typing to be supported. I checked in python 3.7, but not 3.5 or 3.6)

3. TorchScript can't figure out that an Optional is not None

  • Solution: add assertions to help TorchScript

4. Variable types are not allowed to change depending on conditionals

  • Solution: use consistent types (with Optional to tell TorchScript that a variable/argument can be None) - this is where I had to change the interface, since current BERT models can optionally return attention probabilities. Had to change so that they always return the same sized output tuple, with None values, instead).

5. TorchScript can't handle the expand (*) operator on lists

  • Solution: explicitly enumerate the arguments

6. You can't use nn.Modules as local variables (take variable number of args)

  • Solution: use the nn.functional equivalents of the modules.

7. TorchScript doesn't know about nn.ModuleList's enumerate

  • Solution: use a regular for loop

Most of these are pretty small changes and do not affect the logic. #4 and #1c can be tricky, and #5 might be an issue with recent changes made here: #4874

@mfuntowicz
Copy link
Member

Hi @sbrody18,

Thanks for opening this issue and taking the time to dive into our TorchScript support.

Regarding A scriptable model would allow for variable-length input, offering big speedup gains and simplification:

Do you have some numbers to compare against the current transformers library? We ran some TorchScript tests and the differences where not that huge at that time, may be this has changed since? I (and probably others) would be very interested in knowing more on this aspect.

Regarding the list of changes you suggested:

I'm currently not really in favour of such changes as they are almost changing all the way the library is designed and would have an impact on all the models. Some of them might be further discussed if there are real performance benefits.

@sbrody18
Copy link
Author

sbrody18 commented Jun 17, 2020

Hi @mfuntowicz,
My co-workers and I have run the experiments that show that inference time scales more-or-less linearly with the input size (also supported in the linked article below).

Assuming you are trying to run in C++ (which is the reason to use TorchScript), the current solution, using trace() means that you can only use fixed length input - you have to set a large value for max_length to support your longest expected input, and zero-pad all input to the max-length.
That means if your max_length is 1000 tokens and your average length is 20 tokens, your inference is taking 50x longer than it should.
You can see an example of how big a difference this makes, here, under 'Scenario #3: Smaller Inputs (Dynamic Shapes)'.

I'm guessing the tests you ran were focused specifically on the technical behavior of the models on a fixed input set and didn't take into account the max-length issue. Also, this is only an issue if you need to use TorchScript in order to run in C++.

Re. the change to design, my intention is to keep the model changes to a minimum (e.g., adding type hints and asserts does not change the design at all) and make sure they are fully backwards compatible. There would still be some changes required, but I don't think they are drastic.

As I said in the original post, I have a PR where I did a lot of the work, and I'd be happy to work with someone to figure out how to get it to a state where it can be merged.

@eellison
Copy link

eellison commented Jul 2, 2020

@sbrody18 do you mind sharing your fork ?

@sbrody18
Copy link
Author

sbrody18 commented Jul 2, 2020

Yes, I can do so, but it may have to wait a week or two - things are busy at the moment.

@adampauls
Copy link

I am very interested in this work as well. Our team would like to be able to use TorchScript so we can train without depending on Python. If there's any way I can be of help, I would gladly offer some time here!

@sbrody18
Copy link
Author

Sorry for the delay. I hope to have a reasonable PR later this week.

@sbrody18
Copy link
Author

My change is available at https://github.com/sbrody18/transformers/tree/scripting

Note that it is based off of a commit from earlier this month:
ef0e9d8...sbrody18:scripting
Since then there have been changes made to the BertModel interface adding a return_tuple argument and changing the return type of the forward method, and this would require more effort to resolve.

I listed the principles I used in #5067 (comment). The original components tended to return different sized tuples, depending on arguments, which is problematic for TorchScript. When a component BertX required an interface change to be scriptable, I made a BertScriptableX version with the modifications, and had the BertX component inherit from it and just modify the output so it is compatible with the original API.

I made scriptable versions of BertModel and all the BertFor<Task> classes, except BertForMaskedLM (some complexities there were too much work for a proof of concept).
I added a test to demonstrate the scripting capability.

Note that my change disables the gradient_checkpoint path in the encoder. I think this can be resolved, but I didn't have the time to work on it.

@sbrody18
Copy link
Author

@sgugger @joeddav: see comment above for preliminary PR.
Probably too big and complicated to try to merge as is, but would be happy to work with someone to break things down into reasonable chunks.

@sgugger
Copy link
Collaborator

sgugger commented Aug 14, 2020

Thanks for all the work. Looking at this and our recent changes in the model API (in particular the return_dict argument) I think we probably won't be able to have the models be fully compatible with TorchScript. What is possible however would be to have a second version of the models that don't have the option of return_dict (we can also remove output_hiddens/output_attentions if it makes life easier) and would be fully scriptable.

Since you already started with some components in a different class, I think we should have two models (let's say BertModel and ScriptableBertModel) with the same named parameters so you can seemlessly save/load from one to the other (a workflow would then be to experiment with BertModel, save the fine-tuned model and then go to ScriptableBertModel for inference for instance).

Then I'm not sure what's easiest:

  • have the two inherit from some base class and have a minimal of methods that need to be different (probably just the forward?)
  • or have the second class be a complete rewrite.

I think we should focus on having a proof of concept on one model before moving forward with others.

@sbrody18
Copy link
Author

That makes sense to me. It will probably result in some amount of code duplication, and we'd need to make sure we keep the named parameters in sync, but probably easier to maintain.
So would you suggest the ScriptableBertModel is a separate file?

@sgugger
Copy link
Collaborator

sgugger commented Aug 14, 2020

Not necessarily a separate file, I guess it depends on the amount of code to rewrite. I think we can worry about this in a second stage, once we have a good poc.

@sbrody18
Copy link
Author

@sgugger Please see POC implementation in PR above.

@kevinstephano
Copy link

@sbrody18 in the original PR #6846 you created for this issue, you mentioned you saw a large perf increase with dynamic sequences. What did you use as a test to make that determination?

@sbrody18
Copy link
Author

@kevinstephano - see discussion and conclussions here
We saw a large perfomance increase with an older version of PyTorch, where traced models required the input to be the same length as the one used for tracing, making it necessary to pad short sequences at inference, and adding a lot of unnecessary computation overhead.
With recent versions of PyTorch (>=1.3, I think), this is no longer the case.

@stale
Copy link

stale bot commented Dec 25, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Dec 25, 2020
@stale stale bot closed this as completed Jan 2, 2021
@fteufel
Copy link
Contributor

fteufel commented Jan 6, 2021

Hi,

I just tried jit.script()ing using Bert from the PR (just copied modeling_bert, modeling_utils and replaced relative imports of other dependencies with imports from transformers master branch)
I see there are try blocks left in the code, which cause jit.script to fail:

UnsupportedNodeError: try blocks aren't supported:
  File "/zhome/1d/8/153438/experiments/master-thesis/export_model/modeling_utils_script_proof.py", line 131
        Get torch.device from module, assuming that the whole module has one device.
        """
        try:
        ~~~ <--- HERE
            return next(self.parameters()).device
        except StopIteration:

@sbrody18 how did you export the model? I guess the workaround would be to remove try blocks, but apparently it did work for you as it is.

@sbrody18
Copy link
Author

sbrody18 commented Jan 6, 2021

@fteufel you can see #6846 for a stand-alone implementation that worked at a previous version of the transformers library. Maybe that's good enough for your purposes?
The transformers library has changed significantly since these PRs and I'm not sure if that try was added. If you are using code from the transformers master branch in the model itself, it's likely you will encounter several unscriptable bits.
Specifically for the next function, you can either:
a. remove the try block, since there should always be at least one parameter on the model
b. use the next with default:
first_param = next(self.parameters(), None)
if not first_param:
return first_param.device
c. figure out a better way to decide the model device :)

@MrRace
Copy link

MrRace commented Nov 1, 2022

@sbrody18 It seems have not been merged to official transformers ? My transformers Version: 4.21.3, and it can not use jit.script to convert BERT model to TorchScript.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants