# Source code for kymatio.scattering1d.scattering1d

# Authors: Mathieu Andreux, Joakim Anden, Edouard Oyallon
# Scientific Ancestry: Joakim Anden, Mathieu Andreux, Vincent Lostanlen

import math
import numbers
import numpy as np
import torch

from .backend import (fft1d_c2c, ifft1d_c2c, modulus_complex, pad, real,
from .filter_bank import (calibrate_scattering_filters,
scattering_filter_factory)
from .utils import cast_phi, cast_psi, compute_border_indices, compute_padding

__all__ = ['Scattering1D']

[docs]class Scattering1D(object):
"""The 1D scattering transform

The scattering transform computes a cascade of wavelet transforms
alternated with a complex modulus non-linearity. The scattering transform
of a 1D signal :math:x(t) may be written as

$S_J x = [S_J^{(0)} x, S_J^{(1)} x, S_J^{(2)} x]$

where

$S_J^{(0)} x(t) = x \\star \\phi_J(t)$,

$S_J^{(1)} x(t, \\lambda) =|x \\star \\psi_\\lambda^{(1)}| \\star \\phi_J$, and

$S_J^{(2)} x(t, \\lambda, \\mu) = |\\,| x \\star \\psi_\\lambda^{(1)}| \\star \\psi_\\mu^{(2)} | \\star \\phi_J$.

In the above formulas, :math:\\star denotes convolution in time. The
filters $\\psi_\\lambda^{(1)}(t)$ and $\\psi_\\mu^{(2)}(t)$
are analytic wavelets with center frequencies $\\lambda$ and
$\\mu$, while $\\phi_J(t)$ is a real lowpass filter centered
at the zero frequency.

The Scattering1D class implements the 1D scattering transform for a
given set of filters whose parameters are specified at initialization.
While the wavelets are fixed, other parameters may be changed after the
object is created, such as whether to compute all of :math:S_J^{(0)} x,
$S_J^{(1)} x$, and $S_J^{(2)} x$ or just $S_J^{(0)} x$
and $S_J^{(1)} x$.

The scattering transform may be computed on the CPU (the default) or a
GPU, if available. A Scattering1D object may be transferred from one
to the other using the cuda() and cpu() methods.

Given an input Tensor x of size (B, T), where B is the number of
signals to transform (the batch size) and T is the length of the signal,
we compute its scattering transform by passing it to the forward()
method.

Example
-------
::

# Set the parameters of the scattering transform.
J = 6
T = 2**13
Q = 8

# Generate a sample signal.
x = torch.randn(1, 1, T)

# Define a Scattering1D object.
S = Scattering1D(J, T, Q)

# Calculate the scattering transform.
Sx = S.forward(x)

Above, the length of the signal is T = 2**13 = 8192, while the maximum
scale of the scattering transform is set to 2**J = 2**6 = 64. The
time-frequency resolution of the first-order wavelets
:math:\\psi_\\lambda^{(1)}(t) is set to Q = 8 wavelets per octave.
The second-order wavelets :math:\\psi_\\mu^{(2)}(t) always have one
wavelet per octave.

Parameters
----------
J : int
The maximum log-scale of the scattering transform. In other words,
the maximum scale is given by 2**J.
T : int
The length of the input signals.
Q : int >= 1
The number of first-order wavelets per octave (second-order wavelets
are fixed to one wavelet per octave). Defaults to 1.
max_order : int, optional
The maximum order of scattering coefficients to compute. Must be either
1 or 2. Defaults to 2.
average : boolean, optional
Determines whether the output is averaged in time or not. The averaged
output corresponds to the standard scattering transform, while the
un-averaged output skips the last convolution by :math:\\phi_J(t).
This parameter may be modified after object creation.
Defaults to True.
oversampling : integer >= 0, optional
Controls the oversampling factor relative to the default as a power
of two. Since the convolving by wavelets (or lowpass filters) and
taking the modulus reduces the high-frequency content of the signal,
we can subsample to save space and improve performance. However, this
may reduce precision in the calculation. If this is not desirable,
oversampling can be set to a large value to prevent too much
subsampling. This parameter may be modified after object creation.
Defaults to 0.
vectorize : boolean, optional
Determines wheter to return a vectorized scattering transform (that
is, a large array containing the output) or a dictionary (where each
entry corresponds to a separate scattering coefficient). This parameter
may be modified after object creation. Defaults to True.

Attributes
----------
J : int
The maximum log-scale of the scattering transform. In other words,
the maximum scale is given by 2**J.
shape : int
The length of the input signals.
Q : int
The number of first-order wavelets per octave (second-order wavelets
are fixed to one wavelet per octave).
The logarithm of the padded length of the signals.
The amount of padding to the left of the signal.
The amount of padding to the right of the signal.
phi_f : dictionary
A dictionary containing the lowpass filter at all resolutions. See
filter_bank.scattering_filter_factory for an exact description.
psi1_f : dictionary
A dictionary containing all the first-order wavelet filters, each
represented as a dictionary containing that filter at all
resolutions. See filter_bank.scattering_filter_factory for an exact
description.
psi2_f : dictionary
A dictionary containing all the second-order wavelet filters, each
represented as a dictionary containing that filter at all
resolutions. See filter_bank.scattering_filter_factory for an exact
description.
description
max_order : int
The maximum scattering order of the transform.
average : boolean
Controls whether the output should be averaged (the standard
scattering transform) or not (resulting in wavelet modulus
coefficients). Note that to obtain unaveraged output, the vectorize
flag must be set to False.
oversampling : int
The number of powers of two to oversample the output compared to the
default subsampling rate determined from the filters.
vectorize : boolean
Controls whether the output should be vectorized into a single Tensor
or collected into a dictionary. For more details, see the
documentation for forward().
"""
def __init__(self, J, shape, Q=1, max_order=2, average=True,
oversampling=0, vectorize=True):
super(Scattering1D, self).__init__()
# Store the parameters
self.J = J
self.shape = shape
self.Q = Q

self.max_order = max_order
self.average = average
self.oversampling = oversampling
self.vectorize = vectorize

# Build internal values
self.build()

[docs]    def build(self):

Certain internal data, such as the amount of padding and the wavelet
filters to be used in the scattering transform, need to be computed
from the parameters given during construction. This function is called
automatically during object creation and no subsequent calls are
therefore needed.
"""

# Set these default values for now. In the future, we'll want some
# flexibility for these, but for now, let's keep them fixed.
self.r_psi = math.sqrt(0.5)
self.sigma0 = 0.1
self.alpha = 5.
self.P_max = 5
self.eps = 1e-7
self.criterion_amplitude = 1e-3
self.normalize = 'l1'

# check the shape
if isinstance(self.shape, numbers.Integral):
self.T = self.shape
elif isinstance(self.shape, tuple):
self.T = self.shape[0]
if len(self.shape) > 1:
raise ValueError("If shape is specified as a tuple, it must "
"have exactly one element")
else:
raise ValueError("shape must be an integer or a 1-tuple")

# Compute the minimum support to pad (ideally)
self.T, self.J, self.Q, r_psi=self.r_psi, sigma0=self.sigma0,
alpha=self.alpha, P_max=self.P_max, eps=self.eps,
criterion_amplitude=self.criterion_amplitude,
normalize=self.normalize)
# to avoid padding more than T - 1 on the left and on the right,
# since otherwise torch sends nans
J_max_support = int(np.floor(np.log2(3 * self.T - 2)))
J_max_support)
# compute start and end indices
self.ind_start, self.ind_end = compute_border_indices(
# Finally, precompute the filters
phi_f, psi1_f, psi2_f, _ = scattering_filter_factory(
to_torch=True, criterion_amplitude=self.criterion_amplitude,
r_psi=self.r_psi, sigma0=self.sigma0, alpha=self.alpha,
P_max=self.P_max, eps=self.eps)
self.psi1_f = psi1_f
self.psi2_f = psi2_f
self.phi_f = phi_f
self._type(torch.FloatTensor)

def _type(self, target_type):
"""Change the datatype of the filters

This function is used internally to convert the filters. It does not
need to be called explicitly.

Parameters
----------
target_type : type
The desired type of the filters, typically torch.FloatTensor
or torch.cuda.FloatTensor.
"""
cast_psi(self.psi1_f, target_type)
cast_psi(self.psi2_f, target_type)
cast_phi(self.phi_f, target_type)
return self

def cpu(self):
"""Move to the CPU

This function prepares the object to accept input Tensors on the CPU.
"""
return self._type(torch.FloatTensor)

def cuda(self):
"""Move to the GPU

This function prepares the object to accept input Tensors on the GPU.
"""
return self._type(torch.cuda.FloatTensor)

[docs]    def meta(self):
"""Get meta information on the transform

Calls the static method compute_meta_scattering() with the
parameters of the transform object.

Returns
------
meta : dictionary
See the documentation for compute_meta_scattering().
"""
return Scattering1D.compute_meta_scattering(
self.J, self.Q, max_order=self.max_order)

[docs]    def output_size(self, detail=False):
"""Get size of the scattering transform

Calls the static method precompute_size_scattering() with the
parameters of the transform object.

Parameters
----------
detail : boolean, optional
Specifies whether to provide a detailed size (number of coefficient
per order) or an aggregate size (total number of coefficients).

Returns
------
size : int or tuple
See the documentation for precompute_size_scattering().
"""

return Scattering1D.precompute_size_scattering(
self.J, self.Q, max_order=self.max_order, detail=detail)

[docs]    def forward(self, x):
"""Apply the scattering transform

Given an input Tensor of size (B, T0), where B is the batch
size and T0 is the length of the individual signals, this function
computes its scattering transform. If the vectorize flag is set to
True, the output is in the form of a Tensor or size (B, C, T1),
where T1 is the signal length after subsampling to the scale 2**J
(with the appropriate oversampling factor to reduce aliasing), and
C is the number of scattering coefficients.  If vectorize is set
False, however, the output is a dictionary containing C keys, each
a tuple whose length corresponds to the scattering order and whose
elements are the sequence of filter indices used.

Furthermore, if the average flag is set to False, these outputs
are not averaged, but are simply the wavelet modulus coefficients of
the filters.

Parameters
----------
x : tensor
An input Tensor of size (B, T0).

Returns
-------
S : tensor or dictionary
If the vectorize flag is True, the output is a Tensor
containing the scattering coefficients, while if vectorize
is False, it is a dictionary indexed by tuples of filter indices.
"""
# basic checking, should be improved
if len(x.shape) < 1:
raise ValueError(
'Input tensor x should have at least one axis, got {}'.format(
len(x.shape)))

batch_shape = x.shape[:-1]
signal_shape = x.shape[-1:]

x = x.reshape((-1, 1) + signal_shape)

# get the arguments before calling the scattering
# treat the arguments
if self.vectorize:
if not(self.average):
raise ValueError(
'Options average=False and vectorize=True are ' +
'mutually incompatible. Please set vectorize to False.')
size_scattering = self.precompute_size_scattering(
self.J, self.Q, max_order=self.max_order, detail=True)
else:
size_scattering = 0
S = scattering(x, self.psi1_f, self.psi2_f, self.phi_f,
self.J, max_order=self.max_order, average=self.average,
ind_start=self.ind_start, ind_end=self.ind_end,
oversampling=self.oversampling,
vectorize=self.vectorize,
size_scattering=size_scattering)

if self.vectorize:
scattering_shape = S.shape[-2:]
S = S.reshape(batch_shape + scattering_shape)
else:
for k, v in S.items():
scattering_shape = v.shape[-2:]
S[k] = v.reshape(batch_shape + scattering_shape)

return S

def __call__(self, x):
return self.forward(x)

[docs]    @staticmethod
def compute_meta_scattering(J, Q, max_order=2):

This information specifies the content of each scattering coefficient,
which order, which frequencies, which filters were used, and so on.

Parameters
----------
J : int
The maximum log-scale of the scattering transform.
In other words, the maximum scale is given by 2**J.
Q : int >= 1
The number of first-order wavelets per octave.
Second-order wavelets are fixed to one wavelet per octave.
max_order : int, optional
The maximum order of scattering coefficients to compute.
Must be either equal to 1 or 2. Defaults to 2.

Returns
-------
meta : dictionary
A dictionary with the following keys:

- 'order' : tensor
A Tensor of length C, the total number of scattering
coefficients, specifying the scattering order.
- 'xi' : tensor
A Tensor of size (C, max_order), specifying the center
frequency of the filter used at each order (padded with NaNs).
- 'sigma' : tensor
A Tensor of size (C, max_order), specifying the frequency
bandwidth of the filter used at each order (padded with NaNs).
- 'j' : tensor
A Tensor of size (C, max_order), specifying the dyadic scale
of the filter used at each order (padded with NaNs).
- 'n' : tensor
A Tensor of size (C, max_order), specifying the indices of
the filters used at each order (padded with NaNs).
- 'key' : list
The tuples indexing the corresponding scattering coefficient
in the non-vectorized output.
"""
sigma_low, xi1s, sigma1s, j1s, xi2s, sigma2s, j2s = \
calibrate_scattering_filters(J, Q)

meta = {}

meta['order'] = [[], [], []]
meta['xi'] = [[], [], []]
meta['sigma'] = [[], [], []]
meta['j'] = [[], [], []]
meta['n'] = [[], [], []]
meta['key'] = [[], [], []]

meta['order'][0].append(0)
meta['xi'][0].append(())
meta['sigma'][0].append(())
meta['j'][0].append(())
meta['n'][0].append(())
meta['key'][0].append(())

for (n1, (xi1, sigma1, j1)) in enumerate(zip(xi1s, sigma1s, j1s)):
meta['order'][1].append(1)
meta['xi'][1].append((xi1,))
meta['sigma'][1].append((sigma1,))
meta['j'][1].append((j1,))
meta['n'][1].append((n1,))
meta['key'][1].append((n1,))

if max_order < 2:
continue

for (n2, (xi2, sigma2, j2)) in enumerate(zip(xi2s, sigma2s, j2s)):
if j2 > j1:
meta['order'][2].append(2)
meta['xi'][2].append((xi1, xi2))
meta['sigma'][2].append((sigma1, sigma2))
meta['j'][2].append((j1, j2))
meta['n'][2].append((n1, n2))
meta['key'][2].append((n1, n2))

for field, value in meta.items():
meta[field] = value[0] + value[1] + value[2]

pad_fields = ['xi', 'sigma', 'j', 'n']

meta[field] = [x+(math.nan,)*(pad_len-len(x)) for x in meta[field]]

array_fields = ['order', 'xi', 'sigma', 'j', 'n']

for field in array_fields:
meta[field] = torch.from_numpy(np.array(meta[field]))

return meta

[docs]    @staticmethod
def precompute_size_scattering(J, Q, max_order=2, detail=False):
"""Get size of the scattering transform

The number of scattering coefficients depends on the filter
configuration and so can be calculated using a few of the scattering
transform parameters.

Parameters
----------
J : int
The maximum log-scale of the scattering transform.
In other words, the maximum scale is given by 2**J.
Q : int >= 1
The number of first-order wavelets per octave.
Second-order wavelets are fixed to one wavelet per octave.
max_order : int, optional
The maximum order of scattering coefficients to compute.
Must be either equal to 1 or 2. Defaults to 2.
detail : boolean, optional
Specifies whether to provide a detailed size (number of coefficient
per order) or an aggregate size (total number of coefficients).

Returns
-------
size : int or tuple
If detail is False, returns the number of coefficients as an
integer. If True, returns a tuple of size max_order containing
the number of coefficients in each order.
"""
sigma_low, xi1, sigma1, j1, xi2, sigma2, j2 = \
calibrate_scattering_filters(J, Q)

size_order0 = 1
size_order1 = len(xi1)
size_order2 = 0
for n1 in range(len(xi1)):
for n2 in range(len(xi2)):
if j2[n2] > j1[n1]:
size_order2 += 1
if detail:
if max_order == 2:
return size_order0, size_order1, size_order2
else:
return size_order0, size_order1
else:
if max_order == 2:
return size_order0 + size_order1 + size_order2
else:
return size_order0 + size_order1

normalize='l1', r_psi=math.sqrt(0.5),
sigma0=1e-1, alpha=5., P_max=5, eps=1e-7):
"""
Computes the support to pad given the input size and the parameters of the
scattering transform.

Parameters
----------
T : int
temporal size of the input signal
J : int
scale of the scattering
Q : int
number of wavelets per octave
normalize : string, optional
normalization type for the wavelets.
Only 'l2' or 'l1' normalizations are supported.
Defaults to 'l1'
criterion_amplitude: float >0 and <1, optional
Represents the numerical error which is allowed to be lost after
The larger criterion_amplitude, the smaller the padding size is.
Defaults to 1e-3
r_psi : float, optional
Should be >0 and <1. Controls the redundancy of the filters
(the larger r_psi, the larger the overlap between adjacent
wavelets).
Defaults to sqrt(0.5).
sigma0 : float, optional
parameter controlling the frequential width of the
low-pass filter at J_scattering=0; at a an absolute J_scattering,
it is equal to :math:\\frac{\\sigma_0}{2^J}.
Defaults to 1e-1.
alpha : float, optional
tolerance factor for the aliasing after subsampling.
The larger the alpha, the more conservative the value of maximal
subsampling is.
Defaults to 5.
P_max : int, optional
maximal number of periods to use to make sure that the Fourier
transform of the filters is periodic.
P_max = 5 is more than enough for double precision.
Defaults to 5.
eps : float, optional
required machine precision for the periodization (single
floating point is enough for deep learning applications).
Defaults to 1e-7.

Returns
-------
minimal value to pad the signal on one size to avoid any
boundary error.
"""
J_tentative = int(np.ceil(np.log2(T)))
_, _, _, t_max_phi = scattering_filter_factory(
J_tentative, J, Q, normalize=normalize, to_torch=False,
max_subsampling=0, criterion_amplitude=criterion_amplitude,
r_psi=r_psi, sigma0=sigma0, alpha=alpha, P_max=P_max, eps=eps)

