Inverting scattering via mse

This script aims to quantify the information loss for natural images by performing a reconstruction of an image from its scattering coefficients via a L2-norm minimization.

Imports

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import optim
from scipy.misc import face

from kymatio import Scattering2D

device = "cuda" if torch.cuda.is_available() else "cpu"

Load test image

src_img = Image.fromarray(face())
src_img = src_img.resize((512, 384), Image.ANTIALIAS)
src_img = np.array(src_img).astype(np.float32)
src_img = src_img / 255.0
plt.imshow(src_img)
plt.title("Original image")

src_img = np.moveaxis(src_img, -1, 0)  # HWC to CHW
max_iter = 5 # number of steps for the GD
print("Image shape: ", src_img.shape)
channels, height, width = src_img.shape
../_images/sphx_glr_plot_invert_scattering_001.png

Out:

Image shape:  (3, 384, 512)

Main loop

for order in [1]:
    for J in [2, 4]:

        # Compute scattering coefficients
        scattering = Scattering2D(J=J, shape=(height, width), max_order=order)
        if device == "cuda":
            scattering = scattering.cuda()
            max_iter = 500
        src_img_tensor = torch.from_numpy(src_img).to(device).contiguous()
        scattering_coefficients = scattering(src_img_tensor)

        # Create trainable input image
        input_tensor = torch.rand(src_img.shape, requires_grad=True, device=device)

        # Optimizer hyperparams
        optimizer = optim.Adam([input_tensor], lr=1)

        # Training
        best_img = None
        best_loss = float("inf")
        for epoch in range(1, max_iter):
            new_coefficients = scattering(input_tensor)
            loss = F.mse_loss(input=new_coefficients, target=scattering_coefficients)
            print("Epoch {}, loss: {}".format(epoch, loss.item()), end="\r")
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if loss < best_loss:
                best_loss = loss.detach().cpu().item()
                best_img = input_tensor.detach().cpu().numpy()

        best_img = np.clip(best_img, 0.0, 1.0)

        # PSNR
        mse = np.mean((src_img - best_img) ** 2)
        psnr = 20 * np.log10(1.0 / np.sqrt(mse))
        print("\nPSNR: {:.2f}dB for order {} and J={}".format(psnr, order, J))

        # Plot
        plt.figure()
        plt.imshow(np.moveaxis(best_img, 0, -1))
        plt.title("PSNR: {:.2f}dB (order {}, J={})".format(psnr, order, J))

plt.show()
  • ../_images/sphx_glr_plot_invert_scattering_002.png
  • ../_images/sphx_glr_plot_invert_scattering_003.png

Out:

Epoch 1, loss: 0.0043935696594417095
Epoch 2, loss: 0.010752280242741108
Epoch 3, loss: 0.002986006671562791
Epoch 4, loss: 0.0026848462875932455
PSNR: 9.62dB for order 1 and J=2
Epoch 1, loss: 0.0018797414377331734
Epoch 2, loss: 0.0028183492831885815
Epoch 3, loss: 0.0010696848621591926
Epoch 4, loss: 0.0007753887330181897
PSNR: 11.42dB for order 1 and J=4

Total running time of the script: ( 2 minutes 3.451 seconds)

Gallery generated by Sphinx-Gallery