diff --git a/capabilities/interpretability.mdx b/capabilities/interpretability.mdx index 1a5cdc0..794ade5 100644 --- a/capabilities/interpretability.mdx +++ b/capabilities/interpretability.mdx @@ -1,75 +1,440 @@ --- title: "Interpretability" -description: "Explain TabPFN predictions with SHAP values and feature selection." +description: "Explain TabPFN predictions with Shapley values, feature interactions, and partial dependence plots." --- -The [Interpretability Extension](https://github.com/PriorLabs/tabpfn-extensions/tree/main/src/tabpfn_extensions/interpretability) adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. This can be used to: +The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. + +SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations. + +This can be used to: - See which features drive model predictions. - Compare feature importance across samples. - Debug unexpected model behavior. -**Shapley Values** - -SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations. +The extension also provides an easy interface for TabPFN Partial Dependence Plots and feature selection. -Data generation example Data generation example +
+ SHAP waterfall plot + SHAP feature importance +
-## Getting Started +--- -Install the `interpretability` extension: +## Installation ```bash -pip install "tabpfn-extensions[interpretability]" +pip install tabpfn-client "tabpfn-extensions[interpretability]" ``` -Then, use SHAP with any trained TabPFN model. This example shows how to use the `TabPFNClassifier`, however, a `TabPFNRegressor` can be used analogously. +This installs `shapiq`, `shap`, and the other dependencies needed for all three +methods. + +--- + +## Quickstart + +Train a model, explain a single prediction, and plot the result: + + +Interpretability computations are resource-intensive. This tutorial uses our API client. +For fully local execution instead of the cloud API, replace the `tabpfn_client` import with +`tabpfn` and ensure you have a GPU available. See [best practices](/best-practices) for GPU setup. All code +examples below work identically with either backend. + +```python +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer + +X, y = load_iris(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +clf = TabPFNClassifier() +clf.fit(X_train, y_train) + +explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train) +sv = explainer.explain(X_test.iloc[0:1].values, budget=128) +sv.plot_waterfall() +``` + +--- + +## Choosing a Method + +Before diving into each method, here is a summary to help you pick the right +tool for the question you are trying to answer. + + +**shapiq vs SHAP** — shapiq's `TabPFNExplainer` removes features and +re-contextualizes the model, which matches how TabPFN natively handles missing +data. SHAP replaces absent features with random background samples. shapiq is +faster and produces explanations that are more faithful to the TabPFN models. We +recommend it as the default. + + +| Method | What it tells you | When to reach for it | +|--------|-------------------|----------------------| +| **shapiq** (recommended) | This is a modern version of classic SHAP library. Tells you which features drove a specific prediction. Uses a remove-and-recontextualize strategy that is native to how TabPFN handles missing data. | You want per-sample explanations and care about feature interactions, or you want the fastest Shapley-based method for TabPFN. | +| **SHAP** | Per-prediction feature attributions via imputation-based permutation. | You need explanations that are directly comparable to SHAP values from other models (XGBoost, Random Forest, etc.), or you are already using the SHAP library in your workflow. | +| **Partial Dependence / ICE** | The global, marginal effect of one or two features across the entire dataset. | You want to understand how a feature affects the model *on average* rather than for a single sample, or you want to visually compare TabPFN against another sklearn estimator. | +| **Feature Selection** | Which minimal subset of features preserves model performance. | You want to simplify your model or identify redundant features before deployment. | + +If you are still unsure which method to use, follow the table below to see the best tools for most common questions. + +| Question | Method | +|----------|--------| +| "Why did the model predict *this* for *this sample*?" | shapiq — `get_tabpfn_explainer` | +| "Which feature pairs interact most?" | shapiq — `get_tabpfn_explainer` with `index="k-SII"`, `max_order=2` | +| "How does feature X affect predictions globally?" | Partial Dependence — `partial_dependence_plots` | +| "I need SHAP values compatible with other models' explanations" | SHAP — `get_shap_values` or shapiq — `get_tabpfn_imputation_explainer` | +| "Which features can I drop without losing accuracy?" | Feature Selection — `feature_selection` | + +--- + +## Use Cases + +### Explain a prediction with shapiq + +Use Shapley interaction indices to understand not just which features matter, +but which feature *pairs* drive a prediction together. + + + ```python from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer + +X, y = load_breast_cancer(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +clf = TabPFNClassifier() +clf.fit(X_train, y_train) + +# k-SII captures pairwise feature interactions +explainer = get_tabpfn_explainer( + model=clf, + data=X_train, + labels=y_train, + index="k-SII", + max_order=2, +) + +sv = explainer.explain(X_test.iloc[0:1].values, budget=128) +print(sv) # top interactions ranked by magnitude +sv.plot_waterfall() # waterfall plot showing additive contributions +``` + + +```python +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNRegressor +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer -from tabpfn_extensions import TabPFNClassifier, interpretability +X, y = load_diabetes(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) -# Load example dataset -data = load_breast_cancer() -X, y = data.data, data.target -feature_names = data.feature_names -n_samples = 50 +reg = TabPFNRegressor() +reg.fit(X_train, y_train) -# Split data -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5) +# SV with max_order=1 gives plain Shapley values (no interactions) +explainer = get_tabpfn_explainer( + model=reg, + data=X_train, + labels=y_train, + index="SV", + max_order=1, +) + +sv = explainer.explain(X_test.iloc[0:1].values, budget=128) +sv.plot_waterfall() +``` + + + +### Visualize global feature effects with Partial Dependence Plots + +PDP and ICE curves show how a feature affects predictions across the whole +dataset, not just one sample. + + + +```python +from sklearn.datasets import load_breast_cancer +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.pdp import partial_dependence_plots + +X, y = load_breast_cancer(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) -# Initialize and train model clf = TabPFNClassifier() clf.fit(X_train, y_train) -# Calculate SHAP values -shap_values = interpretability.shap.get_shap_values( - estimator=clf, - test_x=X_test[:n_samples], - attribute_names=feature_names, - algorithm="permutation", +# PDP for two features; set kind="individual" for ICE, or "both" for overlay +partial_dependence_plots( + clf, X_test.values, + features=[0, 1], + kind="average", + target_class=1, +) +``` + + +```python +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNRegressor +from tabpfn_extensions.interpretability.pdp import partial_dependence_plots + +X, y = load_diabetes(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +reg = TabPFNRegressor() +reg.fit(X_train, y_train) + +# 1D partial dependence for features 0 and 2, with ICE overlay +partial_dependence_plots(reg, X_test.values, features=[0, 2], kind="both") + +# 2D interaction plot +partial_dependence_plots(reg, X_test.values, features=[(0, 2)]) +``` + + + +### Compare TabPFN vs other model explanations side-by-side + +If you need to compare TabPFN explanations against SHAP explanations of another +model using the exact same imputation strategy, use `get_tabpfn_imputation_explainer`. This wraps shapiq's generic +`TabularExplainer` with marginal imputation — the same approach the SHAP library uses. + +```python +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_imputation_explainer + +# Imputation-based explanation (same strategy as SHAP) +impute_explainer = get_tabpfn_imputation_explainer( + model=clf, + data=X_train, + index="SV", + max_order=1, + imputer="marginal", ) +sv_impute = impute_explainer.explain(X_test.iloc[0:1].values, budget=128) +``` -# Create visualization -fig = interpretability.shap.plot_shap(shap_values) +### Feature selection + +Sequential feature selection identifies the minimal subset of features that +contributes most to model performance: + +```python +from tabpfn_extensions.interpretability.feature_selection import feature_selection + +selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_select=5) +X_selected = selector.transform(X_test.values) +print("Selected feature indices:", selector.get_support(indices=True)) ``` +### Controlling the budget parameter + +The `budget` parameter in `explainer.explain()` sets how many coalition samples +shapiq evaluates to approximate Shapley values. Each coalition is a subset of +features — evaluating more of them produces more accurate estimates but costs +more model calls. + +In theory, exact Shapley values require evaluating all 2^n feature subsets +(e.g. 1024 for 10 features, ~1 billion for 30). In practice, shapiq's +approximation algorithms converge well before that: + +| Dataset size | Suggested budget | Notes | +|-------------|-----------------|-------| +| Few features (< 10) | `64`–`128` | Converges quickly; low budgets are fine | +| Medium (10–20 features) | `128`–`512` | Good accuracy/speed tradeoff | +| Many features (20+) | `512`–`2048` | Higher budgets help, but returns diminish | + +Start low (e.g. `budget=128`) and increase only if the resulting explanations +look noisy or unstable across repeated runs. + +### Compute SHAP values using `shap` library + +The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to `shapiq`. + + +SHAP's permutation explainer scales with the number of features. On datasets +with many features, expect longer runtimes. For faster results, +consider using `shapiq` or passing a smaller subset to `get_shap_values`. + + + + +```python +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap + +X, y = load_iris(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +clf = TabPFNClassifier() +clf.fit(X_train, y_train) + +shap_values = get_shap_values(clf, X_test) + +# Aggregate bar chart + per-sample beeswarm +plot_shap(shap_values) +``` + + +```python +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNRegressor +from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap + +X, y = load_diabetes(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +reg = TabPFNRegressor() +reg.fit(X_train, y_train) + +shap_values = get_shap_values(reg, X_test) +plot_shap(shap_values) +``` + + + +--- + +## Library Reference + +### `interpretability.shapiq.get_tabpfn_explainer` + +Creates a shapiq `TabPFNExplainer` that uses the remove-and-recontextualize +paradigm for TabPFN models. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `model` | `TabPFNClassifier \| TabPFNRegressor` | *required* | Fitted TabPFN model | +| `data` | `DataFrame \| ndarray` | *required* | Background / training data | +| `labels` | `DataFrame \| ndarray` | *required* | Labels for the background data | +| `index` | `str` | `"k-SII"` | Shapley index type. Options: `"SV"` (Shapley values), `"k-SII"` (k-Shapley interaction index), `"SII"`, `"FSII"`, `"FBII"`, `"STII"`. With `max_order=1`, `"k-SII"` reduces to standard Shapley values. | +| `max_order` | `int` | `2` | Maximum interaction order. Set to `1` for single-feature attributions only (no interactions). | +| `class_index` | `int \| None` | `None` | Class to explain for classification models. Defaults to class `1` when `None`. Ignored for regression. | +| `**kwargs` | | | Additional keyword arguments forwarded to `shapiq.TabPFNExplainer` | -### Core Functions +**Returns:** `shapiq.TabPFNExplainer` -| Function | Description | -| -------- | ----------- | -| **`get_shap_values`** | Calculates SHAP values for the provided model and data subset. | -| **`plot_shap`** | Generates an interactive visualization showing feature contributions for each prediction. | +Call `.explain(x, budget=N)` where `x` is a 2D numpy array of shape `(1, n_features)` and `budget` is the number of coalition samples to evaluate (see [Controlling the budget parameter](#controlling-the-budget-parameter)). Returns a `shapiq.InteractionValues` object with `.plot_waterfall()`, `.plot_force()`, and other visualization methods. + +--- + +### `interpretability.shapiq.get_tabpfn_imputation_explainer` + +Creates a shapiq `TabularExplainer` that uses imputation-based feature removal +(same strategy as SHAP). + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `model` | `TabPFNClassifier \| TabPFNRegressor` | *required* | Fitted TabPFN model | +| `data` | `DataFrame \| ndarray` | *required* | Background data for imputation sampling | +| `index` | `str` | `"k-SII"` | Shapley index type (same options as above) | +| `max_order` | `int` | `2` | Maximum interaction order | +| `imputer` | `str` | `"marginal"` | Imputation method. See shapiq docs for available options. | +| `class_index` | `int \| None` | `None` | Class to explain (classification only) | +| `**kwargs` | | | Additional keyword arguments forwarded to `shapiq.TabularExplainer` | + +**Returns:** `shapiq.TabularExplainer` + +Same `.explain(x, budget=N)` interface as above. + +--- + +### `interpretability.shap.get_shap_values` + +Computes SHAP values using a permutation-based explainer with automatic backend +selection for TabPFN models. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `estimator` | sklearn-compatible model | *required* | Fitted model (TabPFN or any sklearn estimator) | +| `test_x` | `DataFrame \| ndarray \| Tensor` | *required* | Samples to explain | +| `attribute_names` | `list[str] \| None` | `None` | Feature names when `test_x` is a numpy array | +| `**kwargs` | | | Forwarded to the underlying `shap.Explainer` | + +**Returns:** `shap.Explanation` — access `.values` for a numpy array of shape +`(n_samples, n_features)` (regression) or `(n_samples, n_features, n_classes)` +(classification). + +--- + +### `interpretability.shap.plot_shap` + +Visualizes SHAP values as an aggregate bar chart, a per-sample beeswarm plot, +and (if more than one sample) an interaction scatter for the most important +feature. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `shap_values` | `shap.Explanation` | *required* | Output from `get_shap_values` | + +**Returns:** `None` (displays matplotlib figures). + +--- + +### `interpretability.pdp.partial_dependence_plots` + +Convenience wrapper around sklearn's `PartialDependenceDisplay.from_estimator`. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `estimator` | sklearn-compatible model | *required* | Fitted estimator | +| `X` | `ndarray` | *required* | Input features | +| `features` | `list[int \| tuple[int, int]]` | *required* | Feature indices for 1D plots, or `(i, j)` tuples for 2D interaction plots | +| `grid_resolution` | `int` | `20` | Number of grid points per feature axis | +| `kind` | `str` | `"average"` | `"average"` for PDP, `"individual"` for ICE curves, `"both"` for overlay | +| `target_class` | `int \| None` | `None` | For classifiers: which class probability to plot | +| `ax` | `matplotlib.axes.Axes \| None` | `None` | Optional axes to plot into | +| `**kwargs` | | | Forwarded to `PartialDependenceDisplay.from_estimator` | + +**Returns:** `sklearn.inspection.PartialDependenceDisplay` + +--- + +### `interpretability.feature_selection.feature_selection` + +Forward sequential feature selection using cross-validation. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `estimator` | sklearn-compatible model | *required* | Fitted estimator | +| `X` | `ndarray` | *required* | Input features | +| `y` | `ndarray` | *required* | Target values | +| `n_features_to_select` | `int` | `3` | Number of features to select | +| `feature_names` | `list[str] \| None` | `None` | Feature names (optional) | +| `**kwargs` | | | Forwarded to `sklearn.feature_selection.SequentialFeatureSelector` | + +**Returns:** `sklearn.feature_selection.SequentialFeatureSelector` — call +`.transform(X)` to reduce features, or `.get_support(indices=True)` to get +selected indices. + +--- - - Check out our Google Colab for a demo. - + + + GPU setup, batch inference, and performance tuning. + + + Binary and multi-class classification guide. + + + Point estimates, quantiles, and full distributions. + + + Adapt TabPFN to your domain-specific data. + +