Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement precision_score function and test aligned with sklearn metrics #28407

Merged
merged 68 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
54e5e6f
Create func_wrapper.py
muzakkirhussain011 Sep 4, 2023
8a9de57
feat: Implement precision_score function aligned with sklearn metrics
muzakkirhussain011 Feb 23, 2024
ce7067e
feat: Implement precision_score function test aligned with sklearn me…
muzakkirhussain011 Feb 23, 2024
81ecd3f
🤖 Lint code
ivy-branch Feb 23, 2024
ac9e299
🤖 Lint code
ivy-branch Feb 23, 2024
d4dfe4d
Delete ivy/functional/frontends/transformers/func_wrapper.py
muzakkirhussain011 Feb 23, 2024
a8cffef
Merge branch 'main' into patch-1
muzakkirhussain011 Feb 25, 2024
5fc6caf
🤖 Lint code
ivy-branch Feb 25, 2024
66ad054
Resolved indent error sklearn_test_precision_score
muzakkirhussain011 Feb 27, 2024
7ddbe54
🤖 Lint code
ivy-branch Feb 27, 2024
99482a2
Resolved indent errors _classification.py
muzakkirhussain011 Feb 27, 2024
0480200
🤖 Lint code
ivy-branch Feb 27, 2024
ad4fa46
Updated dtypes in test_classification.py
muzakkirhussain011 Feb 27, 2024
d2acd5f
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
f904b6e
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
d4be9f4
🤖 Lint code
ivy-branch Feb 27, 2024
02ed1df
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
0bb503c
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
b1d4d9f
🤖 Lint code
ivy-branch Feb 27, 2024
22c6c68
updated test_classification.py
muzakkirhussain011 Feb 27, 2024
fd9198f
🤖 Lint code
ivy-branch Feb 27, 2024
80fe8b3
Update test_classification.py
muzakkirhussain011 Feb 27, 2024
4b2ee9d
🤖 Lint code
ivy-branch Feb 27, 2024
fef2396
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
9d583b3
🤖 Lint code
ivy-branch Feb 27, 2024
9e3c701
Updated test_classifciation.py
muzakkirhussain011 Feb 27, 2024
5fc8174
🤖 Lint code
ivy-branch Feb 27, 2024
e8cb8d1
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
46336db
🤖 Lint code
ivy-branch Feb 27, 2024
1033828
Update test_classification.py
muzakkirhussain011 Feb 27, 2024
54a4853
🤖 Lint code
ivy-branch Feb 27, 2024
cff01e9
Update test_classification.py
muzakkirhussain011 Feb 27, 2024
5d3a8d6
🤖 Lint code
ivy-branch Feb 27, 2024
9fe0c06
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
f790e91
🤖 Lint code
ivy-branch Feb 27, 2024
db753fb
Updated test_classification.py
muzakkirhussain011 Feb 27, 2024
72a1a70
🤖 Lint code
ivy-branch Feb 27, 2024
1ffa551
Update classification.py
muzakkirhussain011 Feb 27, 2024
69bdae8
🤖 Lint code
ivy-branch Feb 27, 2024
02095c1
Updated _classification.py
muzakkirhussain011 Feb 27, 2024
9bba1d0
🤖 Lint code
ivy-branch Feb 27, 2024
128c86c
Update config.py jax
muzakkirhussain011 Feb 28, 2024
20493de
Updated test_classification.py
muzakkirhussain011 Feb 28, 2024
9a7efde
🤖 Lint code
ivy-branch Feb 28, 2024
ad640d1
Update test_classification.py
muzakkirhussain011 Feb 28, 2024
964ca7b
Update test_classification.py
muzakkirhussain011 Feb 28, 2024
e4afd6e
Update test_classification.py
muzakkirhussain011 Feb 28, 2024
c59b930
🤖 Lint code
ivy-branch Feb 28, 2024
fe0887e
Update test_classification.py
muzakkirhussain011 Feb 28, 2024
9c2a268
🤖 Lint code
ivy-branch Feb 28, 2024
7575a01
Update test_classification.py
muzakkirhussain011 Feb 28, 2024
d7956f4
🤖 Lint code
ivy-branch Feb 28, 2024
52862bf
Update test_classication.py
muzakkirhussain011 Feb 28, 2024
1b6db54
Update test_classification.py
muzakkirhussain011 Feb 28, 2024
423d5a7
🤖 Lint code
ivy-branch Feb 28, 2024
a9fb4a4
Update test_classification.py
muzakkirhussain011 Feb 28, 2024
6b79f79
🤖 Lint code
ivy-branch Feb 28, 2024
22aebd2
Merge branch 'main' into patch-1
muzakkirhussain011 Feb 28, 2024
ab7d4ed
Updated
muzakkirhussain011 Feb 28, 2024
209e882
🤖 Lint code
ivy-branch Feb 28, 2024
bc1ae81
updated
muzakkirhussain011 Feb 28, 2024
1f9271d
🤖 Lint code
ivy-branch Feb 28, 2024
f559edb
updated jax __init__.py
muzakkirhussain011 Feb 28, 2024
ea488cd
🤖 Lint code
ivy-branch Feb 28, 2024
68f85d1
Updated test_classification.py
muzakkirhussain011 Feb 28, 2024
19dd4d8
🤖 Lint code
ivy-branch Feb 28, 2024
13aa09b
Updated test_classification.py
muzakkirhussain011 Feb 28, 2024
8843022
🤖 Lint code
ivy-branch Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions ivy/functional/backends/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def _array_unflatten(aux_data, children):

