Scikit-learn transformer example¶
Here we demonstrate a simple application of scattering as a transformer
Import the relevant classes and functions from sciki-learn.
from sklearn.pipeline import Pipeline from sklearn.model_selection import train_test_split from sklearn import datasets from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score
Import the scikit-learn Scattering2D frontend.
from kymatio.sklearn import Scattering2D
Preparing the data¶
First, we load the dataset. In this case, it’s the UCI ML digits dataset included with scikit-learn, consisting of 8×8 images of handwritten digits from one to ten.
digits = datasets.load_digits()
We then extract the images, reshape them to an array of size (n_features, n_samples) needed for processing in a scikit-learn pipeline.
images = digits.images images = images.reshape((images.shape, -1))
We then split the images (and their labels) into a train and a test set.
x_train, x_test, y_train, y_test = train_test_split(images, digits.target, test_size=0.5, shuffle=False)
Training and testing the model¶
Create a Scattering2D object, which implements a scikit-learn Transformer. We set the input shape to match that of the the images (8×8) and the averaging scale is set to J = 1, which means that the local invariance is 2 ** 1 = 1.
S = Scattering2D(shape=(8, 8), J=1)
We then plug this into a scikit-learn pipeline which takes the scattering features and provides them to a LogisticRegression classifier.
classifier = LogisticRegression() estimators = [('scatter', S), ('clf', classifier)] pipeline = Pipeline(estimators)
Given the pipeline, we train it on (x_train, y_train) using pipelien.fit.
/opt/hostedtoolcache/Python/3.7.12/x64/lib/python3.7/site-packages/sklearn/linear_model/_logistic.py:818: ConvergenceWarning: lbfgs failed to converge (status=1): STOP: TOTAL NO. of ITERATIONS REACHED LIMIT. Increase the number of iterations (max_iter) or scale the data as shown in: https://scikit-learn.org/stable/modules/preprocessing.html Please also refer to the documentation for alternative solver options: https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG, Pipeline(steps=[('scatter', Scattering2D(J=1, backend=<class 'kymatio.scattering2d.backend.numpy_backend.backend'>, shape=(8, 8))), ('clf', LogisticRegression())])
Finally, we calculate the predicted labels on the test data and output the classification accuracy.
y_pred = pipeline.predict(x_test) print('Accuracy:', accuracy_score(y_test, y_pred))
Total running time of the script: ( 0 minutes 0.607 seconds)