Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add bias quantization in QAT and refactor the code of weight quantization #2914

Merged
merged 11 commits into from
Oct 10, 2020

Conversation

linbinskn
Copy link
Contributor

@linbinskn linbinskn commented Sep 22, 2020

Add bias quantization and refactor the code. If this PR was merged, I would submit a new PR to modify the doc.

@linbinskn linbinskn changed the title Add bias quantization in QAT and refactor the code of weight quantiza… Add bias quantization in QAT and refactor the code of weight quantization Sep 22, 2020
@QuanluZhang QuanluZhang mentioned this pull request Sep 23, 2020
79 tasks
@QuanluZhang QuanluZhang linked an issue Sep 23, 2020 that may be closed by this pull request
out[out == 0] = 1
return out
weight[weight == 0] = 1
wrapper.module.weight.data = weight
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return is removed? returned value is never used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returned value is quantized weight. There is no need to return in this version because weight is quantized in place in the function quantize_weight().

assert math.isclose(model.conv2.module.scale, 6 / 255, abs_tol=eps)
assert model.conv2.module.zero_point in (42, 43)
# test value of weight and bias after quantization
weight = torch.tensor([[1.1287, 2.3456], [3.7814, 5.9723]])
weight_valid = torch.tensor([[1.1242, 2.3421], [3.7707, 5.9723]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, how are the values of weight_valid and bias_valid calculated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight_valid and bias_valid are calculated by quantization function manually. I will modify the test case and annotation after code freeze.

weight_bits = get_bits_length(config, 'weight')
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"

if quant_start_step > self.steps:
return weight
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add return here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have fixed it.

@chicm-ms chicm-ms merged commit 0a6c234 into microsoft:master Oct 10, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

NNI can't quantize bias?
3 participants