Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match torch.fake_quantize numerics in 8da4w QAT #229

Merged
merged 1 commit into from
May 15, 2024
Merged

Match torch.fake_quantize numerics in 8da4w QAT #229

merged 1 commit into from
May 15, 2024

Conversation

andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented May 8, 2024

Summary: There are two subtle differences between the 8da4w quant primitives and torch.fake_quantize_per_channel_affine today:

  1. 8da4w uses float32 zero points
    torch.fake_quantize uses int32 zero points

  2. 8da4w uses input.div(scales)
    torch.fake_quantize uses input.mul(1.0 / scales)

Of these two differences, the second one is smaller and only resulted in 0.1% elements mismatched in unit tests, but it is a source of numerical divergence nonetheless.

This commit changes 8da4w QAT quant primitives to match the torch.fake_quantize behavior for both of these differences. In a future commit, we will change the 8da4w PTQ quant primitives as well so PTQ and QAT remain consistent.

Note: This commit also has the side effect of reducing memory footprint significantly for bf16 inputs. We now cast them to fp32 before multiplying them with fp32 scales. This reduced memory usage presumably because bf16 * fp32 kernels are not as memory efficient.

Test Plan:
python test/quantization/test_qat.py -k test_qat_generic_fake_quantize

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar

@andrewor14 andrewor14 requested a review from jerryzh168 May 8, 2024 18:34
Copy link

pytorch-bot bot commented May 8, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/229

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 54c48d1 with merge base 3dd16c9 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 8, 2024
@andrewor14
Copy link
Contributor Author

By the way a few existing tests will fail because we haven't made the corresponding input.mul(1.0 / scale) change in PyTorch yet. I'm writing a PR there now

andrewor14 added a commit to pytorch/pytorch that referenced this pull request May 8, 2024
Summary: Follow-up to pytorch/ao#229.
This resolves the difference between `input.div(scales)` and
`input.mul(1.0 / scales)`, which results in small numerical
discrepancies on some inputs.

Test Plan:
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token

Reviewers: jerryzh168

Subscribers: jerryzh168, supriyar

[ghstack-poisoned]
@jerryzh168
Copy link
Contributor

these seems to be small differences and I don't expect will cause large error actually

pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request May 8, 2024
Summary: Follow-up to pytorch/ao#229.
This resolves the difference between `input.div(scales)` and
`input.mul(1.0 / scales)`, which results in small numerical
discrepancies on some inputs.

Test Plan:
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token

Reviewers: jerryzh168

Subscribers: jerryzh168, supriyar

ghstack-source-id: 77184575ff8349028ea46a0cec88825053c72fef
Pull Request resolved: #125781
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request May 9, 2024
Summary: Follow-up to pytorch/ao#229.
This resolves the difference between `input.div(scales)` and
`input.mul(1.0 / scales)`, which results in small numerical
discrepancies on some inputs.

Test Plan:
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token

Reviewers: jerryzh168

Subscribers: jerryzh168, supriyar

ghstack-source-id: 63d92d7c7db63f3cb5ca16cc338806bd7b3cc672
Pull Request resolved: #125781
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request May 13, 2024
Summary: Follow-up to pytorch/ao#229.
This resolves the difference between `input.div(scales)` and
`input.mul(1.0 / scales)`, which results in small numerical
discrepancies on some inputs.

Test Plan:
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token

Reviewers: jerryzh168

Subscribers: jerryzh168, supriyar

ghstack-source-id: 46c1d3dec51ea9fe873e3995c774ae77b2aa52b0
Pull Request resolved: #125781
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request May 14, 2024
…125781)

Summary: Follow-up to pytorch/ao#229.
This resolves the difference between `input.div(scales)` and
`input.mul(1.0 / scales)`, which results in small numerical
discrepancies on some inputs.

Test Plan:
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_channel_group
python test/test_quantization.py TestQuantizedTensor.test_decomposed_quantize_per_token

Reviewers: jerryzh168

Subscribers: jerryzh168, supriyar
Pull Request resolved: #125781
Approved by: https://github.com/jerryzh168
@andrewor14 andrewor14 force-pushed the fq-test branch 2 times, most recently from 5fa3fb3 to 56f0f78 Compare May 15, 2024 18:17
Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we guard the QAT feature itself under TORCH_VERSION_AFTER_2_4?

