Skip to content

Commit

Permalink
feat: Implement precision_score function and test aligned with sklear…
Browse files Browse the repository at this point in the history
…n metrics (#28407)
  • Loading branch information
muzakkirhussain011 authored Mar 5, 2024
1 parent be9f4f9 commit 108e7b8
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 14 deletions.
24 changes: 12 additions & 12 deletions ivy/functional/backends/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _array_unflatten(aux_data, children):

# update these to add new dtypes
valid_dtypes = {
"0.4.24 and below": (
"0.4.25 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -121,7 +121,7 @@ def _array_unflatten(aux_data, children):
)
}
valid_numeric_dtypes = {
"0.4.24 and below": (
"0.4.25 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -140,7 +140,7 @@ def _array_unflatten(aux_data, children):
}

valid_int_dtypes = {
"0.4.24 and below": (
"0.4.25 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -153,12 +153,12 @@ def _array_unflatten(aux_data, children):
}

valid_uint_dtypes = {
"0.4.24 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
"0.4.25 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
}
valid_float_dtypes = {
"0.4.24 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
"0.4.25 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
}
valid_complex_dtypes = {"0.4.24 and below": (ivy.complex64, ivy.complex128)}
valid_complex_dtypes = {"0.4.25 and below": (ivy.complex64, ivy.complex128)}


# leave these untouched
Expand All @@ -173,12 +173,12 @@ def _array_unflatten(aux_data, children):
# invalid data types

# update these to add new dtypes
invalid_dtypes = {"0.4.24 and below": ()}
invalid_numeric_dtypes = {"0.4.24 and below": ()}
invalid_int_dtypes = {"0.4.24 and below": ()}
invalid_float_dtypes = {"0.4.24 and below": ()}
invalid_uint_dtypes = {"0.4.24 and below": ()}
invalid_complex_dtypes = {"0.4.24 and below": ()}
invalid_dtypes = {"0.4.25 and below": ()}
invalid_numeric_dtypes = {"0.4.25 and below": ()}
invalid_int_dtypes = {"0.4.25 and below": ()}
invalid_float_dtypes = {"0.4.25 and below": ()}
invalid_uint_dtypes = {"0.4.25 and below": ()}
invalid_complex_dtypes = {"0.4.25 and below": ()}

# leave these untouched
invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version)
Expand Down
35 changes: 34 additions & 1 deletion ivy/functional/frontends/sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):


@to_ivy_arrays_and_back
def recall_score(y_true, y_pred, *, sample_weight=None):
def precision_score(y_true, y_pred, *, sample_weight=None):
# Ensure that y_true and y_pred have the same shape
if y_true.shape != y_pred.shape:
raise IvyValueError("y_true and y_pred must have the same shape")
Expand All @@ -36,7 +36,39 @@ def recall_score(y_true, y_pred, *, sample_weight=None):
sample_weight = sample_weight / ivy.sum(sample_weight)
else:
sample_weight = ivy.ones_like(y_true)
# Calculate true positives and predicted positives
true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype(
"int64"
)
predicted_positives = ivy.equal(y_pred, 1).astype("int64")

# Apply sample weights
weighted_true_positives = ivy.multiply(true_positives, sample_weight)
weighted_predicted_positives = ivy.multiply(predicted_positives, sample_weight)

# Compute precision
ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_predicted_positives)

ret = ret.astype("float64")
return ret


@to_ivy_arrays_and_back
def recall_score(y_true, y_pred, *, sample_weight=None):
# Ensure that y_true and y_pred have the same shape
if y_true.shape != y_pred.shape:
raise IvyValueError("y_true and y_pred must have the same shape")

# Check if sample_weight is provided and normalize it
if sample_weight is not None:
sample_weight = ivy.array(sample_weight)
if sample_weight.shape[0] != y_true.shape[0]:
raise IvyValueError(
"sample_weight must have the same length as y_true and y_pred"
)
sample_weight = sample_weight / ivy.sum(sample_weight)
else:
sample_weight = ivy.ones_like(y_true)
# Calculate true positives and actual positives
true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype(
"int64"
Expand All @@ -49,5 +81,6 @@ def recall_score(y_true, y_pred, *, sample_weight=None):

# Compute recall
ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_actual_positives)

ret = ret.astype("float64")
return ret
6 changes: 5 additions & 1 deletion ivy/utils/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,11 @@ def check_dev_correct_formatting(device):


def _check_jax_x64_flag(dtype):
if ivy.backend == "jax" and not ivy.functional.backends.jax.jax.config.x64_enabled:
if (
ivy.backend == "jax"
and not ivy.functional.backends.jax.jax.config.jax_enable_x64
):

ivy.utils.assertions.check_elem_in_list(
dtype,
["float64", "int64", "uint64", "complex128"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,71 @@ def test_sklearn_accuracy_score(
)


@handle_frontend_test(
fn_tree="sklearn.metrics.precision_score",
arrays_and_dtypes=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("integer"),
num_arrays=2,
min_value=0,
max_value=1, # Precision score is for binary classification
shared_dtype=True,
shape=(helpers.ints(min_value=2, max_value=5)),
),
sample_weight=st.lists(
st.floats(min_value=0.1, max_value=1), min_size=2, max_size=5
),
)
def test_sklearn_precision_score(
arrays_and_dtypes,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
sample_weight,
):
dtypes, values = arrays_and_dtypes
# Ensure the values are binary by rounding and converting to int
for i in range(2):
values[i] = np.round(values[i]).astype(int)

# Adjust sample_weight to have the correct length
sample_weight = np.array(sample_weight).astype(float)
if len(sample_weight) != len(values[0]):
# If sample_weight is shorter, extend it with ones
sample_weight = np.pad(
sample_weight,
(0, max(0, len(values[0]) - len(sample_weight))),
"constant",
constant_values=1.0,
)
# If sample_weight is longer, truncate it
sample_weight = sample_weight[: len(values[0])]

# Detach tensors if they require grad before converting to NumPy arrays
if backend_fw == "torch":
values = [
(
value.detach().numpy()
if isinstance(value, torch.Tensor) and value.requires_grad
else value
)
for value in values
]

helpers.test_frontend_function(
input_dtypes=dtypes,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
frontend=frontend,
on_device=on_device,
y_true=values[0],
y_pred=values[1],
sample_weight=sample_weight,
)


@handle_frontend_test(
fn_tree="sklearn.metrics.recall_score",
arrays_and_dtypes=helpers.dtype_and_values(
Expand Down

0 comments on commit 108e7b8

Please sign in to comment.