diff --git a/ivy/functional/ivy/losses.py b/ivy/functional/ivy/losses.py index 1e8d6cebb511e..89820300d16df 100644 --- a/ivy/functional/ivy/losses.py +++ b/ivy/functional/ivy/losses.py @@ -381,3 +381,63 @@ def sparse_cross_entropy( return ivy.cross_entropy( true, pred, axis=axis, epsilon=epsilon, reduction=reduction, out=out ) + + +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@inputs_to_ivy_arrays +@handle_array_function +def ssim_loss( + true: Union[ivy.Array, ivy.NativeArray], + pred: Union[ivy.Array, ivy.NativeArray], + out: Optional[ivy.Array] = None, +) -> ivy.Array: + """Calculate the Structural Similarity Index (SSIM) loss between two + images. + + Parameters + ---------- + true: A 4D image array of shape (batch_size, channels, height, width). + pred: A 4D image array of shape (batch_size, channels, height, width). + + Returns + ------- + ivy.Array: The SSIM loss measure similarity between the two images. + + Examples + -------- + With :class:`ivy.Array` input: + >>> import ivy + >>> x = ivy.ones((5, 3, 28, 28)) + >>> y = ivy.zeros((5, 3, 28, 28)) + >>> ivy.ssim_loss(x, y) + ivy.array(0.99989986) + """ + # Constants for stability + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + # Calculate the mean of the two images + mu_x = ivy.avg_pool2d(pred, (3, 3), (1, 1), "SAME") + mu_y = ivy.avg_pool2d(true, (3, 3), (1, 1), "SAME") + + # Calculate variance and covariance + sigma_x2 = ivy.avg_pool2d(pred * pred, (3, 3), (1, 1), "SAME") - mu_x * mu_x + sigma_y2 = ivy.avg_pool2d(true * true, (3, 3), (1, 1), "SAME") - mu_y * mu_y + sigma_xy = ivy.avg_pool2d(pred * true, (3, 3), (1, 1), "SAME") - mu_x * mu_y + + # Calculate SSIM + ssim = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / ( + (mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x2 + sigma_y2 + C2) + ) + + # Convert SSIM to loss + ssim_loss_value = 1 - ssim + + # Return mean SSIM loss + ret = ivy.mean(ssim_loss_value) + + if ivy.exists(out): + ret = ivy.inplace_update(out, ret) + return ret diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py b/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py index bba4b09173af4..92d4407a8acdc 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_losses.py @@ -203,3 +203,41 @@ def test_sparse_cross_entropy( epsilon=epsilon, reduction=reduction, ) + + +@handle_test( + fn_tree="functional.ivy.ssim_loss", + dtype_and_true=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1, + max_value=1, + min_num_dims=4, + max_num_dims=4, + min_dim_size=2, + ), + dtype_and_pred=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-1, + max_value=1, + min_num_dims=4, + max_num_dims=4, + min_dim_size=2, + ), +) +def test_ssim_loss( + dtype_and_true, dtype_and_pred, test_flags, backend_fw, fn_name, on_device +): + true_dtype, true = dtype_and_true + pred_dtype, pred = dtype_and_pred + + helpers.test_function( + input_dtypes=pred_dtype + true_dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + true=true[0], + pred=pred[0], + rtol_=1e-02, + atol_=1e-02, + )