.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "generated/gallery/examples/plot_digits.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_generated_gallery_examples_plot_digits.py: Pixel-wise inference on digit classification ============================================ This example illustrates how to perform pixel-wise inference using Ensemble Clustered Inference with Desparsified Lasso (EnCluDL) on digit classification tasks. We use the MNIST dataset and consider binary classification between digits 1 vs 7 and 0 vs 1. The MNIST dataset contains 28x28 pixel images of handwritten digits. .. GENERATED FROM PYTHON SOURCE LINES 11-17 Loading the MNIST dataset ------------------------- We start by loading the MNIST dataset from OpenML. We then filter the dataset to include only digits 4 and 7 for the first classification task, and digits 0 and 1 for the second task and digits 0 and 9 for the third task. To speed up the example, we downsample the dataset to 4000 samples for each task. .. GENERATED FROM PYTHON SOURCE LINES 17-47 .. code-block:: Python import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import fetch_openml from sklearn.utils import resample mnist_dataset = fetch_openml("mnist_784", version=1, as_frame=False) X_mnist, y_mnist = mnist_dataset.data, mnist_dataset.target # Downsample to speed up the example n_samples = 5000 mask_4_7 = (y_mnist == "4") | (y_mnist == "7") X_4_7, y_4_7 = X_mnist[mask_4_7], y_mnist[mask_4_7].astype(int) X_4_7, y_4_7 = resample( X_4_7, y_4_7, n_samples=n_samples, replace=False, random_state=0, stratify=y_4_7 ) mask_0_1 = (y_mnist == "0") | (y_mnist == "1") X_0_1, y_0_1 = X_mnist[mask_0_1], y_mnist[mask_0_1].astype(int) X_0_1, y_0_1 = resample( X_0_1, y_0_1, n_samples=n_samples, replace=False, random_state=0, stratify=y_0_1 ) mask_0_9 = (y_mnist == "0") | (y_mnist == "9") X_0_9, y_0_9 = X_mnist[mask_0_9], y_mnist[mask_0_9].astype(int) X_0_9, y_0_9 = resample( X_0_9, y_0_9, n_samples=n_samples, replace=False, random_state=0, stratify=y_0_9 ) .. GENERATED FROM PYTHON SOURCE LINES 48-49 Visualizing samples from each classification task .. GENERATED FROM PYTHON SOURCE LINES 49-69 .. code-block:: Python _, axes = plt.subplots(3, 5, figsize=(6, 4), subplot_kw={"xticks": [], "yticks": []}) for i in range(5): # Plot 0 vs 1 label = 1 if i % 2 == 0 else 0 axes[0, i].imshow(X_0_1[y_0_1 == label][i].reshape(28, 28), cmap="gray") # Plot 4 vs 7 label = 7 if i % 2 == 0 else 4 axes[1, i].imshow(X_4_7[y_4_7 == label][i].reshape(28, 28), cmap="gray") # PLot 0 vs 9 label = 9 if i % 2 == 0 else 0 axes[2, i].imshow(X_0_9[y_0_9 == label][i].reshape(28, 28), cmap="gray") axes[0, 2].set_title("Digits 0 vs 1", fontweight="bold", y=1.0) axes[1, 2].set_title("Digits 4 vs 7", fontweight="bold", y=1.0) axes[2, 2].set_title("Digits 0 vs 9", fontweight="bold", y=1.0) _ = plt.tight_layout() .. image-sg:: /generated/gallery/examples/images/sphx_glr_plot_digits_001.png :alt: Digits 0 vs 1, Digits 4 vs 7, Digits 0 vs 9 :srcset: /generated/gallery/examples/images/sphx_glr_plot_digits_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 70-80 Randomness in clustering affects inference stability ---------------------------------------------------- A key limitation of clustered inference procedures is that the clustering step introduces randomness into the inference process. Data perturbations can produce different clustering results, which subsequently lead to varying inference outcomes. This effect can be demonstrated by running the clustering algorithm multiple times on different subsamples of the data. Here we use the Ward hierarchical clustering algorithm with spatial connectivity constraints (each pixel is connected to its immediate neighbors) to cluster the pixels of the images. We use 100 clusters and visualize the clustering results for four different subsamples of the data. .. GENERATED FROM PYTHON SOURCE LINES 80-101 .. code-block:: Python from sklearn.cluster import FeatureAgglomeration from sklearn.feature_extraction import image n_clusters = 100 shape = (28, 28) connectivity = image.grid_to_graph(n_x=shape[0], n_y=shape[1]) _, axes = plt.subplots(1, 4, figsize=(9, 3)) for i, ax in enumerate(axes): X_cluster = resample(X_4_7, n_samples=n_samples // 2, replace=False, random_state=i) clustering = FeatureAgglomeration(n_clusters=n_clusters, connectivity=connectivity) clustering.fit(X_cluster) ax.imshow(clustering.labels_.reshape(28, 28), cmap="Set2") ax.axis("off") ax.set_title(f"Clustering {i}") _ = plt.tight_layout() .. image-sg:: /generated/gallery/examples/images/sphx_glr_plot_digits_002.png :alt: Clustering 0, Clustering 1, Clustering 2, Clustering 3 :srcset: /generated/gallery/examples/images/sphx_glr_plot_digits_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 102-110 Ensemble Clustered Inference with Desparsified Lasso ---------------------------------------------------- To mitigate the randomness introduced by clustering, we ensemble the results from multiple clustered inference procedures. This approach derandomizes the inference procedure and produces more stable results. We use the class:`hidimstat.EnCluDL` for this purpose with 10 bootstraps. While this approach is more computationally intensive than single clustered inference, the procedure can be parallelized across clustering repetitions using the ``n_jobs`` parameter. .. GENERATED FROM PYTHON SOURCE LINES 110-147 .. code-block:: Python from sklearn.cluster import FeatureAgglomeration from sklearn.feature_extraction import image from sklearn.linear_model import LassoCV from hidimstat import DesparsifiedLasso from hidimstat.ensemble_clustered_inference import EnCluDL fwer = 0.1 n_jobs = 5 n_clusters = 100 encludl = EnCluDL( clustering=clustering, desparsified_lasso=DesparsifiedLasso(estimator=LassoCV(max_iter=1000)), n_bootstraps=10, n_jobs=n_jobs, random_state=0, cluster_boostrap_size=0.5, ) encludl.fit_importance(X_4_7, y_4_7) selected_4_7 = encludl.fwer_selection( fwer=fwer, n_tests=n_clusters, two_tailed_test=True ) encludl.fit_importance(X_0_1, y_0_1) selected_0_1 = encludl.fwer_selection( fwer=fwer, n_tests=n_clusters, two_tailed_test=True ) encludl.fit_importance(X_0_9, y_0_9) selected_0_9 = encludl.fwer_selection( fwer=fwer, n_tests=n_clusters, two_tailed_test=True ) .. rst-class:: sphx-glr-script-out .. code-block:: none Fitting clustered inferences: 0%| | 0/10 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_digits.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: plot_digits.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_