# update these to add new dtypes
valid_dtypes = {
"0.4.24 and below": (
"0.4.25 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -121,7 +121,7 @@ def _array_unflatten(aux_data, children):
)
}
valid_numeric_dtypes = {
"0.4.24 and below": (
"0.4.25 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -140,7 +140,7 @@ def _array_unflatten(aux_data, children):
}

valid_int_dtypes = {
"0.4.24 and below": (
"0.4.25 and below": (
ivy.int8,
ivy.int16,
ivy.int32,
Expand All @@ -153,12 +153,12 @@ def _array_unflatten(aux_data, children):
}

valid_uint_dtypes = {
"0.4.24 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
"0.4.25 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64)
}
valid_float_dtypes = {
"0.4.24 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
"0.4.25 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64)
}
valid_complex_dtypes = {"0.4.24 and below": (ivy.complex64, ivy.complex128)}
valid_complex_dtypes = {"0.4.25 and below": (ivy.complex64, ivy.complex128)}


# leave these untouched
Expand All @@ -173,12 +173,12 @@ def _array_unflatten(aux_data, children):
# invalid data types

# update these to add new dtypes
invalid_dtypes = {"0.4.24 and below": ()}
invalid_numeric_dtypes = {"0.4.24 and below": ()}
invalid_int_dtypes = {"0.4.24 and below": ()}
invalid_float_dtypes = {"0.4.24 and below": ()}
invalid_uint_dtypes = {"0.4.24 and below": ()}
invalid_complex_dtypes = {"0.4.24 and below": ()}
invalid_dtypes = {"0.4.25 and below": ()}
invalid_numeric_dtypes = {"0.4.25 and below": ()}
invalid_int_dtypes = {"0.4.25 and below": ()}
invalid_float_dtypes = {"0.4.25 and below": ()}
invalid_uint_dtypes = {"0.4.25 and below": ()}
invalid_complex_dtypes = {"0.4.25 and below": ()}

# leave these untouched
invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version)
Expand Down
35 changes: 34 additions & 1 deletion ivy/functional/frontends/sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):


@to_ivy_arrays_and_back
def recall_score(y_true, y_pred, *, sample_weight=None):
def precision_score(y_true, y_pred, *, sample_weight=None):
# Ensure that y_true and y_pred have the same shape
if y_true.shape != y_pred.shape:
raise IvyValueError("y_true and y_pred must have the same shape")
Expand All @@ -36,7 +36,39 @@ def recall_score(y_true, y_pred, *, sample_weight=None):
sample_weight = sample_weight / ivy.sum(sample_weight)
else:
sample_weight = ivy.ones_like(y_true)
# Calculate true positives and predicted positives
true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype(
"int64"
)
predicted_positives = ivy.equal(y_pred, 1).astype("int64")

# Apply sample weights
weighted_true_positives = ivy.multiply(true_positives, sample_weight)
weighted_predicted_positives = ivy.multiply(predicted_positives, sample_weight)

# Compute precision
ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_predicted_positives)

