Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace torch.no_grad with torch.inference_mode (where possible) #3601

Merged
merged 5 commits into from
Nov 23, 2022
Merged

refactor: replace torch.no_grad with torch.inference_mode (where possible) #3601

merged 5 commits into from
Nov 23, 2022

Conversation

anakin87
Copy link
Member

@anakin87 anakin87 commented Nov 17, 2022

Related Issues

None

Proposed Changes:

As explained in this PyTorch tweet, replacing torch.no_grad with torch.inference_mode can lead to performance improvements.

Just a first draft to run the CI and see if anything breaks...

Checklist

@anakin87 anakin87 marked this pull request as ready for review November 17, 2022 22:03
@anakin87 anakin87 requested a review from a team as a code owner November 17, 2022 22:03
@anakin87 anakin87 requested review from masci and removed request for a team November 17, 2022 22:03
@anakin87 anakin87 marked this pull request as draft November 17, 2022 22:10
@anakin87 anakin87 marked this pull request as ready for review November 17, 2022 22:10
@anakin87 anakin87 marked this pull request as draft November 17, 2022 22:17
@anakin87 anakin87 marked this pull request as ready for review November 17, 2022 22:17
@anakin87 anakin87 marked this pull request as draft November 17, 2022 22:45
@anakin87 anakin87 marked this pull request as ready for review November 18, 2022 08:13
@anakin87
Copy link
Member Author

Now the CI is passing but I'm not sure that all the modified methods are covered by tests.

@julian-risch @bogdankostic Do you think that this change is safe? What is your impression?

@anakin87 anakin87 changed the title refactor: try to replace torch.no_grad with torch.inference_mode refactor: replace torch.no_grad with torch.inference_mode (where possible) Nov 18, 2022
@julian-risch
Copy link
Member

@anakin87 The only file that needs manual testing is haystack/utils/augment_squad.py. Maybe you could run that with and without your changes to compare the results? I don't think that your changes will break anything and I am confident that it will give some speed improvements. So I am very much looking forward to seeing this PR merged. Good work! 👍
A bonus would be to have some speed measurements.

@anakin87
Copy link
Member Author

anakin87 commented Nov 22, 2022

Hey @julian-risch, thanks for your feedback!

Test augment_squad

In order to properly test haystack/utils/augment_squad.py with inference mode, I ran this Colab notebook.
The result is fine.

Speed measurements

I also thought to use the same notebook to record some time measurements and compare them with the original no_grad (this other notebook).

In particular, I executed the following cell:

%%timeit -n 1 -r 5
!python augment_squad.py \
    --squad_path data/distil_a_reader/squad/squad_small.json \
    --glove_path data/distil_a_reader/glove/glove.6B.300d.txt \
    --output_path augmented_dataset.json \
    --multiplication_factor 2

I obtained:

  • no_grad : 2min 27s ± 4.67 s per loop (mean ± std. dev. of 5 runs, 1 loop each)
  • inference_mode: 2min 12s ± 2.91 s per loop (mean ± std. dev. of 5 runs, 1 loop each)

The results would sound exciting but I'd take them with a grain of salt. In fact, I've found very few resources that measure the impact of this change and some (like this, on a different domain) report small performance improvements.

In any case, PyTorch docs say this:

If it works out of the box for your use case it’s a free performance win.

@julian-risch
Copy link
Member

@anakin87 That's great news! Thank you for the extra effort of conducting speed experiments. I have one last question before approving and merging this PR. There are two calls to torch.no_grad() in haystack/nodes/reader/table.py and two in haystack/modeling/training/base.py. Is there a specific reason why not to use inference_mode there instead of no_grad too? Or can we also replace these four calls? It looks like the latter to me, no? Either way this is a great PR! 👍

@anakin87
Copy link
Member Author

There are two calls to torch.no_grad() in haystack/nodes/reader/table.py and two in haystack/modeling/training/base.py. Is there a specific reason why not to use inference_mode there instead of no_grad too?

Yes. At first i tried to apply the change everywhere.
As you can see from the failed tests on the first few commits, i ran into errors (RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.) and reverted the changes that triggered these errors.

Copy link
Member

@julian-risch julian-risch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look good to me! 👍 Thanks again for testing the augment_squad.py script and comparing the speed before and after the changes.

@julian-risch julian-risch merged commit f43bc56 into deepset-ai:main Nov 23, 2022
@anakin87 anakin87 deleted the try_torch_inference_mode branch November 23, 2022 08:28
@sjrl
Copy link
Contributor

sjrl commented Nov 25, 2022

There are two calls to torch.no_grad() in haystack/nodes/reader/table.py and two in haystack/modeling/training/base.py. Is there a specific reason why not to use inference_mode there instead of no_grad too?

Yes. At first i tried to apply the change everywhere. As you can see from the failed tests on the first few commits, i ran into errors (RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd.) and reverted the changes that triggered these errors.

