Use Pytorch SSIM loss function in my model

12,125

Solution 1

The author is trying to maximize the SSIM value. The natural understanding of the pytorch loss function and optimizer working is to reduce the loss. But the SSIM value is quality measure and hence higher the better. Hence the author uses
loss = - criterion(inputs, outputs)

You can instead try using
loss = 1 - criterion(inputs, outputs)
as described in this paper.


Modified code (max_ssim.py) for testing the above thing using this repo

import pytorch_ssim
import torch
from torch.autograd import Variable
from torch import optim
import cv2
import numpy as np

npImg1 = cv2.imread("einstein.png")

img1 = torch.from_numpy(np.rollaxis(npImg1, 2)).float().unsqueeze(0)/255.0
img2 = torch.rand(img1.size())

if torch.cuda.is_available():
    img1 = img1.cuda()
    img2 = img2.cuda()


img1 = Variable( img1,  requires_grad=False)
img2 = Variable( img2, requires_grad = True)

print(img1.shape)
print(img2.shape)
# Functional: pytorch_ssim.ssim(img1, img2, window_size = 11, size_average = True)
ssim_value = 1-pytorch_ssim.ssim(img1, img2).item()
print("Initial ssim:", ssim_value)

# Module: pytorch_ssim.SSIM(window_size = 11, size_average = True)
ssim_loss = pytorch_ssim.SSIM()

optimizer = optim.Adam([img2], lr=0.01)

while ssim_value > 0.05:
    optimizer.zero_grad()
    ssim_out = 1-ssim_loss(img1, img2)
    ssim_value = ssim_out.item()
    print(ssim_value)
    ssim_out.backward()
    optimizer.step()
    cv2.imshow('op',np.transpose(img2.cpu().detach().numpy()[0],(1,2,0)))
    cv2.waitKey()

Solution 2

The usual way to transform a similarity (higher is better) into a loss is to compute 1 - similarity(x, y).

To create this loss you can create a new "function".

def ssim_loss(x, y):
    return 1. - ssim(x, y)

Alternatively, if the similarity is a class (nn.Module), you can overload it to create a new one.

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

Also, there are better implementations of SSIM than the one of this repo. For example, the one of the piqa Python package is faster. The package can be installed with

pip install piqa

For your problem

from piqa import SSIM

class SSIMLoss(SSIM):
    def forward(self, x, y):
        return 1. - super().forward(x, y)

criterion = SSIMLoss() # .cuda() if you need GPU support

...
loss = criterion(x, y)
...

should work well.

Share:
12,125
sealpuppy
Author by

sealpuppy

Updated on June 23, 2022

Comments

  • sealpuppy
    sealpuppy almost 2 years

    I am trying out this SSIM loss implement by this repo for image restoration.

    For the reference of original sample code on author's GitHub, I tried:

    model.train()
    for epo in range(epoch):
        for i, data in enumerate(trainloader, 0):
            inputs = data
            inputs = Variable(inputs)
            optimizer.zero_grad()
            inputs = inputs.view(bs, 1, 128, 128)
            top = model.upward(inputs)
            outputs = model.downward(top, shortcut = True)
            outputs = outputs.view(bs, 1, 128, 128)
    
            if i % 20 == 0:
                out = outputs[0].view(128, 128).detach().numpy() * 255
                cv2.imwrite("/home/tk/Documents/recover/SSIM/" + str(epo) + "_" + str(i) + "_re.png", out)
    
            loss = - criterion(inputs, outputs)
            ssim_value = - loss.data.item()
            print (ssim_value)
            loss.backward()
            optimizer.step()
    

    However, the results didn't come out as I expected. After first 10 epochs, the printed outcome image were all black.

    loss = - criterion(inputs, outputs) is proposed by the author, however, for classical Pytorch training code this will be loss = criterion(y_pred, target), therefore should be loss = criterion(inputs, outputs) here.

    However, I tried loss = criterion(inputs, outputs) but the results are still the same.

    Can anyone share some thoughts about how to properly utilize SSIM loss? Thanks.