Plot the 2D wavelet filters

See kymatio.scattering2d.filter_bank() for more informations about the used wavelets.

import matplotlib.pyplot as plt
import numpy as np
from kymatio.scattering2d.filter_bank import filter_bank
from kymatio.scattering2d.utils import fft2

Initial parameters of the filter bank

M = 32
J = 3
L = 8
filters_set = filter_bank(M, M, J, L=L)

Imshow complex images

Thanks to https://stackoverflow.com/questions/17044052/mathplotlib-imshow-complex-2d-array

from colorsys import hls_to_rgb
def colorize(z):
    n, m = z.shape
    c = np.zeros((n, m, 3))
    c[np.isinf(z)] = (1.0, 1.0, 1.0)
    c[np.isnan(z)] = (0.5, 0.5, 0.5)

    idx = ~(np.isinf(z) + np.isnan(z))
    A = (np.angle(z[idx]) + np.pi) / (2*np.pi)
    A = (A + 0.5) % 1.0
    B =  1.0/(1.0+abs(z[idx])**0.3)
    c[idx] = [hls_to_rgb(a, b, 0.8) for a,b in zip(A,B)]
    return c

Bandpass filters

First, we display each wavelet according to its scale and orientation.

fig, axs = plt.subplots(J, L, sharex=True, sharey=True)
fig.set_figheight(6)
fig.set_figwidth(6)
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
i=0
for filter in filters_set['psi']:
    f_r = filter[0][...,0].numpy()
    f_i = filter[0][..., 1].numpy()
    f = f_r + 1j*f_i
    filter_c = fft2(f)
    filter_c = np.fft.fftshift(filter_c)
    axs[i // L, i % L].imshow(colorize(filter_c))
    axs[i // L, i % L].axis('off')
    axs[i // L, i % L].set_title("$j = {}$ \n $\\theta={}$".format(i // L, i % L))
    i = i+1

fig.suptitle("Wavelets for each scales $j$ and angles $\\theta$ used."
"\n Color saturation and color hue respectively denote complex magnitude and complex phase.", fontsize=13)
fig.show()
../_images/sphx_glr_plot_filters_0011.png

Lowpass filter

We finally display the low-pass filter.

plt.figure()
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
plt.axis('off')
plt.set_cmap('gray_r')

f_r = filters_set['phi'][0][..., 0].numpy()
f_i = filters_set['phi'][0][..., 1].numpy()
f = f_r + 1j*f_i

filter_c = fft2(f)
filter_c = np.fft.fftshift(filter_c)
plt.suptitle("The corresponding low-pass filter, also known as scaling function."
"Color saturation and color hue respectively denote complex magnitude and complex phase", fontsize=13)
filter_c = np.abs(filter_c)
plt.imshow(filter_c)
../_images/sphx_glr_plot_filters_0021.png

Total running time of the script: ( 0 minutes 1.471 seconds)

Gallery generated by Sphinx-Gallery