-
Notifications
You must be signed in to change notification settings - Fork 216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Match torch.fake_quantize numerics in 8da4w QAT #229
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/229
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 54c48d1 with merge base 3dd16c9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
By the way a few existing tests will fail because we haven't made the corresponding |
Summary: Follow-up to pytorch/ao#229. This resolves the difference between `input.div(scales)` and `input.mul(1.0 / scales)`, which results in small numerical discrepancies on some inputs. Test Plan: python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar [ghstack-poisoned]
these seems to be small differences and I don't expect will cause large error actually |
Summary: Follow-up to pytorch/ao#229. This resolves the difference between `input.div(scales)` and `input.mul(1.0 / scales)`, which results in small numerical discrepancies on some inputs. Test Plan: python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar ghstack-source-id: 77184575ff8349028ea46a0cec88825053c72fef Pull Request resolved: #125781
Summary: Follow-up to pytorch/ao#229. This resolves the difference between `input.div(scales)` and `input.mul(1.0 / scales)`, which results in small numerical discrepancies on some inputs. Test Plan: python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar ghstack-source-id: 63d92d7c7db63f3cb5ca16cc338806bd7b3cc672 Pull Request resolved: #125781
Summary: Follow-up to pytorch/ao#229. This resolves the difference between `input.div(scales)` and `input.mul(1.0 / scales)`, which results in small numerical discrepancies on some inputs. Test Plan: python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar ghstack-source-id: 46c1d3dec51ea9fe873e3995c774ae77b2aa52b0 Pull Request resolved: #125781
…125781) Summary: Follow-up to pytorch/ao#229. This resolves the difference between `input.div(scales)` and `input.mul(1.0 / scales)`, which results in small numerical discrepancies on some inputs. Test Plan: python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token Reviewers: jerryzh168 Subscribers: jerryzh168, supriyar Pull Request resolved: #125781 Approved by: https://github.com/jerryzh168
5fa3fb3
to
56f0f78
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we guard the QAT feature itself under TORCH_VERSION_AFTER_2_4
?
When adding this guard, please provide a clear error messages by providing an empty stub for users of older versions. |
Yup both done |
faa4781
to
5517770
Compare
Summary: There are two subtle differences between the 8da4w quant primitives and `torch.fake_quantize_per_channel_affine` today: 1. 8da4w uses float32 zero points torch.fake_quantize uses int32 zero points 2. 8da4w uses input.div(scales) torch.fake_quantize uses input.mul(1.0 / scales) Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless. This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent. Note: This commit also has the side effect of reducing memory footprint significantly for bf16 inputs. We now cast them to fp32 before multiplying them with fp32 scales. This reduced memory usage presumably because bf16 * fp32 kernels are not as memory efficient. Test Plan: python test/quantization/test_qat.py -k test_qat_generic_fake_quantize Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
Summary: There are two subtle differences between the 8da4w quant primitives and `torch.fake_quantize_per_channel_affine` today: 1. 8da4w uses float32 zero points torch.fake_quantize uses int32 zero points 2. 8da4w uses input.div(scales) torch.fake_quantize uses input.mul(1.0 / scales) Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless. This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent. Note: This commit also has the side effect of reducing memory footprint significantly for bf16 inputs. We now cast them to fp32 before multiplying them with fp32 scales. This reduced memory usage presumably because bf16 * fp32 kernels are not as memory efficient. Test Plan: python test/quantization/test_qat.py -k test_qat_generic_fake_quantize Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
Summary: There are two subtle differences between the 8da4w quant primitives and `torch.fake_quantize_per_channel_affine` today: 1. 8da4w uses float32 zero points torch.fake_quantize uses int32 zero points 2. 8da4w uses input.div(scales) torch.fake_quantize uses input.mul(1.0 / scales) Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless. This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent. Note: This commit also has the side effect of reducing memory footprint significantly for bf16 inputs. We now cast them to fp32 before multiplying them with fp32 scales. This reduced memory usage presumably because bf16 * fp32 kernels are not as memory efficient. Test Plan: python test/quantization/test_qat.py -k test_qat_generic_fake_quantize Reviewers: jerryzh168, cpuhrsch Subscribers: jerryzh168, cpuhrsch, supriyar
Summary: There are two subtle differences between the 8da4w quant primitives and
torch.fake_quantize_per_channel_affine
today:8da4w uses float32 zero points
torch.fake_quantize uses int32 zero points
8da4w uses input.div(scales)
torch.fake_quantize uses input.mul(1.0 / scales)
Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless.
This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent.
Note: This commit also has the side effect of reducing memory footprint significantly for bf16 inputs. We now cast them to fp32 before multiplying them with fp32 scales. This reduced memory usage presumably because bf16 * fp32 kernels are not as memory efficient.
Test Plan:
python test/quantization/test_qat.py -k test_qat_generic_fake_quantize
Reviewers: jerryzh168, cpuhrsch
Subscribers: jerryzh168, cpuhrsch, supriyar