Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KL error method per-tensor #1041

Merged
merged 2 commits into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def _lp_error_histogram(q_bins: np.ndarray,


def _kl_error_function(x: np.ndarray,
range_min: float,
range_max: float,
range_min: np.ndarray,
range_max: np.ndarray,
n_bins: int = 2048,
n_bits: int = 8) -> np.float32:
"""
Expand Down Expand Up @@ -148,7 +148,8 @@ def _kl_error_function_wrapper(x: np.ndarray,
range_min: np.ndarray,
range_max: np.ndarray,
n_bins: int = 2048,
n_bits: int = 8) -> np.ndarray:
n_bits: int = 8,
per_channel: int = False) -> np.ndarray:
"""
Computes the error function between a tensor and its quantized version for each channel.
The error is based on the KL-divergence between the distributions.
Expand All @@ -161,24 +162,28 @@ def _kl_error_function_wrapper(x: np.ndarray,
range_max: Array specifying the maximum bound of the quantization range for each channel.
n_bins: Number of bins for the float histogram.
n_bits: Number of bits used for quantization.
per_channel: Whether quantization is done per-channel.

Returns:
An array containing the KL-divergence between the float and quantized histograms of the tensor for each channel.

"""

error_list = []
for j in range(x.shape[0]): # iterate all channels of the tensor.
error_list.append(_kl_error_function(x[j], range_min[j], range_max[j], n_bins=n_bins, n_bits=n_bits))
if per_channel:
for j in range(x.shape[0]): # iterate all channels of the tensor.
error_list.append(_kl_error_function(x[j], range_min[j], range_max[j], n_bins=n_bins, n_bits=n_bits))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can be done in a vectorial way?

else:
error_list.append(_kl_error_function(x, range_min, range_max, n_bins=n_bins, n_bits=n_bits))
return np.asarray(error_list)


def _kl_error_histogram(q_bins: np.ndarray,
q_count: np.ndarray,
bins: np.ndarray,
counts: np.ndarray,
range_min: float,
range_max: float) -> np.float32:
range_min: np.ndarray,
range_max: np.ndarray) -> np.float32:
"""
Compute the error function between a histogram to its quantized version.
The error is computed based on the KL-divergence the distributions have.
Expand Down Expand Up @@ -241,8 +246,8 @@ def _kl_error_histogram(q_bins: np.ndarray,


def _get_bins_indices_from_range(bins: np.ndarray,
range_min: float,
range_max: float) -> Tuple[int, int]:
range_min: np.ndarray,
range_max: np.ndarray) -> Tuple[int, int]:
"""
For bins and a threshold, compute the first and last bins in between the threshold
ranges.
Expand All @@ -262,7 +267,7 @@ def _get_bins_indices_from_range(bins: np.ndarray,
return first_bin_idx, last_bin_idx


def _is_range_valid(bins: np.ndarray, range_min: float, range_max: float) -> bool:
def _is_range_valid(bins: np.ndarray, range_min: np.ndarray, range_max: np.ndarray) -> bool:
"""
Check whether there are some bins from a numpy array of bins that are in between
a threshold range or not.
Expand Down Expand Up @@ -387,15 +392,36 @@ def get_threshold_selection_tensor_error_function(quantization_method: Quantizat

Returns: a Callable method that calculates the error between a tensor and a quantized tensor.
"""
if quant_error_method == qc.QuantizationErrorMethod.KL:
if axis is None:
# per-tensor
if quantization_method == QuantizationMethod.UNIFORM:
return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=threshold[0],
range_max=threshold[1],
n_bits=n_bits,
per_channel=False)
else:
return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=0 if not signed else -threshold,
range_max=threshold,
n_bits=n_bits,
per_channel=False)
else:
# per-channel
if quantization_method == QuantizationMethod.UNIFORM:
return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=threshold[:, 0],
range_max=threshold[:, 1],
n_bits=n_bits,
per_channel=True)
else:
return lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=0 if not signed else -threshold,
range_max=threshold,
n_bits=n_bits,
per_channel=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use variables to pass to _kl_error_function_wrapper to avoid unnecessary duplications?


quant_method_error_function_mapping = {
qc.QuantizationErrorMethod.MSE: lambda x, y, threshold: compute_mse(x, y, norm=norm, axis=axis),
qc.QuantizationErrorMethod.MAE: lambda x, y, threshold: compute_mae(x, y, norm=norm, axis=axis),
qc.QuantizationErrorMethod.LP: lambda x, y, threshold: compute_lp_norm(x, y, p=p, norm=norm, axis=axis),
qc.QuantizationErrorMethod.KL:
lambda x, y, threshold: _kl_error_function_wrapper(x, range_min=threshold[:,0], range_max=threshold[:,1],
n_bits=n_bits) if quantization_method == QuantizationMethod.UNIFORM
else _kl_error_function_wrapper(x, range_min=0 if not signed else -threshold, range_max=threshold, n_bits=n_bits)
}

return quant_method_error_function_mapping[quant_error_method]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def representative_data_gen():
tp_model = generate_test_tp_model({
'weights_quantization_method': quantize_method,
'weights_n_bits': 8,
'activation_n_bits': 8})
'activation_n_bits': 8,
'weights_per_channel_threshold': per_channel})
tpc = generate_keras_tpc(name="kl_quant_config_weights_test", tp_model=tp_model)

qc = mct.core.QuantizationConfig(activation_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def representative_data_gen():
tp = generate_test_tp_model({
'weights_quantization_method': quantize_method,
'weights_n_bits': 8,
'activation_n_bits': 16})
'activation_n_bits': 16,
'weights_per_channel_threshold': per_channel})
tpc = generate_keras_tpc(name="quant_config_weights_test", tp_model=tp)

qc = mct.core.QuantizationConfig(activation_error_method=mct.core.QuantizationErrorMethod.NOCLIPPING,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def representative_dataset():
yield [np.random.randn(1, 16, 16, 4).astype(np.float32)]


def get_tpc():
tp = generate_test_tp_model({
'weights_quantization_method': mct.target_platform.QuantizationMethod.SYMMETRIC})
def get_tpc(per_channel):
tp = generate_test_tp_model(edit_params_dict={
'weights_quantization_method': mct.target_platform.QuantizationMethod.SYMMETRIC,
'weights_per_channel_threshold': per_channel})
tpc = generate_keras_tpc(name="symmetric_threshold_selection_test", tp_model=tp)

return tpc
Expand Down Expand Up @@ -99,7 +100,8 @@ def run_test_for_threshold_method(self, threshold_method, per_channel=True):

in_model = create_network()
graph = prepare_graph_with_quantization_parameters(in_model, KerasImplementation(), DEFAULT_KERAS_INFO,
representative_dataset, lambda name, _tp: get_tpc(),
representative_dataset,
lambda name, _tp: get_tpc(per_channel),
qc=qc, input_shape=(1, 16, 16, 4))

nodes_list = list(graph.nodes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def representative_dataset():
yield [np.random.randn(1, 16, 16, 4).astype(np.float32)]


def get_tpc():
def get_tpc(per_channel):
tp = generate_test_tp_model({
'weights_quantization_method': mct.target_platform.QuantizationMethod.UNIFORM})
'weights_quantization_method': mct.target_platform.QuantizationMethod.UNIFORM,
'weights_per_channel_threshold': per_channel})
tpc = generate_keras_tpc(name="uniform_range_selection_test", tp_model=tp)

return tpc
Expand Down Expand Up @@ -98,7 +99,8 @@ def run_test_for_threshold_method(self, threshold_method, per_channel=True):

in_model = create_network()
graph = prepare_graph_with_quantization_parameters(in_model, KerasImplementation(), DEFAULT_KERAS_INFO,
representative_dataset, lambda name, _tp: get_tpc(),
representative_dataset,
lambda name, _tp: get_tpc(per_channel),
qc=qc, input_shape=(1, 16, 16, 4))

nodes_list = list(graph.nodes)
Expand Down
Loading