.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gallery_2d/plot_sklearn.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gallery_2d_plot_sklearn.py: Scikit-learn transformer example ================================ Here we demonstrate a simple application of scattering as a transformer .. GENERATED FROM PYTHON SOURCE LINES 8-12 Preliminaries ------------- Import the relevant classes and functions from sciki-learn. .. GENERATED FROM PYTHON SOURCE LINES 12-20 .. code-block:: default from sklearn.pipeline import Pipeline from sklearn.model_selection import train_test_split from sklearn import datasets from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score from sklearn.preprocessing import StandardScaler .. GENERATED FROM PYTHON SOURCE LINES 21-22 Import the scikit-learn `Scattering2D` frontend. .. GENERATED FROM PYTHON SOURCE LINES 22-25 .. code-block:: default from kymatio.sklearn import Scattering2D .. GENERATED FROM PYTHON SOURCE LINES 26-32 Preparing the data ------------------ First, we load the dataset. In this case, it's the UCI ML digits dataset included with scikit-learn, consisting of 8×8 images of handwritten digits from one to ten. .. GENERATED FROM PYTHON SOURCE LINES 32-35 .. code-block:: default digits = datasets.load_digits() .. GENERATED FROM PYTHON SOURCE LINES 36-38 We then extract the images, reshape them to an array of size `(n_features, n_samples)` needed for processing in a scikit-learn pipeline. .. GENERATED FROM PYTHON SOURCE LINES 38-42 .. code-block:: default images = digits.images images = images.reshape((images.shape[0], -1)) .. GENERATED FROM PYTHON SOURCE LINES 43-44 We then split the images (and their labels) into a train and a test set. .. GENERATED FROM PYTHON SOURCE LINES 44-48 .. code-block:: default x_train, x_test, y_train, y_test = train_test_split(images, digits.target, test_size=0.5, shuffle=False) .. GENERATED FROM PYTHON SOURCE LINES 49-55 Training and testing the model ------------------------------ Create a `Scattering2D` object, which implements a scikit-learn `Transformer`. We set the input shape to match that of the the images (8×8) and the averaging scale is set to `J = 1`, which means that the local invariance is `2 ** 1 = 1`. .. GENERATED FROM PYTHON SOURCE LINES 55-58 .. code-block:: default S = Scattering2D(shape=(8, 8), J=1) .. GENERATED FROM PYTHON SOURCE LINES 59-61 We then plug this into a scikit-learn pipeline which takes the scattering features, scales them, then provides them to a `LogisticRegression` classifier. .. GENERATED FROM PYTHON SOURCE LINES 61-66 .. code-block:: default classifier = LogisticRegression(max_iter=150) estimators = [('scatter', S), ('scaler', StandardScaler()), ('clf', classifier)] pipeline = Pipeline(estimators) .. GENERATED FROM PYTHON SOURCE LINES 67-69 Given the pipeline, we train it on `(x_train, y_train)` using `pipelien.fit`. .. GENERATED FROM PYTHON SOURCE LINES 69-72 .. code-block:: default pipeline.fit(x_train, y_train) .. rst-class:: sphx-glr-script-out .. code-block:: none Pipeline(steps=[('scatter', Scattering2D(J=1, backend=, shape=(8, 8))), ('scaler', StandardScaler()), ('clf', LogisticRegression(max_iter=150))]) .. GENERATED FROM PYTHON SOURCE LINES 73-75 Finally, we calculate the predicted labels on the test data and output the classification accuracy. .. GENERATED FROM PYTHON SOURCE LINES 75-79 .. code-block:: default y_pred = pipeline.predict(x_test) print('Accuracy:', accuracy_score(y_test, y_pred)) .. rst-class:: sphx-glr-script-out .. code-block:: none Accuracy: 0.9755283648498332 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.403 seconds) .. _sphx_glr_download_gallery_2d_plot_sklearn.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_sklearn.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_sklearn.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_