ind_start=None, ind_end=None, oversampling=0,
max_order=2, average=True, size_scattering=(0, 0, 0), vectorize=False):
"""
Main function implementing the 1-D scattering transform.

Parameters
----------
x : Tensor
a torch Tensor of size (B, 1, T) where T is the temporal size
psi1 : dictionary
a dictionary of filters (in the Fourier domain), with keys (j, q).
j corresponds to the downsampling factor for
:math:x \\ast psi1[(j, q)], and q corresponds to a pitch class
(chroma).
* psi1[(j, n)] is itself a dictionary, with keys corresponding to the
dilation factors: psi1[(j, n)][j2] corresponds to a support of size
:math:2^{J_\\text{max} - j_2}, where :math:J_\\text{max} has been
defined a priori (J_max = size of the padding support of the input)
* psi1[(j, n)] only has real values;
the tensors are complex so that broadcasting applies
psi2 : dictionary
a dictionary of filters, with keys (j2, n2). Same remarks as for psi1
phi : dictionary
a dictionary of filters of scale :math:2^J with keys (j)
where :math:2^j is the downsampling factor.
The array phi[j] is a real-valued filter.
J : int
scale of the scattering
how much to pad the signal on the left. Defaults to 0
how much to pad the signal on the right. Defaults to 0
ind_start : dictionary of ints, optional
indices to truncate the signal to recover only the
parts which correspond to the actual signal after padding and
downsampling. Defaults to None
ind_end : dictionary of ints, optional
See description of ind_start
oversampling : int, optional
how much to oversample the scattering (with respect to :math:2^J):
the higher, the larger the resulting scattering
tensor along time. Defaults to 0
order2 : boolean, optional
Whether to compute the 2nd order or not. Defaults to False.
average_U1 : boolean, optional
whether to average the first order vector. Defaults to True
size_scattering : tuple
Contains the number of channels of the scattering, precomputed for
speed-up. Defaults to (0, 0, 0).
vectorize : boolean, optional
whether to return a dictionary or a tensor. Defaults to False.

"""
# S is simply a dictionary if we do not perform the averaging...
if vectorize:
batch_size = x.shape[0]
kJ = max(J - oversampling, 0)
temporal_size = ind_end[kJ] - ind_start[kJ]
S = x.new(batch_size, sum(size_scattering), temporal_size).fill_(0.)
else:
S = {}

