Skip to content

Commit

Permalink
Improve performance of clip_values utility function(cleanlab#1104)
Browse files Browse the repository at this point in the history
  • Loading branch information
gogetron authored Jul 2, 2024
1 parent e67c4ae commit c915f77
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 12 deletions.
19 changes: 8 additions & 11 deletions cleanlab/internal/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def clip_noise_rates(noise_matrix: np.ndarray) -> np.ndarray:
return noise_matrix


def clip_values(x, low=0.0, high=1.0, new_sum=None) -> np.ndarray:
def clip_values(x, low=0.0, high=1.0, new_sum: Optional[float] = None) -> np.ndarray:
"""Clip all values in p to range [low,high].
Preserves sum of x.
Expand All @@ -115,17 +115,14 @@ def clip_values(x, low=0.0, high=1.0, new_sum=None) -> np.ndarray:
x : np.ndarray
A list of clipped values, summing to the same sum as x."""

def clip_range(a, low=low, high=high):
"""Clip a into range [low,high]"""
return min(max(a, low), high)

vectorized_clip = np.vectorize(
clip_range
) # Vectorize clip_range for efficiency with np.ndarrays
prev_sum = sum(x) if new_sum is None else new_sum # Store previous sum
x = vectorized_clip(x) # Clip all values (efficiently)
if len(x.shape) > 1:
raise TypeError(
f"only size-1 arrays can be converted to Python scalars but 'x' had shape {x.shape}"
)
prev_sum = np.sum(x) if new_sum is None else new_sum # Store previous sum
x = np.clip(x, low, high) # Clip all values (efficiently)
x = (
x * prev_sum / np.clip(float(sum(x)), a_min=TINY_VALUE, a_max=None)
x * prev_sum / np.clip(np.sum(x), a_min=TINY_VALUE, a_max=None)
) # Re-normalized values to sum to previous sum
return x

Expand Down
2 changes: 1 addition & 1 deletion tests/test_latent_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with cleanlab. If not, see <https://www.gnu.org/licenses/>.

from cleanlab.internal import latent_algebra
import numpy as np
import pytest

from cleanlab.internal import latent_algebra

s = [0] * 10 + [1] * 5 + [2] * 15
nm = np.array([[1.0, 0.0, 0.2], [0.0, 0.7, 0.2], [0.0, 0.3, 0.6]])
Expand Down

0 comments on commit c915f77

Please sign in to comment.