hidimstat.D0CRT#

class hidimstat.D0CRT(estimator, method: str = 'predict', estimated_coef=None, sigma_X=None, lasso_screening=LassoCV(fit_intercept=False, n_alphas=10, tol=1e-06), model_distillation_x=LassoCV(n_alphas=10), refit=False, screening_threshold=10, centered=True, n_jobs=1, joblib_verbose=0, fit_y=False, scaled_statistics=False, reuse_screening_model=True, random_state=None)[source]#

Bases: BaseVariableImportance

Implements distilled conditional randomization test (dCRT) without interactions.

This class provides a fast implementation of the Conditional Randomization Test Candes et al.[1] using the distillation process from Liu et al.[2]. The approach accelerates variable selection by combining Lasso-based screening and residual-based test statistics. Based on the original implementation at: moleibobliu/Distillation-CRT The y-distillation is based on a given estimator and the x-distillation is based on a Lasso estimator.

Parameters:
estimatorsklearn estimator

The base estimator used for y-distillation and prediction (e.g., Lasso, RandomForest, …).

methodstr, default=”predict”

Method of the estimator to use for predictions (“predict”, “predict_proba”, “decision_function”).

estimated_coefarray-like of shape (n_features,) or None, default=None

Pre-computed feature coefficients. If None, coefficients are estimated via Lasso.

sigma_Xarray-like of shape (n_features, n_features) or None, default=None

Covariance matrix of X. If None, Lasso is used for X distillation.

lasso_screeningsklearn estimator, default=LassoCV(n_alphas=10, tol=1e-6, fit_intercept=False)

Estimator for variable screening (typically LassoCV or Lasso).

model_distillation_xsklearn estimator, default=LassoCV(n_alphas=10)

Estimator for X distillation (typically LassoCV or Lasso).

refitbool, default=False

Whether to refit the model on selected features after screening.

screening_thresholdfloat, default=10

Percentile threshold for screening (0-100). Larger values include more variables at screening. (screening_threshold=100 keeps all variables).

centeredbool, default=True

Whether to center and scale features using StandardScaler.

n_jobsint, default=1

Number of parallel jobs.

joblib_verboseint, default=0

Verbosity level for parallel jobs.

fit_ybool, default=False

Controls y-distillation behavior: - If False and the estimator is linear, the sub-model predicting y from X^{-j} is created by simply removing the idx-th coefficient from the full model (no fitting is performed). - If True, fits a clone of estimator on (X^{-j}, y) - For non-linear estimators, always fits a clone of estimator on (X^{-j}, y) regardless of fit_y.

scaled_statisticsbool, default=False

Whether to use scaled statistics when computing importance.

random_stateint, default=None

Random seed for reproducibility.

reuse_screening_model: bool, default=True

Whether to reuse the screening model for y-distillation.

Attributes:
coefficient_ndarray of shape (n_features,)

Estimated feature coefficients after screening/refitting during fit method.

selection_set_ndarray of shape (n_features,)

Boolean mask indicating selected features after screening.

model_x_list of estimators

Fitted models for X distillation (Lasso or None if using sigma_X).

model_y_list of estimators

Fitted models for y distillation (sklearn estimator or None if using estimated_coef and Lasso estimator).

importances_ndarray of shape (n_features,)

Importance scores for each feature. Test statistics following standard normal distribution.

pvalues_ndarray of shape (n_features,)

Computed p-values for each feature.

Notes

The implementation follows Liu et al. (2022), introducing distillation to speed up conditional randomization testing. Key steps: 1. Optional screening using Lasso coefficients to reduce dimensionality. 2. Distillation to estimate conditional distributions. 3. Test statistic computation using residual correlations. 4. P-value calculation assuming Gaussian null distribution.

The implementation currently allows flexible models for the y-distillation step. However, the x-distillation step only supports linear models.

The random_state parameter of the different x-distillation and y-distillation models is set by spawning independent Generators from the main random_state of the D0CRT instance.

References

__init__(estimator, method: str = 'predict', estimated_coef=None, sigma_X=None, lasso_screening=LassoCV(fit_intercept=False, n_alphas=10, tol=1e-06), model_distillation_x=LassoCV(n_alphas=10), refit=False, screening_threshold=10, centered=True, n_jobs=1, joblib_verbose=0, fit_y=False, scaled_statistics=False, reuse_screening_model=True, random_state=None)[source]#
fit(X, y)[source]#