@cpuhrsch
Copy link
Contributor

When adding this guard, please provide a clear error messages by providing an empty stub for users of older versions.

@andrewor14
Copy link
Contributor Author

can we guard the QAT feature itself under TORCH_VERSION_AFTER_2_4?
When adding this guard, please provide a clear error messages by providing an empty stub for users of older versions.

Yup both done

@andrewor14 andrewor14 force-pushed the fq-test branch 2 times, most recently from faa4781 to 5517770 Compare May 15, 2024 21:57
Summary: There are two subtle differences between the 8da4w quant
primitives and `torch.fake_quantize_per_channel_affine` today:

1. 8da4w uses float32 zero points
   torch.fake_quantize uses int32 zero points

2. 8da4w uses input.div(scales)
   torch.fake_quantize uses input.mul(1.0 / scales)

Of these two differences, the second one is smaller and only
resulted in 0.1% elements mismatched in unit tests, but it is
a source of numerical divergence nonetheless.

This commit changes 8da4w QAT quant primitives to match the
torch.fake_quantize behavior for both of these differences.
In a future commit, we will change the 8da4w PTQ quant
primitives as well so PTQ and QAT remain consistent.

Note: This commit also has the side effect of reducing memory
footprint significantly for bf16 inputs. We now cast them to
fp32 before multiplying them with fp32 scales. This reduced
memory usage presumably because bf16 * fp32 kernels are not
as memory efficient.

Test Plan:
python test/quantization/test_qat.py -k test_qat_generic_fake_quantize

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
@andrewor14 andrewor14 merged commit cae3d82 into main May 15, 2024
13 checks passed
@andrewor14 andrewor14 deleted the fq-test branch May 16, 2024 15:29
lancerts pushed a commit to lancerts/ao that referenced this pull request May 17, 2024
Summary: There are two subtle differences between the 8da4w quant
primitives and `torch.fake_quantize_per_channel_affine` today:

1. 8da4w uses float32 zero points
   torch.fake_quantize uses int32 zero points

2. 8da4w uses input.div(scales)
   torch.fake_quantize uses input.mul(1.0 / scales)

Of these two differences, the second one is smaller and only
resulted in 0.1% elements mismatched in unit tests, but it is
a source of numerical divergence nonetheless.

This commit changes 8da4w QAT quant primitives to match the
torch.fake_quantize behavior for both of these differences.
In a future commit, we will change the 8da4w PTQ quant
primitives as well so PTQ and QAT remain consistent.

Note: This commit also has the side effect of reducing memory
footprint significantly for bf16 inputs. We now cast them to
fp32 before multiplying them with fp32 scales. This reduced
memory usage presumably because bf16 * fp32 kernels are not
as memory efficient.

Test Plan:
python test/quantization/test_qat.py -k test_qat_generic_fake_quantize

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
dbyoung18 pushed a commit to dbyoung18/ao that referenced this pull request Jul 31, 2024
Summary: There are two subtle differences between the 8da4w quant
primitives and `torch.fake_quantize_per_channel_affine` today:

1. 8da4w uses float32 zero points
   torch.fake_quantize uses int32 zero points

2. 8da4w uses input.div(scales)
   torch.fake_quantize uses input.mul(1.0 / scales)

Of these two differences, the second one is smaller and only
resulted in 0.1% elements mismatched in unit tests, but it is
a source of numerical divergence nonetheless.

This commit changes 8da4w QAT quant primitives to match the
torch.fake_quantize behavior for both of these differences.
In a future commit, we will change the 8da4w PTQ quant
primitives as well so PTQ and QAT remain consistent.

Note: This commit also has the side effect of reducing memory
footprint significantly for bf16 inputs. We now cast them to
fp32 before multiplying them with fp32 scales. This reduced
memory usage presumably because bf16 * fp32 kernels are not
as memory efficient.

Test Plan:
python test/quantization/test_qat.py -k test_qat_generic_fake_quantize

Reviewers: jerryzh168, cpuhrsch

Subscribers: jerryzh168, cpuhrsch, supriyar
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants