From 23bccf8116e12a51eeefe4552bb52bbaf2e6eacc Mon Sep 17 00:00:00 2001 From: Diana Kriuchkova Date: Thu, 2 Apr 2026 15:45:10 +0200 Subject: [PATCH 1/4] Rewrote interpretability doc --- capabilities/interpretability.mdx | 451 +++++++++++++++++++++++++++--- 1 file changed, 408 insertions(+), 43 deletions(-) diff --git a/capabilities/interpretability.mdx b/capabilities/interpretability.mdx index 1a5cdc0..0b28559 100644 --- a/capabilities/interpretability.mdx +++ b/capabilities/interpretability.mdx @@ -1,75 +1,440 @@ --- title: "Interpretability" -description: "Explain TabPFN predictions with SHAP values and feature selection." +description: "Explain TabPFN predictions with Shapley values, feature interactions, and partial dependence plots." --- -The [Interpretability Extension](https://github.com/PriorLabs/tabpfn-extensions/tree/main/src/tabpfn_extensions/interpretability) adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. This can be used to: +The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. This can be used to: - See which features drive model predictions. - Compare feature importance across samples. - Debug unexpected model behavior. -**Shapley Values** +The extension also provides an easy interface for TabPFN Partial Dependence Plots. -SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations. +--- -Data generation example Data generation example +## Installation -## Getting Started +```bash +pip install tabpfn-client "tabpfn-extensions[interpretability]" +``` -Install the `interpretability` extension: +This installs `shapiq`, `shap`, and the other dependencies needed for all three +methods. -```bash -pip install "tabpfn-extensions[interpretability]" +--- + +## Quickstart + +Train a model, explain a single prediction, and plot the result: + + +Interpretability computations are resource-intensive. This tutorial uses our API client. +For fully local execution instead of the cloud API, replace the `tabpfn_client` import with +`tabpfn` and ensure you have a GPU available. See [best practices](/best-practices) for GPU setup. All code +examples below work identically with either backend. + + +```python +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer + +X, y = load_iris(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +clf = TabPFNClassifier() +clf.fit(X_train, y_train) + +explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train) +sv = explainer.explain(X_test.iloc[0:1].values, budget=128) +sv.plot_waterfall() ``` -Then, use SHAP with any trained TabPFN model. This example shows how to use the `TabPFNClassifier`, however, a `TabPFNRegressor` can be used analogously. +--- +## Choosing a Method + +Before diving into each method, here is a summary to help you pick the right +tool for the question you are trying to answer. + +| Method | What it tells you | When to reach for it | +|--------|-------------------|----------------------| +| **shapiq** (recommended) | This is a modern version of classic SHAP library. Tells you which features drove a specific prediction. Uses a remove-and-recontextualize strategy that is native to how TabPFN handles missing data. | You want per-sample explanations and care about feature interactions, or you want the fastest Shapley-based method for TabPFN. | +| **SHAP** | Per-prediction feature attributions via imputation-based permutation. | You need explanations that are directly comparable to SHAP values from other models (XGBoost, Random Forest, etc.), or you are already using the SHAP library in your workflow. | +| **Partial Dependence / ICE** | The global, marginal effect of one or two features across the entire dataset. | You want to understand how a feature affects the model *on average* rather than for a single sample, or you want to visually compare TabPFN against another sklearn estimator. | + + +**shapiq vs SHAP** — shapiq's `TabPFNExplainer` removes features and +re-contextualises the model, which matches how TabPFN natively handles missing +data. SHAP replaces absent features with random background samples. shapiq is +faster and produces explanations that are more faithful to the TabPFN models. We +recommend it as the default. + + +--- + +## Use Cases + + +### Choosing the right method + +| Question | Method | +|----------|--------| +| "Why did the model predict *this* for *this sample*?" | shapiq `get_tabpfn_explainer` | +| "Which feature pairs interact most?" | shapiq with `index="k-SII"`, `max_order=2` | +| "How does feature X affect predictions globally?" | Partial dependence plots | +| "I need SHAP values compatible with other models' explanations" | `get_shap_values` or `get_tabpfn_imputation_explainer` | +| "Which features can I drop without losing accuracy?" | `feature_selection` | + +### Explain a prediction with shapiq + +Use Shapley interaction indices to understand not just which features matter, +but which feature *pairs* drive a prediction together. + + + ```python from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer + +X, y = load_breast_cancer(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +clf = TabPFNClassifier() +clf.fit(X_train, y_train) + +# k-SII captures pairwise feature interactions +explainer = get_tabpfn_explainer( + model=clf, + data=X_train, + labels=y_train, + index="k-SII", + max_order=2, +) + +sv = explainer.explain(X_test.iloc[0:1].values, budget=128) +print(sv) # top interactions ranked by magnitude +sv.plot_waterfall() # waterfall plot showing additive contributions +``` + + +```python +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNRegressor +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer + +X, y = load_diabetes(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +reg = TabPFNRegressor() +reg.fit(X_train, y_train) + +# SV with max_order=1 gives plain Shapley values (no interactions) +explainer = get_tabpfn_explainer( + model=reg, + data=X_train, + labels=y_train, + index="SV", + max_order=1, +) + +sv = explainer.explain(X_test.iloc[0:1].values, budget=128) +sv.plot_waterfall() +``` + + + +### Compute SHAP values for a batch of predictions + +The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to `shapiq`. + + + +```python +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap -from tabpfn_extensions import TabPFNClassifier, interpretability +X, y = load_iris(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) -# Load example dataset -data = load_breast_cancer() -X, y = data.data, data.target -feature_names = data.feature_names -n_samples = 50 +clf = TabPFNClassifier() +clf.fit(X_train, y_train) + +shap_values = get_shap_values(clf, X_test) + +# Aggregate bar chart + per-sample beeswarm +plot_shap(shap_values) +``` + + +```python +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNRegressor +from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap + +X, y = load_diabetes(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +reg = TabPFNRegressor() +reg.fit(X_train, y_train) + +shap_values = get_shap_values(reg, X_test) +plot_shap(shap_values) +``` + + + + +SHAP's permutation explainer scales with the number of features. On datasets +with many features, expect longer runtimes. For faster results, +consider using `shapiq` or passing a smaller subset to `get_shap_values`. + + +### Visualise global feature effects with Partial Dependence Plots + +PDP and ICE curves show how a feature affects predictions across the whole +dataset, not just one sample. + + + +```python +from sklearn.datasets import load_breast_cancer +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.pdp import partial_dependence_plots -# Split data -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5) +X, y = load_breast_cancer(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) -# Initialize and train model clf = TabPFNClassifier() clf.fit(X_train, y_train) -# Calculate SHAP values -shap_values = interpretability.shap.get_shap_values( - estimator=clf, - test_x=X_test[:n_samples], - attribute_names=feature_names, - algorithm="permutation", +# PDP for two features; set kind="individual" for ICE, or "both" for overlay +partial_dependence_plots( + clf, X_test.values, + features=[0, 1], + kind="average", + target_class=1, ) +``` + + +```python +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNRegressor +from tabpfn_extensions.interpretability.pdp import partial_dependence_plots -# Create visualization -fig = interpretability.shap.plot_shap(shap_values) +X, y = load_diabetes(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +reg = TabPFNRegressor() +reg.fit(X_train, y_train) + +# 1D partial dependence for features 0 and 2, with ICE overlay +partial_dependence_plots(reg, X_test.values, features=[0, 2], kind="both") + +# 2D interaction plot +partial_dependence_plots(reg, X_test.values, features=[(0, 2)]) +``` + + + +### Compare TabPFN vs other model explanations side-by-side + +If you need to compare TabPFN explanations against SHAP explanations of another +model using the exact same imputation strategy, use +`get_tabpfn_imputation_explainer`. This wraps shapiq's generic +`TabularExplainer` with marginal imputation — the same approach the SHAP library +uses. + +```python +from tabpfn_extensions.interpretability.shapiq import ( + get_tabpfn_explainer, + get_tabpfn_imputation_explainer, +) + +# Native TabPFN explanation (remove-and-recontextualize) +native_explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train) +sv_native = native_explainer.explain(X_test.iloc[0:1].values, budget=128) + +# Imputation-based explanation (same strategy as SHAP) +impute_explainer = get_tabpfn_imputation_explainer( + model=clf, + data=X_train, + index="SV", + max_order=1, + imputer="marginal", +) +sv_impute = impute_explainer.explain(X_test.iloc[0:1].values, budget=128) ``` -### Core Functions +### Feature selection + +Sequential feature selection identifies the minimal subset of features that +contributes most to model performance: + +```python +from tabpfn_extensions.interpretability.feature_selection import feature_selection + +selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_select=5) +X_selected = selector.transform(X_test.values) +print("Selected feature indices:", selector.get_support(indices=True)) +``` +### Controlling the budget parameter + +The `budget` parameter in `explainer.explain()` sets how many coalition samples +shapiq evaluates to approximate Shapley values. Each coalition is a subset of +features — evaluating more of them produces more accurate estimates but costs +more model calls. + +In theory, exact Shapley values require evaluating all 2^n feature subsets +(e.g. 1024 for 10 features, ~1 billion for 30). In practice, shapiq's +approximation algorithms converge well before that: + +| Dataset size | Suggested budget | Notes | +|-------------|-----------------|-------| +| Few features (< 10) | `64`–`128` | Converges quickly; low budgets are fine | +| Medium (10–20 features) | `128`–`512` | Good accuracy/speed tradeoff | +| Many features (20+) | `512`–`2048` | Higher budgets help, but returns diminish | + +Start low (e.g. `budget=128`) and increase only if the resulting explanations +look noisy or unstable across repeated runs. + +--- + +## Library Reference + +### `interpretability.shapiq.get_tabpfn_explainer` + +Creates a shapiq `TabPFNExplainer` that uses the remove-and-recontextualize +paradigm for TabPFN models. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `model` | `TabPFNClassifier \| TabPFNRegressor` | *required* | Fitted TabPFN model | +| `data` | `DataFrame \| ndarray` | *required* | Background / training data | +| `labels` | `DataFrame \| ndarray` | *required* | Labels for the background data | +| `index` | `str` | `"k-SII"` | Shapley index type. Options: `"SV"` (Shapley values), `"k-SII"` (k-Shapley interaction index), `"SII"`, `"FSII"`, `"FBII"`, `"STII"`. With `max_order=1`, `"k-SII"` reduces to standard Shapley values. | +| `max_order` | `int` | `2` | Maximum interaction order. Set to `1` for single-feature attributions only (no interactions). | +| `class_index` | `int \| None` | `None` | Class to explain for classification models. Defaults to class `1` when `None`. Ignored for regression. | +| `**kwargs` | | | Additional keyword arguments forwarded to `shapiq.TabPFNExplainer` | + +**Returns:** `shapiq.TabPFNExplainer` -| Function | Description | -| -------- | ----------- | -| **`get_shap_values`** | Calculates SHAP values for the provided model and data subset. | -| **`plot_shap`** | Generates an interactive visualization showing feature contributions for each prediction. | +Call `.explain(x, budget=N)` where `x` is a 2D numpy array of shape `(1, n_features)` and `budget` is the number of coalition samples to evaluate (see [Controlling the budget parameter](#controlling-the-budget-parameter)). Returns a `shapiq.InteractionValues` object with `.plot_waterfall()`, `.plot_force()`, and other visualization methods. + +--- + +### `interpretability.shapiq.get_tabpfn_imputation_explainer` + +Creates a shapiq `TabularExplainer` that uses imputation-based feature removal +(same strategy as SHAP). + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `model` | `TabPFNClassifier \| TabPFNRegressor` | *required* | Fitted TabPFN model | +| `data` | `DataFrame \| ndarray` | *required* | Background data for imputation sampling | +| `index` | `str` | `"k-SII"` | Shapley index type (same options as above) | +| `max_order` | `int` | `2` | Maximum interaction order | +| `imputer` | `str` | `"marginal"` | Imputation method. See shapiq docs for available options. | +| `class_index` | `int \| None` | `None` | Class to explain (classification only) | +| `**kwargs` | | | Additional keyword arguments forwarded to `shapiq.TabularExplainer` | + +**Returns:** `shapiq.TabularExplainer` + +Same `.explain(x, budget=N)` interface as above. + +--- + +### `interpretability.shap.get_shap_values` + +Computes SHAP values using a permutation-based explainer with automatic backend +selection for TabPFN models. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `estimator` | sklearn-compatible model | *required* | Fitted model (TabPFN or any sklearn estimator) | +| `test_x` | `DataFrame \| ndarray \| Tensor` | *required* | Samples to explain | +| `attribute_names` | `list[str] \| None` | `None` | Feature names when `test_x` is a numpy array | +| `**kwargs` | | | Forwarded to the underlying `shap.Explainer` | + +**Returns:** `shap.Explanation` — access `.values` for a numpy array of shape +`(n_samples, n_features)` (regression) or `(n_samples, n_features, n_classes)` +(classification). + +--- + +### `interpretability.shap.plot_shap` + +Visualises SHAP values as an aggregate bar chart, a per-sample beeswarm plot, +and (if more than one sample) an interaction scatter for the most important +feature. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `shap_values` | `shap.Explanation` | *required* | Output from `get_shap_values` | + +**Returns:** `None` (displays matplotlib figures). + +--- + +### `interpretability.pdp.partial_dependence_plots` + +Convenience wrapper around sklearn's `PartialDependenceDisplay.from_estimator`. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `estimator` | sklearn-compatible model | *required* | Fitted estimator | +| `X` | `ndarray` | *required* | Input features | +| `features` | `list[int \| tuple[int, int]]` | *required* | Feature indices for 1D plots, or `(i, j)` tuples for 2D interaction plots | +| `grid_resolution` | `int` | `20` | Number of grid points per feature axis | +| `kind` | `str` | `"average"` | `"average"` for PDP, `"individual"` for ICE curves, `"both"` for overlay | +| `target_class` | `int \| None` | `None` | For classifiers: which class probability to plot | +| `ax` | `matplotlib.axes.Axes \| None` | `None` | Optional axes to plot into | +| `**kwargs` | | | Forwarded to `PartialDependenceDisplay.from_estimator` | + +**Returns:** `sklearn.inspection.PartialDependenceDisplay` + +--- + +### `interpretability.feature_selection.feature_selection` + +Forward sequential feature selection using cross-validation. + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `estimator` | sklearn-compatible model | *required* | Fitted estimator | +| `X` | `ndarray` | *required* | Input features | +| `y` | `ndarray` | *required* | Target values | +| `n_features_to_select` | `int` | `3` | Number of features to select | +| `feature_names` | `list[str] \| None` | `None` | Feature names (optional) | +| `**kwargs` | | | Forwarded to `sklearn.feature_selection.SequentialFeatureSelector` | + +**Returns:** `sklearn.feature_selection.SequentialFeatureSelector` — call +`.transform(X)` to reduce features, or `.get_support(indices=True)` to get +selected indices. + +--- - - Check out our Google Colab for a demo. - + + + GPU setup, batch inference, and performance tuning. + + + Binary and multi-class classification guide. + + + Point estimates, quantiles, and full distributions. + + + Adapt TabPFN to your domain-specific data. + + From 62c9301ed88e3de37becc6fc3cac08f8a9157311 Mon Sep 17 00:00:00 2001 From: Diana Kriuchkova Date: Thu, 2 Apr 2026 15:59:40 +0200 Subject: [PATCH 2/4] change the order of the sections, minor edits --- capabilities/interpretability.mdx | 139 +++++++++++++++--------------- 1 file changed, 69 insertions(+), 70 deletions(-) diff --git a/capabilities/interpretability.mdx b/capabilities/interpretability.mdx index 0b28559..27dfc6f 100644 --- a/capabilities/interpretability.mdx +++ b/capabilities/interpretability.mdx @@ -8,7 +8,7 @@ The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support - Compare feature importance across samples. - Debug unexpected model behavior. -The extension also provides an easy interface for TabPFN Partial Dependence Plots. +The extension also provides an easy interface for TabPFN Partial Dependence Plots and feature selection. --- @@ -58,12 +58,6 @@ sv.plot_waterfall() Before diving into each method, here is a summary to help you pick the right tool for the question you are trying to answer. -| Method | What it tells you | When to reach for it | -|--------|-------------------|----------------------| -| **shapiq** (recommended) | This is a modern version of classic SHAP library. Tells you which features drove a specific prediction. Uses a remove-and-recontextualize strategy that is native to how TabPFN handles missing data. | You want per-sample explanations and care about feature interactions, or you want the fastest Shapley-based method for TabPFN. | -| **SHAP** | Per-prediction feature attributions via imputation-based permutation. | You need explanations that are directly comparable to SHAP values from other models (XGBoost, Random Forest, etc.), or you are already using the SHAP library in your workflow. | -| **Partial Dependence / ICE** | The global, marginal effect of one or two features across the entire dataset. | You want to understand how a feature affects the model *on average* rather than for a single sample, or you want to visually compare TabPFN against another sklearn estimator. | - **shapiq vs SHAP** — shapiq's `TabPFNExplainer` removes features and re-contextualises the model, which matches how TabPFN natively handles missing @@ -72,20 +66,26 @@ faster and produces explanations that are more faithful to the TabPFN models. We recommend it as the default. ---- - -## Use Cases - +| Method | What it tells you | When to reach for it | +|--------|-------------------|----------------------| +| **shapiq** (recommended) | This is a modern version of classic SHAP library. Tells you which features drove a specific prediction. Uses a remove-and-recontextualize strategy that is native to how TabPFN handles missing data. | You want per-sample explanations and care about feature interactions, or you want the fastest Shapley-based method for TabPFN. | +| **SHAP** | Per-prediction feature attributions via imputation-based permutation. | You need explanations that are directly comparable to SHAP values from other models (XGBoost, Random Forest, etc.), or you are already using the SHAP library in your workflow. | +| **Partial Dependence / ICE** | The global, marginal effect of one or two features across the entire dataset. | You want to understand how a feature affects the model *on average* rather than for a single sample, or you want to visually compare TabPFN against another sklearn estimator. | +| **Feature Selection** | Which minimal subset of features preserves model performance. | You want to simplify your model or identify redundant features before deployment. | -### Choosing the right method +If you are still unsure which method to use, follow the table below to see the best tools for most common questions. | Question | Method | |----------|--------| -| "Why did the model predict *this* for *this sample*?" | shapiq `get_tabpfn_explainer` | -| "Which feature pairs interact most?" | shapiq with `index="k-SII"`, `max_order=2` | -| "How does feature X affect predictions globally?" | Partial dependence plots | -| "I need SHAP values compatible with other models' explanations" | `get_shap_values` or `get_tabpfn_imputation_explainer` | -| "Which features can I drop without losing accuracy?" | `feature_selection` | +| "Why did the model predict *this* for *this sample*?" | shapiq — `get_tabpfn_explainer` | +| "Which feature pairs interact most?" | shapiq — `get_tabpfn_explainer` with `index="k-SII"`, `max_order=2` | +| "How does feature X affect predictions globally?" | Partial Dependence — `partial_dependence_plots` | +| "I need SHAP values compatible with other models' explanations" | SHAP — `get_shap_values` or shapiq — `get_tabpfn_imputation_explainer` | +| "Which features can I drop without losing accuracy?" | Feature Selection — `feature_selection` | + +--- + +## Use Cases ### Explain a prediction with shapiq @@ -148,55 +148,6 @@ sv.plot_waterfall() -### Compute SHAP values for a batch of predictions - -The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to `shapiq`. - - - -```python -from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split -from tabpfn_client import TabPFNClassifier -from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap - -X, y = load_iris(return_X_y=True, as_frame=True) -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) - -clf = TabPFNClassifier() -clf.fit(X_train, y_train) - -shap_values = get_shap_values(clf, X_test) - -# Aggregate bar chart + per-sample beeswarm -plot_shap(shap_values) -``` - - -```python -from sklearn.datasets import load_diabetes -from sklearn.model_selection import train_test_split -from tabpfn_client import TabPFNRegressor -from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap - -X, y = load_diabetes(return_X_y=True, as_frame=True) -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) - -reg = TabPFNRegressor() -reg.fit(X_train, y_train) - -shap_values = get_shap_values(reg, X_test) -plot_shap(shap_values) -``` - - - - -SHAP's permutation explainer scales with the number of features. On datasets -with many features, expect longer runtimes. For faster results, -consider using `shapiq` or passing a smaller subset to `get_shap_values`. - - ### Visualise global feature effects with Partial Dependence Plots PDP and ICE curves show how a feature affects predictions across the whole @@ -250,10 +201,8 @@ partial_dependence_plots(reg, X_test.values, features=[(0, 2)]) ### Compare TabPFN vs other model explanations side-by-side If you need to compare TabPFN explanations against SHAP explanations of another -model using the exact same imputation strategy, use -`get_tabpfn_imputation_explainer`. This wraps shapiq's generic -`TabularExplainer` with marginal imputation — the same approach the SHAP library -uses. +model using the exact same imputation strategy, use `get_tabpfn_imputation_explainer`. This wraps shapiq's generic +`TabularExplainer` with marginal imputation — the same approach the SHAP library uses. ```python from tabpfn_extensions.interpretability.shapiq import ( @@ -288,6 +237,56 @@ selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_ X_selected = selector.transform(X_test.values) print("Selected feature indices:", selector.get_support(indices=True)) ``` + +### Compute SHAP values for a batch of predictions + +The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to `shapiq`. + + +SHAP's permutation explainer scales with the number of features. On datasets +with many features, expect longer runtimes. For faster results, +consider using `shapiq` or passing a smaller subset to `get_shap_values`. + + + + +```python +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNClassifier +from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap + +X, y = load_iris(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +clf = TabPFNClassifier() +clf.fit(X_train, y_train) + +shap_values = get_shap_values(clf, X_test) + +# Aggregate bar chart + per-sample beeswarm +plot_shap(shap_values) +``` + + +```python +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split +from tabpfn_client import TabPFNRegressor +from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap + +X, y = load_diabetes(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + +reg = TabPFNRegressor() +reg.fit(X_train, y_train) + +shap_values = get_shap_values(reg, X_test) +plot_shap(shap_values) +``` + + + ### Controlling the budget parameter The `budget` parameter in `explainer.explain()` sets how many coalition samples From 1c451da363c48f2f7f2dd08f9d7358088b728c8c Mon Sep 17 00:00:00 2001 From: Diana Kriuchkova Date: Thu, 2 Apr 2026 16:05:33 +0200 Subject: [PATCH 3/4] add images from past tutorial --- capabilities/interpretability.mdx | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/capabilities/interpretability.mdx b/capabilities/interpretability.mdx index 27dfc6f..a7441af 100644 --- a/capabilities/interpretability.mdx +++ b/capabilities/interpretability.mdx @@ -3,13 +3,23 @@ title: "Interpretability" description: "Explain TabPFN predictions with Shapley values, feature interactions, and partial dependence plots." --- -The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. This can be used to: +The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations. This can be used to: - See which features drive model predictions. - Compare feature importance across samples. - Debug unexpected model behavior. The extension also provides an easy interface for TabPFN Partial Dependence Plots and feature selection. +Data generation example Data generation example + --- ## Installation From 439e4b49c33e3d4142b2a116ff50c090ac3191d5 Mon Sep 17 00:00:00 2001 From: Diana Kriuchkova Date: Thu, 2 Apr 2026 16:14:42 +0200 Subject: [PATCH 4/4] minor changes, american english fixes --- capabilities/interpretability.mdx | 75 ++++++++++++++----------------- 1 file changed, 33 insertions(+), 42 deletions(-) diff --git a/capabilities/interpretability.mdx b/capabilities/interpretability.mdx index a7441af..794ade5 100644 --- a/capabilities/interpretability.mdx +++ b/capabilities/interpretability.mdx @@ -3,22 +3,21 @@ title: "Interpretability" description: "Explain TabPFN predictions with Shapley values, feature interactions, and partial dependence plots." --- -The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations. This can be used to: +The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. + +SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations. + +This can be used to: - See which features drive model predictions. - Compare feature importance across samples. - Debug unexpected model behavior. The extension also provides an easy interface for TabPFN Partial Dependence Plots and feature selection. -Data generation example Data generation example +
+ SHAP waterfall plot + SHAP feature importance +
--- @@ -70,7 +69,7 @@ tool for the question you are trying to answer. **shapiq vs SHAP** — shapiq's `TabPFNExplainer` removes features and -re-contextualises the model, which matches how TabPFN natively handles missing +re-contextualizes the model, which matches how TabPFN natively handles missing data. SHAP replaces absent features with random background samples. shapiq is faster and produces explanations that are more faithful to the TabPFN models. We recommend it as the default. @@ -158,7 +157,7 @@ sv.plot_waterfall() -### Visualise global feature effects with Partial Dependence Plots +### Visualize global feature effects with Partial Dependence Plots PDP and ICE curves show how a feature affects predictions across the whole dataset, not just one sample. @@ -215,14 +214,7 @@ model using the exact same imputation strategy, use `get_tabpfn_imputation_expla `TabularExplainer` with marginal imputation — the same approach the SHAP library uses. ```python -from tabpfn_extensions.interpretability.shapiq import ( - get_tabpfn_explainer, - get_tabpfn_imputation_explainer, -) - -# Native TabPFN explanation (remove-and-recontextualize) -native_explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train) -sv_native = native_explainer.explain(X_test.iloc[0:1].values, budget=128) +from tabpfn_extensions.interpretability.shapiq import get_tabpfn_imputation_explainer # Imputation-based explanation (same strategy as SHAP) impute_explainer = get_tabpfn_imputation_explainer( @@ -247,8 +239,27 @@ selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_ X_selected = selector.transform(X_test.values) print("Selected feature indices:", selector.get_support(indices=True)) ``` +### Controlling the budget parameter + +The `budget` parameter in `explainer.explain()` sets how many coalition samples +shapiq evaluates to approximate Shapley values. Each coalition is a subset of +features — evaluating more of them produces more accurate estimates but costs +more model calls. + +In theory, exact Shapley values require evaluating all 2^n feature subsets +(e.g. 1024 for 10 features, ~1 billion for 30). In practice, shapiq's +approximation algorithms converge well before that: -### Compute SHAP values for a batch of predictions +| Dataset size | Suggested budget | Notes | +|-------------|-----------------|-------| +| Few features (< 10) | `64`–`128` | Converges quickly; low budgets are fine | +| Medium (10–20 features) | `128`–`512` | Good accuracy/speed tradeoff | +| Many features (20+) | `512`–`2048` | Higher budgets help, but returns diminish | + +Start low (e.g. `budget=128`) and increase only if the resulting explanations +look noisy or unstable across repeated runs. + +### Compute SHAP values using `shap` library The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to `shapiq`. @@ -297,26 +308,6 @@ plot_shap(shap_values) -### Controlling the budget parameter - -The `budget` parameter in `explainer.explain()` sets how many coalition samples -shapiq evaluates to approximate Shapley values. Each coalition is a subset of -features — evaluating more of them produces more accurate estimates but costs -more model calls. - -In theory, exact Shapley values require evaluating all 2^n feature subsets -(e.g. 1024 for 10 features, ~1 billion for 30). In practice, shapiq's -approximation algorithms converge well before that: - -| Dataset size | Suggested budget | Notes | -|-------------|-----------------|-------| -| Few features (< 10) | `64`–`128` | Converges quickly; low budgets are fine | -| Medium (10–20 features) | `128`–`512` | Good accuracy/speed tradeoff | -| Many features (20+) | `512`–`2048` | Higher budgets help, but returns diminish | - -Start low (e.g. `budget=128`) and increase only if the resulting explanations -look noisy or unstable across repeated runs. - --- ## Library Reference @@ -383,7 +374,7 @@ selection for TabPFN models. ### `interpretability.shap.plot_shap` -Visualises SHAP values as an aggregate bar chart, a per-sample beeswarm plot, +Visualizes SHAP values as an aggregate bar chart, a per-sample beeswarm plot, and (if more than one sample) an interaction scatter for the most important feature.