Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] MSE optimal scale #1105

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,55 @@ def forward(self, x: Tensor):
return torch.abs(max_val - min_val)


class OptimalIntSymmetricScale(brevitas.jit.ScriptModule):

def __init__(self, N: int) -> None:
super(OptimalIntSymmetricScale, self).__init__()
# Possible quantized values are {-N, ..., 0, ..., N}
self.N = N

@brevitas.jit.script_method
def forward(self, x: Tensor):
# Number of elements in the vector
P = len(x)
# Sort absolute values in ascending order
abs_x_sorted, _ = torch.sort(torch.abs(x))

# Scales in which at least one element changes its optimal quantized value
transition_scales = 2 * abs_x_sorted.unsqueeze(0) / (
2 * torch.arange(start=0, end=self.N).unsqueeze(1) + 1)

# This operation can be optimised, considering that each row in transition_scales is sorted. due to the monotonicity
# # of the operation, so the computational cost could be reduced from (NP)log(NP) to (NP)log(N)
_, scales_sorting_indices = torch.sort(transition_scales.view(-1))

# Book-keeping values for determining the optimal scale
sum_w_q = 0
sum_q_squared = 0

optimal_scale = None
optimal_neg_error = float('-inf')

# Update the running scale every time a quantized assignment changes, keeping the value with the lowest loss
for j in reversed(range(P * self.N)):
# Retrieved the corresponding value in the transition table
k, i = scales_sorting_indices[j] // P, scales_sorting_indices[j] % P
# The running sums need to be updated to account for the change in the quantized assignment
sum_w_q -= abs_x_sorted[i] * k
sum_q_squared -= torch.square(k)
sum_w_q += abs_x_sorted[i] * (k + 1)
sum_q_squared += torch.square(k + 1)

neg_error = sum_w_q / torch.sqrt(sum_q_squared)

# Check if the current value maximized the negative error. If so, update the optimal scale
if neg_error > optimal_neg_error:
optimal_neg_error = neg_error
optimal_scale = sum_w_q / sum_q_squared

return optimal_scale


class AbsMaxAve(brevitas.jit.ScriptModule):
__constants__ = ['stats_reduce_dim']

Expand Down
62 changes: 62 additions & 0 deletions tests/brevitas/core/test_opt_scale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import pytest_cases
import torch

from brevitas.core.stats.stats_op import OptimalIntSymmetricScale
from tests.conftest import SEED

# Number of weights to generate
P = 100
GRID_SEARCH_ITERS = 1000
ATOL = 1. / GRID_SEARCH_ITERS


class TestScale:

def test_optimal_scale_ternary(self):
# Quantized values are {-1, 0, 1}
N = 1
# Generate a vector of random weights
x = torch.rand((P,), dtype=torch.float32)

# Optimal scale in the ternary case admits a closed-form solution
# See https://arxiv.org/pdf/1707.04319
abs_sorted_x, _ = torch.sort(torch.abs(x), descending=True)
j_optimal = torch.argmax(
torch.cumsum(abs_sorted_x, dim=-1) / torch.sqrt(torch.arange(start=1, end=P + 1)))
gt_optimal_scale = torch.sum(abs_sorted_x[:j_optimal + 1]) / (j_optimal + 1)

optimal_int_symmetric_scale = OptimalIntSymmetricScale(N=N)
optimal_scale = optimal_int_symmetric_scale(x)

# Compare scales
assert torch.allclose(gt_optimal_scale, optimal_scale)

@pytest_cases.parametrize("N", [2, 3, 5])
# Quantized values are {-N, ..., 0, ..., 1}
def test_optimal_scale_grid_search(self, N):
# Generate a vector of random weights
x = torch.rand((P,), dtype=torch.float32)

# Compute optimal scale
optimal_int_symmetric_scale = OptimalIntSymmetricScale(N=N)
optimal_scale = optimal_int_symmetric_scale(x)

# Compare with that obtained via grid-search
def error_closure(scale):
return torch.sum(torch.square(x - scale * torch.clamp(torch.round(x / scale), -N, N)))

gt_optimal_scale = None
gt_optimal_error = float('inf')

for i in range(GRID_SEARCH_ITERS):
curr_scale = torch.tensor(i / GRID_SEARCH_ITERS, dtype=torch.float32)
curr_error = error_closure(curr_scale)
if curr_error < gt_optimal_error:
gt_optimal_error = curr_error
gt_optimal_scale = curr_scale

torch.allclose(optimal_scale, gt_optimal_scale, atol=ATOL, rtol=1e-1)