Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mamba slow path bug with dtype mismatch. #32691

Merged
merged 21 commits into from
Oct 1, 2024

Conversation

Adibvafa
Copy link
Contributor

What does this PR do?

Fix issue #32690

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • [] Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • [] Did you write any new necessary tests?

Who can review?

@ArthurZucker

@molbap
Copy link
Contributor

molbap commented Aug 16, 2024

Thanks @Adibvafa for the fix! Do you think we could we add a test for f16/bf16 inference to avoid possible regressions later?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, @molbap feel free to merge once we properly test this case!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@molbap
Copy link
Contributor

molbap commented Sep 16, 2024

Hi again @Adibvafa , I'm back from holidays :) iiuc there aren't tests for various inference dtypes, do you think they could be added to this PR? then we can merge it!

@Adibvafa
Copy link
Contributor Author

Hi @molbap, I added a test to Mamba for checking this dtype mismatch.
It seems Mamba2 is suffering from the same bug #33409
If the test is good I can add it there too.

@Adibvafa
Copy link
Contributor Author

Adibvafa commented Sep 19, 2024

@molbap There seems to be an issue with the test. I will fix it.

@molbap
Copy link
Contributor

molbap commented Sep 19, 2024

Thanks for your work on that @Adibvafa !

@Adibvafa
Copy link
Contributor Author

Adibvafa commented Sep 25, 2024

@molbap I added a test named test_dtype_mismatch_handled_in_cache to catch any dtype mismatch between model and cache. It only passes by changing conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device) to conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype) as expected.

@ArthurZucker ArthurZucker merged commit c269c5c into huggingface:main Oct 1, 2024
21 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks 🤗

NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
* Fix Mamba slow path bug with dtype mismatch.

* Update test_modeling_mamba.py

* Improve style.

* Fix issue with cache position of dtype mismatch test.

* Change test for slow path.

* Revert changes.

* Switch to buggy code and add test to catch it.

* Fix the dtype mismatch bug and add test code to verify it.

* Fix minor bug with test.

* Fix incorrect dtype of model output.

* Fix incorrect dtype of cache.

* Fix incorrect dtype of ssm cache.

* Fix incorrect dtype of conv state.

* Remove assertion for ssm state.

* Add assertion for conv state dtype.

* Fix all issues with dtype mismatch test.
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* Fix Mamba slow path bug with dtype mismatch.

* Update test_modeling_mamba.py

* Improve style.

* Fix issue with cache position of dtype mismatch test.

* Change test for slow path.

* Revert changes.

* Switch to buggy code and add test to catch it.

* Fix the dtype mismatch bug and add test code to verify it.

* Fix minor bug with test.

* Fix incorrect dtype of model output.

* Fix incorrect dtype of cache.

* Fix incorrect dtype of ssm cache.

* Fix incorrect dtype of conv state.

* Remove assertion for ssm state.

* Add assertion for conv state dtype.

* Fix all issues with dtype mismatch test.
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* Fix Mamba slow path bug with dtype mismatch.

* Update test_modeling_mamba.py

* Improve style.

* Fix issue with cache position of dtype mismatch test.

* Change test for slow path.

* Revert changes.

* Switch to buggy code and add test to catch it.

* Fix the dtype mismatch bug and add test code to verify it.

* Fix minor bug with test.

* Fix incorrect dtype of model output.

* Fix incorrect dtype of cache.

* Fix incorrect dtype of ssm cache.

* Fix incorrect dtype of conv state.

* Remove assertion for ssm state.

* Add assertion for conv state dtype.

* Fix all issues with dtype mismatch test.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants