Skip to content

Commit

Permalink
[metrics] Update SSIM (#4566)
Browse files Browse the repository at this point in the history
* [metrics] Update SSIM

* [metrics] Update SSIM

* [metrics] Update SSIM

* [metrics] Update SSIM

* [metrics] update ssim

* dist_sync_on_step True

* [metrics] update ssim

* Update tests/metrics/regression/test_ssim.py

Co-authored-by: chaton <thomas@grid.ai>

* Update pytorch_lightning/metrics/functional/ssim.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* ddp=True

* Update test_ssim.py

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
  • Loading branch information
5 people authored and rohitgr7 committed Nov 21, 2020
1 parent 06a301b commit bde885c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 21 deletions.
31 changes: 19 additions & 12 deletions pytorch_lightning/metrics/functional/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
from torch.nn import functional as F


def _gaussian_kernel(channel, kernel_size, sigma, device):
def _gaussian(kernel_size, sigma, device):
gauss = torch.arange(
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device
)
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device):
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y)
def _gaussian_kernel(channel: int, kernel_size: Sequence[int], sigma: Sequence[float],
dtype: torch.dtype, device: torch.device):
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])

Expand Down Expand Up @@ -82,9 +82,15 @@ def _ssim_compute(
device = preds.device

channel = preds.size(1)
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
dtype = preds.dtype
kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device)
pad_w = (kernel_size[0] - 1) // 2
pad_h = (kernel_size[1] - 1) // 2

preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect')

input_list = torch.cat([preds, target, preds * preds, target * target, preds * target]) # (5 * B, C, H, W)
input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)
output_list = [outputs[x * preds.size(0): (x + 1) * preds.size(0)] for x in range(len(outputs))]

Expand All @@ -100,6 +106,7 @@ def _ssim_compute(
lower = sigma_pred_sq + sigma_target_sq + c2

ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)
ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w]

return reduce(ssim_idx, reduction)

Expand Down
17 changes: 9 additions & 8 deletions tests/metrics/regression/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@


_inputs = []
for size, channel, coef, multichannel in [
(16, 1, 0.9, False),
(32, 3, 0.8, True),
(48, 4, 0.7, True),
(64, 5, 0.6, True),
for size, channel, coef, multichannel, dtype in [
(12, 3, 0.9, True, torch.float),
(13, 1, 0.8, False, torch.float32),
(14, 1, 0.7, False, torch.double),
(15, 3, 0.6, True, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size)
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds,
Expand All @@ -41,7 +41,8 @@ def _sk_metric(preds, target, data_range, multichannel):
sk_target = sk_target[:, :, :, 0]

return structural_similarity(
sk_target, sk_preds, data_range=data_range, multichannel=multichannel, gaussian_weights=True, win_size=11
sk_target, sk_preds, data_range=data_range, multichannel=multichannel,
gaussian_weights=True, win_size=11, sigma=1.5, use_sample_covariance=False
)


Expand All @@ -50,7 +51,7 @@ def _sk_metric(preds, target, data_range, multichannel):
[(i.preds, i.target, i.multichannel) for i in _inputs],
)
class TestSSIM(MetricTester):
atol = 1e-3 # TODO: ideally tests should pass with lower tolerance
atol = 6e-5

# TODO: for some reason this test hangs with ddp=True
# @pytest.mark.parametrize("ddp", [True, False])
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setup_ddp(rank, world_size):
os.environ["MASTER_ADDR"] = 'localhost'
os.environ['MASTER_PORT'] = '8088'

if torch.distributed.is_available() and sys.platform not in ['win32', 'cygwin']:
if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


Expand Down

0 comments on commit bde885c

Please sign in to comment.