ret = ret.astype("float64")
return ret


@to_ivy_arrays_and_back
def recall_score(y_true, y_pred, *, sample_weight=None):
# Ensure that y_true and y_pred have the same shape
if y_true.shape != y_pred.shape:
raise IvyValueError("y_true and y_pred must have the same shape")

# Check if sample_weight is provided and normalize it
if sample_weight is not None:
sample_weight = ivy.array(sample_weight)
if sample_weight.shape[0] != y_true.shape[0]:
raise IvyValueError(
"sample_weight must have the same length as y_true and y_pred"
)
sample_weight = sample_weight / ivy.sum(sample_weight)
else:
sample_weight = ivy.ones_like(y_true)
# Calculate true positives and actual positives
true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype(
"int64"
Expand All @@ -49,5 +81,6 @@ def recall_score(y_true, y_pred, *, sample_weight=None):

# Compute recall
ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_actual_positives)

ret = ret.astype("float64")
return ret
6 changes: 5 additions & 1 deletion ivy/utils/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,11 @@ def check_dev_correct_formatting(device):


def _check_jax_x64_flag(dtype):
if ivy.backend == "jax" and not ivy.functional.backends.jax.jax.config.x64_enabled:
if (
ivy.backend == "jax"
and not ivy.functional.backends.jax.jax.config.jax_enable_x64
):

ivy.utils.assertions.check_elem_in_list(
dtype,
["float64", "int64", "uint64", "complex128"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,71 @@ def test_sklearn_accuracy_score(
)


@handle_frontend_test(
fn_tree="sklearn.metrics.precision_score",
arrays_and_dtypes=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("integer"),
num_arrays=2,
min_value=0,
max_value=1, # Precision score is for binary classification
shared_dtype=True,
shape=(helpers.ints(min_value=2, max_value=5)),
),
sample_weight=st.lists(
st.floats(min_value=0.1, max_value=1), min_size=2, max_size=5
),
)
def test_sklearn_precision_score(
arrays_and_dtypes,
on_device,
fn_tree,
frontend,
test_flags,
backend_fw,
sample_weight,
):
dtypes, values = arrays_and_dtypes
# Ensure the values are binary by rounding and converting to int
for i in range(2):
values[i] = np.round(values[i]).astype(int)

# Adjust sample_weight to have the correct length
sample_weight = np.array(sample_weight).astype(float)
if len(sample_weight) != len(values[0]):
# If sample_weight is shorter, extend it with ones
sample_weight = np.pad(
sample_weight,
(0, max(0, len(values[0]) - len(sample_weight))),
"constant",
constant_values=1.0,
)
# If sample_weight is longer, truncate it
sample_weight = sample_weight[: len(values[0])]

# Detach tensors if they require grad before converting to NumPy arrays
if backend_fw == "torch":
values = [
(
value.detach().numpy()
if isinstance(value, torch.Tensor) and value.requires_grad
else value
)
for value in values
]

helpers.test_frontend_function(
input_dtypes=dtypes,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
frontend=frontend,
on_device=on_device,
y_true=values[0],
y_pred=values[1],
sample_weight=sample_weight,
)


@handle_frontend_test(
fn_tree="sklearn.metrics.recall_score",
arrays_and_dtypes=helpers.dtype_and_values(
Expand Down
Loading