hidimstat.D0CRT#
- class hidimstat.D0CRT(estimator, method: str = 'predict', estimated_coef=None, sigma_X=None, lasso_screening=LassoCV(fit_intercept=False, n_alphas=10, tol=1e-06), model_distillation_x=LassoCV(n_alphas=10), refit=False, screening_threshold=10, centered=True, n_jobs=1, joblib_verbose=0, fit_y=False, scaled_statistics=False, reuse_screening_model=True, random_state=None)[source]#
Bases:
BaseVariableImportance
Implements distilled conditional randomization test (dCRT) without interactions.
This class provides a fast implementation of the Conditional Randomization Test Candes et al.[1] using the distillation process from Liu et al.[2]. The approach accelerates variable selection by combining Lasso-based screening and residual-based test statistics. Based on the original implementation at: moleibobliu/Distillation-CRT The y-distillation is based on a given estimator and the x-distillation is based on a Lasso estimator.
- Parameters:
- estimatorsklearn estimator
The base estimator used for y-distillation and prediction (e.g., Lasso, RandomForest, …).
- methodstr, default=”predict”
Method of the estimator to use for predictions (“predict”, “predict_proba”, “decision_function”).
- estimated_coefarray-like of shape (n_features,) or None, default=None
Pre-computed feature coefficients. If None, coefficients are estimated via Lasso.
- sigma_Xarray-like of shape (n_features, n_features) or None, default=None
Covariance matrix of X. If None, Lasso is used for X distillation.
- lasso_screeningsklearn estimator, default=LassoCV(n_alphas=10, tol=1e-6, fit_intercept=False)
Estimator for variable screening (typically LassoCV or Lasso).
- model_distillation_xsklearn estimator, default=LassoCV(n_alphas=10)
Estimator for X distillation (typically LassoCV or Lasso).
- refitbool, default=False
Whether to refit the model on selected features after screening.
- screening_thresholdfloat, default=10
Percentile threshold for screening (0-100). Larger values include more variables at screening. (screening_threshold=100 keeps all variables).
- centeredbool, default=True
Whether to center and scale features using StandardScaler.
- n_jobsint, default=1
Number of parallel jobs.
- joblib_verboseint, default=0
Verbosity level for parallel jobs.
- fit_ybool, default=False
Controls y-distillation behavior: - If False and the estimator is linear, the sub-model predicting y from X^{-j} is created by simply removing the idx-th coefficient from the full model (no fitting is performed). - If True, fits a clone of estimator on (X^{-j}, y) - For non-linear estimators, always fits a clone of estimator on (X^{-j}, y) regardless of fit_y.
- scaled_statisticsbool, default=False
Whether to use scaled statistics when computing importance.
- random_stateint, default=None
Random seed for reproducibility.
- reuse_screening_model: bool, default=True
Whether to reuse the screening model for y-distillation.
- Attributes:
- coefficient_ndarray of shape (n_features,)
Estimated feature coefficients after screening/refitting during fit method.
- selection_set_ndarray of shape (n_features,)
Boolean mask indicating selected features after screening.
- model_x_list of estimators
Fitted models for X distillation (Lasso or None if using sigma_X).
- model_y_list of estimators
Fitted models for y distillation (sklearn estimator or None if using estimated_coef and Lasso estimator).
- importances_ndarray of shape (n_features,)
Importance scores for each feature. Test statistics following standard normal distribution.
- pvalues_ndarray of shape (n_features,)
Computed p-values for each feature.
Notes
The implementation follows Liu et al. (2022), introducing distillation to speed up conditional randomization testing. Key steps: 1. Optional screening using Lasso coefficients to reduce dimensionality. 2. Distillation to estimate conditional distributions. 3. Test statistic computation using residual correlations. 4. P-value calculation assuming Gaussian null distribution.
The implementation currently allows flexible models for the y-distillation step. However, the x-distillation step only supports linear models.
The random_state parameter of the different x-distillation and y-distillation models is set by spawning independent Generators from the main random_state of the D0CRT instance.
References
- __init__(estimator, method: str = 'predict', estimated_coef=None, sigma_X=None, lasso_screening=LassoCV(fit_intercept=False, n_alphas=10, tol=1e-06), model_distillation_x=LassoCV(n_alphas=10), refit=False, screening_threshold=10, centered=True, n_jobs=1, joblib_verbose=0, fit_y=False, scaled_statistics=False, reuse_screening_model=True, random_state=None)[source]#
- fit(X, y)[source]#
Fit the dCRT model.
This method fits the Distilled Conditional Randomization Test (DCRT) model as described in Liu et al.[2]. It performs optional feature screening using Lasso, computes coefficients, and prepares the model for importance and p-value computation.
- Parameters:
- Xarray-like of shape (n_samples, n_features)
Training data matrix.
- yarray-like of shape (n_samples,)
Target values.
- Returns:
- selfobject
Returns the fitted instance.
Notes
Main steps: 1. Optional data centering with StandardScaler 2. Lasso screening of variables (if no estimated coefficients provided) 3. Feature selection based on coefficient magnitudes 4. Model refitting on selected features (if refit=True) 5. Fit model for future distillation
The screening threshold controls which features are kept based on their Lasso coefficients. Features with coefficients below the threshold are set to zero.
References
- importance(X, y)[source]#
Compute feature importance scores using distilled CRT.
Calculates test statistics and p-values for each feature using residual correlations after the distillation process.
- Parameters:
- Xarray-like of shape (n_samples, n_features)
Input data matrix.
- yarray-like of shape (n_samples,)
Target values.
- Attributes:
- importances_same as return value
- pvalues_ndarray of shape (n_features,)
Two-sided p-values for each feature under Gaussian null.
- Returns:
- importances_ndarray of shape (n_features,)
Test statistics/importance scores for each feature. For unselected features, the score is set to 0.
Notes
For each selected feature j: 1. Computes residuals from regressing X_j on other features 2. Computes residuals from regressing y on other features 3. Calculates test statistic from correlation of residuals 4. Computes p-value assuming standard normal distribution
- fit_importance(X, y, cv=None)[source]#
Fits the model to the data and computes feature importance.
A convenience method that combines fit() and importance() into a single call. First fits the dCRT model to the data, then calculates importance scores.
- Parameters:
- Xarray-like of shape (n_samples, n_features)
Training data matrix.
- yarray-like of shape (n_samples,)
Target values.
- cvNone or int, optional (default=None)
Not used. Included for compatibility. A warning will be issued if provided.
- Returns:
- importancendarray of shape (n_features,)
Feature importance scores/test statistics. For features not selected during screening, scores are set to 0.
Notes
Also sets the importances_ and pvalues_ attributes on the instance. See fit() and importance() for details on the underlying computations.
- get_metadata_routing()[source]#
Get metadata routing of this object.
Please check User Guide on how the routing mechanism works.
- Returns:
- routingMetadataRequest
A
MetadataRequest
encapsulating routing information.
- get_params(deep=True)[source]#
Get parameters for this estimator.
- Parameters:
- deepbool, default=True
If True, will return the parameters for this estimator and contained subobjects that are estimators.
- Returns:
- paramsdict
Parameter names mapped to their values.
- plot_importance(ax=None, ascending=False, **seaborn_barplot_kwargs)[source]#
Plot feature importances as a horizontal bar plot.
- Parameters:
- axmatplotlib.axes.Axes or None, (default=None)
Axes object to draw the plot onto, otherwise uses the current Axes.
- ascending: bool, default=False
Whether to sort features by ascending importance.
- **seaborn_barplot_kwargsadditional keyword arguments
Additional arguments passed to seaborn.barplot. https://seaborn.pydata.org/generated/seaborn.barplot.html
- Returns:
- axmatplotlib.axes.Axes
The Axes object with the plot.
- selection(k_best=None, percentile=None, threshold=None, threshold_pvalue=None)[source]#
Selects features based on variable importance. In case several arguments are different from None, the returned selection is the conjunction of all of them.
- Parameters:
- k_bestint, optional, default=None
Selects the top k features based on importance scores.
- percentilefloat, optional, default=None
Selects features based on a specified percentile of importance scores.
- thresholdfloat, optional, default=None
Selects features with importance scores above the specified threshold.
- threshold_pvaluefloat, optional, default=None
Selects features with p-values below the specified threshold.
- Returns:
- selectionarray-like of shape (n_features,)
Binary array indicating the selected features.
- set_params(**params)[source]#
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as
Pipeline
). The latter have parameters of the form<component>__<parameter>
so that it’s possible to update each component of a nested object.- Parameters:
- **paramsdict
Estimator parameters.
- Returns:
- selfestimator instance
Estimator instance.
Examples using hidimstat.D0CRT
#

Distilled Conditional Randomization Test (dCRT) using Lasso vs Random Forest learners