Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I added structural similarity index (SSIM) loss. #27134

Merged
merged 12 commits into from
Jul 13, 2024
60 changes: 60 additions & 0 deletions ivy/functional/ivy/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,3 +381,63 @@ def sparse_cross_entropy(
return ivy.cross_entropy(
true, pred, axis=axis, epsilon=epsilon, reduction=reduction, out=out
)


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@inputs_to_ivy_arrays
@handle_array_function
def ssim_loss(
true: Union[ivy.Array, ivy.NativeArray],
pred: Union[ivy.Array, ivy.NativeArray],
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""Calculate the Structural Similarity Index (SSIM) loss between two
images.

Parameters
----------
true: A 4D image array of shape (batch_size, channels, height, width).
pred: A 4D image array of shape (batch_size, channels, height, width).

Returns
-------
ivy.Array: The SSIM loss measure similarity between the two images.

Examples
--------
With :class:`ivy.Array` input:
>>> import ivy
>>> x = ivy.ones((5, 3, 28, 28))
>>> y = ivy.zeros((5, 3, 28, 28))
>>> ivy.ssim_loss(x, y)
ivy.array(0.99989986)
"""
# Constants for stability
C1 = 0.01 ** 2
C2 = 0.03 ** 2

# Calculate the mean of the two images
mu_x = ivy.avg_pool2d(pred, (3, 3), (1, 1), "SAME")
mu_y = ivy.avg_pool2d(true, (3, 3), (1, 1), "SAME")

# Calculate variance and covariance
sigma_x2 = ivy.avg_pool2d(pred * pred, (3, 3), (1, 1), "SAME") - mu_x * mu_x
sigma_y2 = ivy.avg_pool2d(true * true, (3, 3), (1, 1), "SAME") - mu_y * mu_y
sigma_xy = ivy.avg_pool2d(pred * true, (3, 3), (1, 1), "SAME") - mu_x * mu_y

# Calculate SSIM
ssim = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / (
(mu_x ** 2 + mu_y ** 2 + C1) * (sigma_x2 + sigma_y2 + C2)
)

# Convert SSIM to loss
ssim_loss_value = 1 - ssim

# Return mean SSIM loss
ret = ivy.mean(ssim_loss_value)

if ivy.exists(out):
ret = ivy.inplace_update(out, ret)
return ret
38 changes: 38 additions & 0 deletions ivy_tests/test_ivy/test_functional/test_nn/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,41 @@ def test_sparse_cross_entropy(
epsilon=epsilon,
reduction=reduction,
)


@handle_test(
fn_tree="functional.ivy.ssim_loss",
dtype_and_true=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-1,
max_value=1,
min_num_dims=4,
max_num_dims=4,
min_dim_size=2,
),
dtype_and_pred=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-1,
max_value=1,
min_num_dims=4,
max_num_dims=4,
min_dim_size=2,
),
)
def test_ssim_loss(
dtype_and_true, dtype_and_pred, test_flags, backend_fw, fn_name, on_device
):
true_dtype, true = dtype_and_true
pred_dtype, pred = dtype_and_pred

helpers.test_function(
input_dtypes=pred_dtype + true_dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
fn_name=fn_name,
on_device=on_device,
true=true[0],
pred=pred[0],
rtol_=1e-02,
atol_=1e-02,
)
Loading