.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery_2d/mnist_keras.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_2d_mnist_keras.py: Classification of MNIST with scattering ======================================= Here we demonstrate a simple application of scattering on the MNIST dataset. We use 10000 images to train a linear classifier. Features are normalized by batch normalization. .. GENERATED FROM PYTHON SOURCE LINES 10-15 Preliminaries ------------- Since we're using TensorFlow and Keras to train the model, import the relevant modules. .. GENERATED FROM PYTHON SOURCE LINES 15-20 .. code-block:: default import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Flatten, Dense .. GENERATED FROM PYTHON SOURCE LINES 21-23 Finally, we import the `Scattering2D` class from the `kymatio.keras` package. .. GENERATED FROM PYTHON SOURCE LINES 23-26 .. code-block:: default from kymatio.keras import Scattering2D .. GENERATED FROM PYTHON SOURCE LINES 27-31 Training and testing the model ------------------------------ First, we load in the data and normalize it. .. GENERATED FROM PYTHON SOURCE LINES 31-36 .. code-block:: default (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() x_train, x_test = x_train / 255., x_test / 255. .. GENERATED FROM PYTHON SOURCE LINES 37-39 We then create a Keras model using the scattering transform followed by a dense layer and a softmax activation. .. GENERATED FROM PYTHON SOURCE LINES 39-46 .. code-block:: default inputs = Input(shape=(28, 28)) x = Scattering2D(J=3, L=8)(inputs) x = Flatten()(x) x_out = Dense(10, activation='softmax')(x) model = Model(inputs, x_out) .. GENERATED FROM PYTHON SOURCE LINES 47-48 Display the created model. .. GENERATED FROM PYTHON SOURCE LINES 48-51 .. code-block:: default model.summary() .. GENERATED FROM PYTHON SOURCE LINES 52-54 Once the model is created, we couple it with an Adam optimizer and a cross-entropy loss function. .. GENERATED FROM PYTHON SOURCE LINES 54-59 .. code-block:: default model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) .. GENERATED FROM PYTHON SOURCE LINES 60-61 We then train the model using `model.fit` on a subset of the MNIST data. .. GENERATED FROM PYTHON SOURCE LINES 61-65 .. code-block:: default model.fit(x_train[:10000], y_train[:10000], epochs=15, batch_size=64, validation_split=0.2) .. GENERATED FROM PYTHON SOURCE LINES 66-67 Finally, we evaluate the model on the held-out test data. .. GENERATED FROM PYTHON SOURCE LINES 67-69 .. code-block:: default model.evaluate(x_test, y_test) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_gallery_2d_mnist_keras.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: mnist_keras.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: mnist_keras.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_