# compute the Fourier transform
U0_hat = fft1d_c2c(U0)
if vectorize:
# initialize the cursor
cc = [0] + list(size_scattering[:-1])  # current coordinate
cc[1] = cc[0] + cc[1]
if max_order == 2:
cc[2] = cc[1] + cc[2]
# Get S0
k0 = max(J - oversampling, 0)
if average:
S0_J_hat = subsample_fourier(U0_hat * phi[0], 2**k0)
ind_start[k0], ind_end[k0])
else:
S0_J = x
if vectorize:
S[:, cc[0], :] = S0_J.squeeze(dim=1)
cc[0] += 1
else:
S[()] = S0_J
# First order:
for n1 in range(len(psi1)):
# Convolution + downsampling
j1 = psi1[n1]['j']
k1 = max(j1 - oversampling, 0)
assert psi1[n1]['xi'] < 0.5 / (2**k1)
U1_hat = subsample_fourier(U0_hat * psi1[n1][0], 2**k1)
# Take the modulus
U1 = modulus_complex(ifft1d_c2c(U1_hat))
if average or max_order > 1:
U1_hat = fft1d_c2c(U1)
if average:
# Convolve with phi_J
k1_J = max(J - k1 - oversampling, 0)
S1_J_hat = subsample_fourier(U1_hat * phi[k1], 2**k1_J)
ind_start[k1_J + k1], ind_end[k1_J + k1])
else:
# just take the real value and unpad
if vectorize:
S[:, cc[1], :] = S1_J.squeeze(dim=1)
cc[1] += 1
else:
S[(n1,)] = S1_J
if max_order == 2:
# 2nd order
for n2 in range(len(psi2)):
j2 = psi2[n2]['j']
if j2 > j1:
assert psi2[n2]['xi'] < psi1[n1]['xi']
# convolution + downsampling
k2 = max(j2 - k1 - oversampling, 0)
U2_hat = subsample_fourier(U1_hat * psi2[n2][k1],
2**k2)
# take the modulus and go back in Fourier
U2 = modulus_complex(ifft1d_c2c(U2_hat))
if average:
U2_hat = fft1d_c2c(U2)
# Convolve with phi_J
k2_J = max(J - k2 - k1 - oversampling, 0)
S2_J_hat = subsample_fourier(U2_hat * phi[k1 + k2],
2**k2_J)
ind_start[k1 + k2 + k2_J],
ind_end[k1 + k2 + k2_J])
else:
# just take the real value and unpad
`