Fit the dCRT model.

This method fits the Distilled Conditional Randomization Test (DCRT) model as described in Liu et al.[2]. It performs optional feature screening using Lasso, computes coefficients, and prepares the model for importance and p-value computation.

Parameters:
Xarray-like of shape (n_samples, n_features)

Training data matrix.

yarray-like of shape (n_samples,)

Target values.

Returns:
selfobject

Returns the fitted instance.

Notes

Main steps: 1. Optional data centering with StandardScaler 2. Lasso screening of variables (if no estimated coefficients provided) 3. Feature selection based on coefficient magnitudes 4. Model refitting on selected features (if refit=True) 5. Fit model for future distillation

The screening threshold controls which features are kept based on their Lasso coefficients. Features with coefficients below the threshold are set to zero.

References

importance(X, y)[source]#

Compute feature importance scores using distilled CRT.

Calculates test statistics and p-values for each feature using residual correlations after the distillation process.

Parameters:
Xarray-like of shape (n_samples, n_features)

Input data matrix.

yarray-like of shape (n_samples,)

Target values.

Attributes:
importances_same as return value
pvalues_ndarray of shape (n_features,)

Two-sided p-values for each feature under Gaussian null.

Returns:
importances_ndarray of shape (n_features,)

Test statistics/importance scores for each feature. For unselected features, the score is set to 0.

Notes

For each selected feature j: 1. Computes residuals from regressing X_j on other features 2. Computes residuals from regressing y on other features 3. Calculates test statistic from correlation of residuals 4. Computes p-value assuming standard normal distribution

fit_importance(X, y, cv=None)[source]#

Fits the model to the data and computes feature importance.

A convenience method that combines fit() and importance() into a single call. First fits the dCRT model to the data, then calculates importance scores.

Parameters:
Xarray-like of shape (n_samples, n_features)

Training data matrix.

yarray-like of shape (n_samples,)

Target values.

cvNone or int, optional (default=None)

Not used. Included for compatibility. A warning will be issued if provided.

Returns:
importancendarray of shape (n_features,)

Feature importance scores/test statistics. For features not selected during screening, scores are set to 0.

Notes

Also sets the importances_ and pvalues_ attributes on the instance. See fit() and importance() for details on the underlying computations.

get_metadata_routing()[source]#

Get metadata routing of this object.

Please check User Guide on how the routing mechanism works.

Returns:
routingMetadataRequest

A MetadataRequest encapsulating routing information.

get_params(deep=True)[source]#

Get parameters for this estimator.

Parameters:
deepbool, default=True

If True, will return the parameters for this estimator and contained subobjects that are estimators.

Returns:
paramsdict

Parameter names mapped to their values.

plot_importance(ax=None, ascending=False, **seaborn_barplot_kwargs)[source]#

Plot feature importances as a horizontal bar plot.

Parameters:
axmatplotlib.axes.Axes or None, (default=None)

Axes object to draw the plot onto, otherwise uses the current Axes.

ascending: bool, default=False

Whether to sort features by ascending importance.

**seaborn_barplot_kwargsadditional keyword arguments

Additional arguments passed to seaborn.barplot. https://seaborn.pydata.org/generated/seaborn.barplot.html

Returns:
axmatplotlib.axes.Axes

The Axes object with the plot.

selection(k_best=None, percentile=None, threshold=None, threshold_pvalue=None)[source]#

Selects features based on variable importance. In case several arguments are different from None, the returned selection is the conjunction of all of them.

Parameters:
k_bestint, optional, default=None

Selects the top k features based on importance scores.

percentilefloat, optional, default=None

Selects features based on a specified percentile of importance scores.

thresholdfloat, optional, default=None

Selects features with importance scores above the specified threshold.

threshold_pvaluefloat, optional, default=None

Selects features with p-values below the specified threshold.

Returns:
selectionarray-like of shape (n_features,)

Binary array indicating the selected features.

set_params(**params)[source]#

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.

Parameters:
**paramsdict

Estimator parameters.

Returns:
selfestimator instance

Estimator instance.

Examples using hidimstat.D0CRT#

Distilled Conditional Randomization Test (dCRT) using Lasso vs Random Forest learners

Distilled Conditional Randomization Test (dCRT) using Lasso vs Random Forest learners

Variable Selection Under Model Misspecification

Variable Selection Under Model Misspecification