.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery_1d/reconstruct_torch.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_1d_reconstruct_torch.py: Reconstruct a synthetic signal from its scattering transform ============================================================ In this example we generate a harmonic signal of a few different frequencies, analyze it with the 1D scattering transform, and reconstruct the scattering transform back to the harmonic signal. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Import the necessary packages ----------------------------- .. GENERATED FROM PYTHON SOURCE LINES 12-21 .. code-block:: default import numpy as np import torch from kymatio.torch import Scattering1D from torch.autograd import backward import matplotlib.pyplot as plt device = torch.device("cuda" if torch.cuda.is_available else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 22-32 Write a function that can generate a harmonic signal ---------------------------------------------------- Let's write a function that can generate some simple blip-type sounds with decaying harmonics. It will take four arguments: T, the length of the output vector; num_intervals, the number of different blips; gamma, the exponential decay factor of the harmonic; random_state, a random seed to generate random pitches and phase shifts. The function proceeds by splitting the time length T into intervals, chooses base frequencies and phases, generates sinusoidal sounds and harmonics, and then adds a windowed version to the output signal. .. GENERATED FROM PYTHON SOURCE LINES 32-58 .. code-block:: default def generate_harmonic_signal(T, num_intervals=4, gamma=0.9, random_state=42): """ Generates a harmonic signal, which is made of piecewise constant notes (of random fundamental frequency), with half overlap """ rng = np.random.RandomState(random_state) num_notes = 2 * (num_intervals - 1) + 1 support = T // num_intervals half_support = support // 2 base_freq = 0.1 * rng.rand(num_notes) + 0.05 phase = 2 * np.pi * rng.rand(num_notes) window = np.hanning(support) x = np.zeros(T, dtype='float32') t = np.arange(0, support) u = 2 * np.pi * t for i in range(num_notes): ind_start = i * half_support note = np.zeros(support) for k in range(1): note += (np.power(gamma, k) * np.cos(u * (k + 1) * base_freq[i] + phase[i])) x[ind_start:ind_start + support] += note * window return x .. GENERATED FROM PYTHON SOURCE LINES 59-60 Let’s take a look at what such a signal could look like. .. GENERATED FROM PYTHON SOURCE LINES 60-67 .. code-block:: default T = 2 ** 13 x = torch.from_numpy(generate_harmonic_signal(T)) plt.figure(figsize=(8, 2)) plt.plot(x.numpy()) plt.title("Original signal") .. GENERATED FROM PYTHON SOURCE LINES 68-69 Let’s take a look at the signal spectrogram. .. GENERATED FROM PYTHON SOURCE LINES 69-74 .. code-block:: default plt.figure(figsize=(8, 8)) plt.specgram(x.numpy(), Fs=1024) plt.title("Spectrogram of original signal") .. GENERATED FROM PYTHON SOURCE LINES 75-76 # Doing the scattering transform. .. GENERATED FROM PYTHON SOURCE LINES 76-90 .. code-block:: default J = 6 Q = 16 scattering = Scattering1D(J, T, Q).to(device) x = x.to(device) Sx = scattering(x) learning_rate = 100 bold_driver_accelerator = 1.1 bold_driver_brake = 0.55 n_iterations = 200 .. GENERATED FROM PYTHON SOURCE LINES 91-92 Reconstruct the scattering transform back to original signal. .. GENERATED FROM PYTHON SOURCE LINES 92-144 .. code-block:: default # Random guess to initialize. torch.manual_seed(0) y = torch.randn((T,), requires_grad=True, device=device) Sy = scattering(y) history = [] signal_update = torch.zeros_like(x, device=device) # Iterate to recontsruct random guess to be close to target. for k in range(n_iterations): # Backpropagation. err = torch.norm(Sx - Sy) if k % 10 == 0: print('Iteration %3d, loss %.2f' % (k, err.detach().cpu().numpy())) # Measure the new loss. history.append(err.detach().cpu()) backward(err) delta_y = y.grad # Gradient descent with torch.no_grad(): signal_update = - learning_rate * delta_y new_y = y + signal_update new_y.requires_grad = True # New forward propagation. Sy = scattering(new_y) if history[k] > history[k - 1]: learning_rate *= bold_driver_brake else: learning_rate *= bold_driver_accelerator y = new_y plt.figure(figsize=(8, 2)) plt.plot(history) plt.title("MSE error vs. iterations") plt.figure(figsize=(8, 2)) plt.plot(y.detach().cpu().numpy()) plt.title("Reconstructed signal") plt.figure(figsize=(8, 8)) plt.specgram(y.detach().cpu().numpy(), Fs=1024) plt.title("Spectrogram of reconstructed signal") plt.show() .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_gallery_1d_reconstruct_torch.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: reconstruct_torch.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: reconstruct_torch.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_