.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery_2d/plot_invert_scattering_torch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_2d_plot_invert_scattering_torch.py: Inverting scattering via mse ============================ This script aims to quantify the information loss for natural images by performing a reconstruction of an image from its scattering coefficients via a L2-norm minimization. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 12-24 .. code-block:: default import matplotlib.pyplot as plt import numpy as np import torch import torch.nn.functional as F from PIL import Image from torch import optim from scipy.misc import face from kymatio.torch import Scattering2D device = "cuda" if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 25-27 Load test image --------------- .. GENERATED FROM PYTHON SOURCE LINES 27-39 .. code-block:: default src_img = Image.fromarray(face()) src_img = src_img.resize((512, 384), Image.ANTIALIAS) src_img = np.array(src_img).astype(np.float32) src_img = src_img / 255.0 plt.imshow(src_img) plt.title("Original image") src_img = np.moveaxis(src_img, -1, 0) # HWC to CHW max_iter = 15 # number of steps for the GD print("Image shape: ", src_img.shape) channels, height, width = src_img.shape .. image-sg:: /gallery_2d/images/sphx_glr_plot_invert_scattering_torch_001.png :alt: Original image :srcset: /gallery_2d/images/sphx_glr_plot_invert_scattering_torch_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /home/runner/work/kymatio/kymatio/examples/2d/plot_invert_scattering_torch.py:28: DeprecationWarning: ANTIALIAS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead. src_img = src_img.resize((512, 384), Image.ANTIALIAS) Image shape: (3, 384, 512) .. GENERATED FROM PYTHON SOURCE LINES 40-42 Main loop ---------- .. GENERATED FROM PYTHON SOURCE LINES 42-86 .. code-block:: default for order in [1]: for J in [2, 4]: # Compute scattering coefficients scattering = Scattering2D(J=J, shape=(height, width), max_order=order) if device == "cuda": scattering = scattering.cuda() max_iter = 500 src_img_tensor = torch.from_numpy(src_img).to(device).contiguous() scattering_coefficients = scattering(src_img_tensor) # Create trainable input image input_tensor = torch.rand(src_img.shape, requires_grad=True, device=device) # Optimizer hyperparams optimizer = optim.Adam([input_tensor], lr=1) # Training best_img = None best_loss = float("inf") for epoch in range(1, max_iter): new_coefficients = scattering(input_tensor) loss = F.mse_loss(input=new_coefficients, target=scattering_coefficients) print("Epoch {}, loss: {}".format(epoch, loss.item()), end="\r") optimizer.zero_grad() loss.backward() optimizer.step() if loss < best_loss: best_loss = loss.detach().cpu().item() best_img = input_tensor.detach().cpu().numpy() best_img = np.clip(best_img, 0.0, 1.0) # PSNR mse = np.mean((src_img - best_img) ** 2) psnr = 20 * np.log10(1.0 / np.sqrt(mse)) print("\nPSNR: {:.2f}dB for order {} and J={}".format(psnr, order, J)) # Plot plt.figure() plt.imshow(np.moveaxis(best_img, 0, -1)) plt.title("PSNR: {:.2f}dB (order {}, J={})".format(psnr, order, J)) plt.show() .. rst-class:: sphx-glr-horizontal * .. image-sg:: /gallery_2d/images/sphx_glr_plot_invert_scattering_torch_002.png :alt: PSNR: 14.85dB (order 1, J=2) :srcset: /gallery_2d/images/sphx_glr_plot_invert_scattering_torch_002.png :class: sphx-glr-multi-img * .. image-sg:: /gallery_2d/images/sphx_glr_plot_invert_scattering_torch_003.png :alt: PSNR: 14.97dB (order 1, J=4) :srcset: /gallery_2d/images/sphx_glr_plot_invert_scattering_torch_003.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none Epoch 1, loss: 0.004395806696265936 Epoch 2, loss: 0.010747981257736683 Epoch 3, loss: 0.0029847484547644854 Epoch 4, loss: 0.002685928950086236 Epoch 5, loss: 0.004441110882908106 Epoch 6, loss: 0.0030704180244356394 Epoch 7, loss: 0.0015076962299644947 Epoch 8, loss: 0.0015746791614219546 Epoch 9, loss: 0.0020367270335555077 Epoch 10, loss: 0.001731803989969194 Epoch 11, loss: 0.0011120557319372892 Epoch 12, loss: 0.0009237747872248292 Epoch 13, loss: 0.0010906597599387169 Epoch 14, loss: 0.0011006807908415794 PSNR: 14.85dB for order 1 and J=2 Epoch 1, loss: 0.0018843415891751647 Epoch 2, loss: 0.0028222508262842894 Epoch 3, loss: 0.0010740647558122873 Epoch 4, loss: 0.0007734951213933527 Epoch 5, loss: 0.0010741836158558726 Epoch 6, loss: 0.0008093100623227656 Epoch 7, loss: 0.00048384207184426486 Epoch 8, loss: 0.00044804130448028445 Epoch 9, loss: 0.0004951524315401912 Epoch 10, loss: 0.0004230768245179206 Epoch 11, loss: 0.00030509373755194247 Epoch 12, loss: 0.00026347331004217267 Epoch 13, loss: 0.00027792336186394095 Epoch 14, loss: 0.00026439508656039834 PSNR: 14.97dB for order 1 and J=4 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 1 minutes 14.631 seconds) .. _sphx_glr_download_gallery_2d_plot_invert_scattering_torch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_invert_scattering_torch.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_invert_scattering_torch.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_