-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Mamba slow path bug with dtype mismatch. #32691
Conversation
Thanks @Adibvafa for the fix! Do you think we could we add a test for f16/bf16 inference to avoid possible regressions later? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, @molbap feel free to merge once we properly test this case!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hi again @Adibvafa , I'm back from holidays :) iiuc there aren't tests for various inference dtypes, do you think they could be added to this PR? then we can merge it! |
@molbap There seems to be an issue with the test. I will fix it. |
Thanks for your work on that @Adibvafa ! |
@molbap I added a test named |
Thanks 🤗 |
* Fix Mamba slow path bug with dtype mismatch. * Update test_modeling_mamba.py * Improve style. * Fix issue with cache position of dtype mismatch test. * Change test for slow path. * Revert changes. * Switch to buggy code and add test to catch it. * Fix the dtype mismatch bug and add test code to verify it. * Fix minor bug with test. * Fix incorrect dtype of model output. * Fix incorrect dtype of cache. * Fix incorrect dtype of ssm cache. * Fix incorrect dtype of conv state. * Remove assertion for ssm state. * Add assertion for conv state dtype. * Fix all issues with dtype mismatch test.
* Fix Mamba slow path bug with dtype mismatch. * Update test_modeling_mamba.py * Improve style. * Fix issue with cache position of dtype mismatch test. * Change test for slow path. * Revert changes. * Switch to buggy code and add test to catch it. * Fix the dtype mismatch bug and add test code to verify it. * Fix minor bug with test. * Fix incorrect dtype of model output. * Fix incorrect dtype of cache. * Fix incorrect dtype of ssm cache. * Fix incorrect dtype of conv state. * Remove assertion for ssm state. * Add assertion for conv state dtype. * Fix all issues with dtype mismatch test.
* Fix Mamba slow path bug with dtype mismatch. * Update test_modeling_mamba.py * Improve style. * Fix issue with cache position of dtype mismatch test. * Change test for slow path. * Revert changes. * Switch to buggy code and add test to catch it. * Fix the dtype mismatch bug and add test code to verify it. * Fix minor bug with test. * Fix incorrect dtype of model output. * Fix incorrect dtype of cache. * Fix incorrect dtype of ssm cache. * Fix incorrect dtype of conv state. * Remove assertion for ssm state. * Add assertion for conv state dtype. * Fix all issues with dtype mismatch test.
What does this PR do?
Fix issue #32690
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@ArthurZucker