diff --git a/ivy/functional/backends/jax/__init__.py b/ivy/functional/backends/jax/__init__.py index 6168ec1b4ac84..7efc8b9c4f522 100644 --- a/ivy/functional/backends/jax/__init__.py +++ b/ivy/functional/backends/jax/__init__.py @@ -102,7 +102,7 @@ def _array_unflatten(aux_data, children): # update these to add new dtypes valid_dtypes = { - "0.4.24 and below": ( + "0.4.25 and below": ( ivy.int8, ivy.int16, ivy.int32, @@ -121,7 +121,7 @@ def _array_unflatten(aux_data, children): ) } valid_numeric_dtypes = { - "0.4.24 and below": ( + "0.4.25 and below": ( ivy.int8, ivy.int16, ivy.int32, @@ -140,7 +140,7 @@ def _array_unflatten(aux_data, children): } valid_int_dtypes = { - "0.4.24 and below": ( + "0.4.25 and below": ( ivy.int8, ivy.int16, ivy.int32, @@ -153,12 +153,12 @@ def _array_unflatten(aux_data, children): } valid_uint_dtypes = { - "0.4.24 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64) + "0.4.25 and below": (ivy.uint8, ivy.uint16, ivy.uint32, ivy.uint64) } valid_float_dtypes = { - "0.4.24 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64) + "0.4.25 and below": (ivy.bfloat16, ivy.float16, ivy.float32, ivy.float64) } -valid_complex_dtypes = {"0.4.24 and below": (ivy.complex64, ivy.complex128)} +valid_complex_dtypes = {"0.4.25 and below": (ivy.complex64, ivy.complex128)} # leave these untouched @@ -173,12 +173,12 @@ def _array_unflatten(aux_data, children): # invalid data types # update these to add new dtypes -invalid_dtypes = {"0.4.24 and below": ()} -invalid_numeric_dtypes = {"0.4.24 and below": ()} -invalid_int_dtypes = {"0.4.24 and below": ()} -invalid_float_dtypes = {"0.4.24 and below": ()} -invalid_uint_dtypes = {"0.4.24 and below": ()} -invalid_complex_dtypes = {"0.4.24 and below": ()} +invalid_dtypes = {"0.4.25 and below": ()} +invalid_numeric_dtypes = {"0.4.25 and below": ()} +invalid_int_dtypes = {"0.4.25 and below": ()} +invalid_float_dtypes = {"0.4.25 and below": ()} +invalid_uint_dtypes = {"0.4.25 and below": ()} +invalid_complex_dtypes = {"0.4.25 and below": ()} # leave these untouched invalid_dtypes = _dtype_from_version(invalid_dtypes, backend_version) diff --git a/ivy/functional/frontends/sklearn/metrics/_classification.py b/ivy/functional/frontends/sklearn/metrics/_classification.py index 2c05ae5da5edf..29395a8d14d9f 100644 --- a/ivy/functional/frontends/sklearn/metrics/_classification.py +++ b/ivy/functional/frontends/sklearn/metrics/_classification.py @@ -21,7 +21,7 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None): @to_ivy_arrays_and_back -def recall_score(y_true, y_pred, *, sample_weight=None): +def precision_score(y_true, y_pred, *, sample_weight=None): # Ensure that y_true and y_pred have the same shape if y_true.shape != y_pred.shape: raise IvyValueError("y_true and y_pred must have the same shape") @@ -36,7 +36,39 @@ def recall_score(y_true, y_pred, *, sample_weight=None): sample_weight = sample_weight / ivy.sum(sample_weight) else: sample_weight = ivy.ones_like(y_true) + # Calculate true positives and predicted positives + true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype( + "int64" + ) + predicted_positives = ivy.equal(y_pred, 1).astype("int64") + + # Apply sample weights + weighted_true_positives = ivy.multiply(true_positives, sample_weight) + weighted_predicted_positives = ivy.multiply(predicted_positives, sample_weight) + + # Compute precision + ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_predicted_positives) + ret = ret.astype("float64") + return ret + + +@to_ivy_arrays_and_back +def recall_score(y_true, y_pred, *, sample_weight=None): + # Ensure that y_true and y_pred have the same shape + if y_true.shape != y_pred.shape: + raise IvyValueError("y_true and y_pred must have the same shape") + + # Check if sample_weight is provided and normalize it + if sample_weight is not None: + sample_weight = ivy.array(sample_weight) + if sample_weight.shape[0] != y_true.shape[0]: + raise IvyValueError( + "sample_weight must have the same length as y_true and y_pred" + ) + sample_weight = sample_weight / ivy.sum(sample_weight) + else: + sample_weight = ivy.ones_like(y_true) # Calculate true positives and actual positives true_positives = ivy.logical_and(ivy.equal(y_true, 1), ivy.equal(y_pred, 1)).astype( "int64" @@ -49,5 +81,6 @@ def recall_score(y_true, y_pred, *, sample_weight=None): # Compute recall ret = ivy.sum(weighted_true_positives) / ivy.sum(weighted_actual_positives) + ret = ret.astype("float64") return ret diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index e3c7323c481ba..185a745400445 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -344,7 +344,11 @@ def check_dev_correct_formatting(device): def _check_jax_x64_flag(dtype): - if ivy.backend == "jax" and not ivy.functional.backends.jax.jax.config.x64_enabled: + if ( + ivy.backend == "jax" + and not ivy.functional.backends.jax.jax.config.jax_enable_x64 + ): + ivy.utils.assertions.check_elem_in_list( dtype, ["float64", "int64", "uint64", "complex128"], diff --git a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py index 1d8342188abee..dbca7729a356d 100644 --- a/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py +++ b/ivy_tests/test_ivy/test_frontends/test_sklearn/test_metrics/test_classification.py @@ -45,6 +45,71 @@ def test_sklearn_accuracy_score( ) +@handle_frontend_test( + fn_tree="sklearn.metrics.precision_score", + arrays_and_dtypes=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("integer"), + num_arrays=2, + min_value=0, + max_value=1, # Precision score is for binary classification + shared_dtype=True, + shape=(helpers.ints(min_value=2, max_value=5)), + ), + sample_weight=st.lists( + st.floats(min_value=0.1, max_value=1), min_size=2, max_size=5 + ), +) +def test_sklearn_precision_score( + arrays_and_dtypes, + on_device, + fn_tree, + frontend, + test_flags, + backend_fw, + sample_weight, +): + dtypes, values = arrays_and_dtypes + # Ensure the values are binary by rounding and converting to int + for i in range(2): + values[i] = np.round(values[i]).astype(int) + + # Adjust sample_weight to have the correct length + sample_weight = np.array(sample_weight).astype(float) + if len(sample_weight) != len(values[0]): + # If sample_weight is shorter, extend it with ones + sample_weight = np.pad( + sample_weight, + (0, max(0, len(values[0]) - len(sample_weight))), + "constant", + constant_values=1.0, + ) + # If sample_weight is longer, truncate it + sample_weight = sample_weight[: len(values[0])] + + # Detach tensors if they require grad before converting to NumPy arrays + if backend_fw == "torch": + values = [ + ( + value.detach().numpy() + if isinstance(value, torch.Tensor) and value.requires_grad + else value + ) + for value in values + ] + + helpers.test_frontend_function( + input_dtypes=dtypes, + backend_to_test=backend_fw, + test_flags=test_flags, + fn_tree=fn_tree, + frontend=frontend, + on_device=on_device, + y_true=values[0], + y_pred=values[1], + sample_weight=sample_weight, + ) + + @handle_frontend_test( fn_tree="sklearn.metrics.recall_score", arrays_and_dtypes=helpers.dtype_and_values(