Skip to content

Commit

Permalink
Update _classification.py
Browse files Browse the repository at this point in the history
  • Loading branch information
muzakkirhussain011 authored Feb 23, 2024
1 parent 154519b commit ccadb93
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions ivy/functional/frontends/sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):


@to_ivy_arrays_and_back
def precision_score(y_true, y_pred, *, sample_weight=None):
def recall_score(y_true, y_pred, *, sample_weight=None):
# Ensure that y_true and y_pred have the same shape
if y_true.shape != y_pred.shape:
raise IvyValueError("y_true and y_pred must have the same shape")
Expand All @@ -29,24 +29,20 @@ def precision_score(y_true, y_pred, *, sample_weight=None):
if sample_weight is not None:
sample_weight = ivy.array(sample_weight)
if sample_weight.shape[0] != y_true.shape[0]:
raise IvyValueError(
"sample_weight must have the same length as y_true and y_pred"
)
raise IvyValueError("sample_weight must have the same length as y_true and y_pred")
sample_weight = sample_weight / ivy.sum(sample_weight)
else:
sample_weight = ivy.ones_like(y_true)

# Calculate true positives and predicted positives
true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype(
"int64"
)
predicted_positives = ivy.equal(y_pred, 1).astype("int64")
# Calculate true positives and actual positives
true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype("int64")
actual_positives = ivy.equal(y_true, 1).astype("int64")

# Apply sample weights
weighted_true_positives = ivy.multiply(true_positives, sample_weight)
weighted_predicted_positives = ivy.multiply(predicted_positives, sample_weight)
weighted_actual_positives = ivy.multiply(actual_positives, sample_weight)

# Compute precision
ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_predicted_positives)
# Compute recall
ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_actual_positives)
ret = ret.astype("float64")
return ret

0 comments on commit ccadb93

Please sign in to comment.