Skip to content

Commit

Permalink
dont skip
Browse files Browse the repository at this point in the history
  • Loading branch information
dsikka committed Oct 21, 2024
1 parent 983e750 commit b8380f7
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 39 deletions.
36 changes: 9 additions & 27 deletions tests/test_quantization/lifecycle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ def test_wrap_module_forward_quantized(create_quantization_scheme):
assert not func_forward == layer.forward.__func__


@pytest.mark.skip(reason="wip")
@pytest.mark.parametrize(
"quantization_status", ["initialized", "calibration", "frozen"]
)
def test_forward_quantize(create_quantization_scheme, quantization_status):
@pytest.mark.parametrize("quantization_status", ["initialized", "calibration"])
def test_forward_quantize(
mock_per_tensor_calibration, create_quantization_scheme, quantization_status
):
num_bits = 8
quantization_scheme = create_quantization_scheme(
targets=["*"],
Expand All @@ -76,37 +75,20 @@ def test_forward_quantize(create_quantization_scheme, quantization_status):
if layer.quantization_status == QuantizationStatus.INITIALIZED:
# Init zp and scales
initialize_module_for_quantization(layer, quantization_scheme)
# init weight observers; update weight scales/zp
set_module_for_calibration(layer)
# mock weight calibration
mock_per_tensor_calibration(layer, "weight", value=layer.weight.data)
# call quant/dequant on weights
out = forward_quantize(layer, layer.weight, "weight", quantization_args)
assert torch.allclose(out, layer.weight.data, atol=0.2)
elif layer.quantization_status == QuantizationStatus.CALIBRATION:
# init zp/scales
initialize_module_for_quantization(layer, quantization_scheme)
# init weight observers; update weight scales/zp
set_module_for_calibration(layer)
# init input observers, update input scales/zp
calibrate_activations(
module=layer,
value=dummy_tensor,
base_name="input",
quantization_args=quantization_args,
)
# run weight and input calibration
mock_per_tensor_calibration(layer, "weight", value=layer.weight.data)
mock_per_tensor_calibration(layer, "input", value=dummy_tensor)
# call quant/dequant on inputs
out = forward_quantize(layer, dummy_tensor, "input", quantization_args)
assert torch.allclose(out, dummy_tensor, atol=0.2)
assert layer.input_observer._num_observed_tokens == dummy_tensor.shape[0]
elif layer.quantization_status == QuantizationStatus.FROZEN:
# init weight observers
initialize_module_for_quantization(layer, quantization_scheme)
# init weight observers; update weight scales/zp
set_module_for_calibration(layer)
# remove weight observers and any input observers
freeze_module_quantization(layer)
# call quant/dequant on weights
out = forward_quantize(layer, layer.weight.data, "weight", quantization_args)
assert torch.allclose(out, layer.weight.data, atol=0.2)


@pytest.mark.parametrize(
Expand Down
20 changes: 8 additions & 12 deletions tests/test_quantization/lifecycle/test_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from torch.nn import Linear


@pytest.mark.skip(reason="wip")
def test_lifecyle(create_quantization_scheme):
def test_lifecyle(mock_per_tensor_calibration, create_quantization_scheme):
num_bits = 8

quantization_scheme = create_quantization_scheme(
Expand Down Expand Up @@ -61,19 +60,17 @@ def test_lifecyle(create_quantization_scheme):
assert hasattr(layer, "quantization_status")
assert layer.quantization_status == QuantizationStatus.INITIALIZED

set_module_for_calibration(layer)
assert hasattr(layer, "weight_observer")
assert layer.quantization_status == QuantizationStatus.CALIBRATION

# do a calibration step
assert torch.numel(layer.input_zero_point.data) == 1
assert torch.numel(layer.input_scale) == 1
assert torch.numel(layer.weight_scale) == 1
assert torch.numel(layer.weight_zero_point) == 1

random_input = torch.randn(4, 4)
random_input[0][0] = 42 # skew distribution to force non-zero zp
layer(random_input)

# do a calibration step
mock_per_tensor_calibration(layer, "weight", value=layer.weight)
mock_per_tensor_calibration(layer, "input", value=random_input)

# zero-points and scale should be updated after forward pass
assert torch.numel(layer.input_zero_point.data) > 0
Expand All @@ -94,7 +91,9 @@ def test_lifecyle(create_quantization_scheme):
for _ in range(10):
random_input = torch.randn(4, 4)
random_input[0][0] = 42 # skew distribution to force non-zero zp
layer(random_input)

mock_per_tensor_calibration(layer, "weight", value=layer.weight)
mock_per_tensor_calibration(layer, "input", value=random_input)

assert initialized_layer_input_zero_point != 0
assert initialized_layer_input_scale != layer.input_scale
Expand All @@ -108,9 +107,6 @@ def test_lifecyle(create_quantization_scheme):
layer_before_freeze_input_scale = deepcopy(layer.input_scale)
layer_before_freeze_weight_scale = deepcopy(layer.weight_scale)

# Freeze, no update after any forward pass
freeze_module_quantization(layer)

for _ in range(10):
layer(torch.randn(4, 4))
assert layer_before_freeze_input_zero_point == layer.input_zero_point
Expand Down

0 comments on commit b8380f7

Please sign in to comment.