@julian-risch I've been working on the TableQA code for a client project. I also saw this error while working there and I believe I can fix it with minimal changes. Would you like me to open a new PR for that?

Update: Ahh never mind. I remember now this would require an update in the transformers library since the error results from there.

ZanSara added a commit that referenced this pull request Nov 28, 2022
* Fix docstrings for DocumentStores

* Fix docstrings for AnswerGenerator

* Fix docstrings for Connector

* Fix docstrings for DocumentClassifier

* Fix docstrings for LabelGenerator

* Fix docstrings for QueryClassifier

* Fix docstrings for Ranker

* Fix docstrings for Retriever and Summarizer

* Fix docstrings for Translator

* Fix docstrings for Pipelines

* Fix docstrings for Primitives

* Fix Python code block spacing

* Add line break before code block

* Fix code blocks

* fix: discard metadata fields if not set in Weaviate (#3578)

* fix weaviate bug in returning embeddings and setting empty meta fields

* review comment

* Update unstable version and openapi schema (#3584)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* fix: Flatten `DocumentClassifier` output in `SQLDocumentStore`; remove `_sql_session_rollback` hack in tests (#3273)

* first draft

* fix

* fix

* move test to test_sql

* test: add test to check id_hash_keys is not ignored (#3577)

* refactor: Generate JSON schema when missing (#3533)

* removed unused script

* print info logs when generating openapi schema

* create json schema only when needed

* fix tests

* Remove leftover

Co-authored-by: ZanSara <sarazanzo94@gmail.com>

* move milvus tests to their own module (#3596)

* feat: store metadata using JSON in SQLDocumentStore (#3547)

* add warnings

* make the field cachable

* review comment

* Pin faiss-cpu as 1.7.3 seems to have problems (#3603)

* Update Haystack imports (#3599)

* Update Python version (#3602)

* fix: `ParsrConverter` fails on pages without text (#3605)

* try to fix bug

* remove print

* leftover

* refactor: update Squad data  (#3513)

* refractor the to_squad data class

* fix the validation label

* refractor the to_squad data class

* fix the validation label

* add the test for the to_label object function

* fix the tests for to_label_objects

* move all the test related to squad data to one file

* remove unused imports

* revert tiny_augmented.json

Co-authored-by: ZanSara <sarazanzo94@gmail.com>

* Url fixes (#3592)

* add 2 example scripts

* fixing faq script

* fixing some urls

* removing example scripts

* black reformatting

* add labeler to the repo (#3609)

* convert eval metrics to python float (#3612)

* feat: add support for `BM25Retriever` in `InMemoryDocumentStore` (#3561)

* very first draft

* implement query and query_batch

* add more bm25 parameters

* add rank_bm25 dependency

* fix mypy

* remove tokenizer callable parameter

* remove unused import

* only json serializable attributes

* try to fix: pylint too-many-public-methods / R0904

* bm25 attribute always present

* convert errors into warnings to make the tutorial 1 work

* add docstrings; tests

* try to make tests run

* better docstrings; revert not running tests

* some suggestions from review

* rename elasticsearch retriever as bm25 in tests; try to test memory_bm25

* exclude tests with filters

* change elasticsearch to bm25 retriever in test_summarizer

* add tests

* try to improve tests

* better type hint

* adapt test_table_text_retriever_embedding

* handle non-textual docs

* query only textual documents

* Incorporate Reviewer feedback

* refactor: replace `torch.no_grad` with `torch.inference_mode` (where possible) (#3601)

* try to replace torch.no_grad

* revert erroneous change

* revert other module breaking

* revert training/base

* Fix docstrings for DocumentStores

* Fix docstrings for AnswerGenerator

* Fix docstrings for Connector

* Fix docstrings for DocumentClassifier

* Fix docstrings for LabelGenerator

* Fix docstrings for QueryClassifier

* Fix docstrings for Ranker

* Fix docstrings for Retriever and Summarizer

* Fix docstrings for Translator

* Fix docstrings for Pipelines

* Fix docstrings for Primitives

* Fix Python code block spacing

* Add line break before code block

* Fix code blocks

* Incorporate Reviewer feedback

Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com>
Co-authored-by: Julian Risch <julian.risch@deepset.ai>
Co-authored-by: ZanSara <sarazanzo94@gmail.com>
Co-authored-by: Espoir Murhabazi <espoir.mur@gmail.com>
Co-authored-by: Tuana Celik <tuana.celik@deepset.ai>
Co-authored-by: tstadel <60758086+tstadel@users.noreply.github.com>
@julian-risch
Copy link
Member

@sjrl Thanks to your PR Haystack's transformers dependency has been updated. Do you know whether that allows us to use inference_mode now also where we previously got the error mentioned above?

@sjrl
Copy link
Contributor

sjrl commented Dec 19, 2022

Yes it does! It also requires a small code change in our code base which I've opened in PR #3731

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants