diff --git a/capabilities/interpretability.mdx b/capabilities/interpretability.mdx
index 1a5cdc0..794ade5 100644
--- a/capabilities/interpretability.mdx
+++ b/capabilities/interpretability.mdx
@@ -1,75 +1,440 @@
---
title: "Interpretability"
-description: "Explain TabPFN predictions with SHAP values and feature selection."
+description: "Explain TabPFN predictions with Shapley values, feature interactions, and partial dependence plots."
---
-The [Interpretability Extension](https://github.com/PriorLabs/tabpfn-extensions/tree/main/src/tabpfn_extensions/interpretability) adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. This can be used to:
+The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction.
+
+SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations.
+
+This can be used to:
- See which features drive model predictions.
- Compare feature importance across samples.
- Debug unexpected model behavior.
-**Shapley Values**
-
-SHAP values explain a single prediction by attributing the prediction's deviation from the baseline (mean prediction) to individual features. They provide a consistent, game-theoretic measure of feature influence. Mathematically, each SHAP value represents the marginal contribution of a feature across all possible feature combinations.
+The extension also provides an easy interface for TabPFN Partial Dependence Plots and feature selection.
-
+
+

+

+
-## Getting Started
+---
-Install the `interpretability` extension:
+## Installation
```bash
-pip install "tabpfn-extensions[interpretability]"
+pip install tabpfn-client "tabpfn-extensions[interpretability]"
```
-Then, use SHAP with any trained TabPFN model. This example shows how to use the `TabPFNClassifier`, however, a `TabPFNRegressor` can be used analogously.
+This installs `shapiq`, `shap`, and the other dependencies needed for all three
+methods.
+
+---
+
+## Quickstart
+
+Train a model, explain a single prediction, and plot the result:
+
+
+Interpretability computations are resource-intensive. This tutorial uses our API client.
+For fully local execution instead of the cloud API, replace the `tabpfn_client` import with
+`tabpfn` and ensure you have a GPU available. See [best practices](/best-practices) for GPU setup. All code
+examples below work identically with either backend.
+
+```python
+from sklearn.datasets import load_iris
+from sklearn.model_selection import train_test_split
+from tabpfn_client import TabPFNClassifier
+from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer
+
+X, y = load_iris(return_X_y=True, as_frame=True)
+X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+
+clf = TabPFNClassifier()
+clf.fit(X_train, y_train)
+
+explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train)
+sv = explainer.explain(X_test.iloc[0:1].values, budget=128)
+sv.plot_waterfall()
+```
+
+---
+
+## Choosing a Method
+
+Before diving into each method, here is a summary to help you pick the right
+tool for the question you are trying to answer.
+
+
+**shapiq vs SHAP** — shapiq's `TabPFNExplainer` removes features and
+re-contextualizes the model, which matches how TabPFN natively handles missing
+data. SHAP replaces absent features with random background samples. shapiq is
+faster and produces explanations that are more faithful to the TabPFN models. We
+recommend it as the default.
+
+
+| Method | What it tells you | When to reach for it |
+|--------|-------------------|----------------------|
+| **shapiq** (recommended) | This is a modern version of classic SHAP library. Tells you which features drove a specific prediction. Uses a remove-and-recontextualize strategy that is native to how TabPFN handles missing data. | You want per-sample explanations and care about feature interactions, or you want the fastest Shapley-based method for TabPFN. |
+| **SHAP** | Per-prediction feature attributions via imputation-based permutation. | You need explanations that are directly comparable to SHAP values from other models (XGBoost, Random Forest, etc.), or you are already using the SHAP library in your workflow. |
+| **Partial Dependence / ICE** | The global, marginal effect of one or two features across the entire dataset. | You want to understand how a feature affects the model *on average* rather than for a single sample, or you want to visually compare TabPFN against another sklearn estimator. |
+| **Feature Selection** | Which minimal subset of features preserves model performance. | You want to simplify your model or identify redundant features before deployment. |
+
+If you are still unsure which method to use, follow the table below to see the best tools for most common questions.
+
+| Question | Method |
+|----------|--------|
+| "Why did the model predict *this* for *this sample*?" | shapiq — `get_tabpfn_explainer` |
+| "Which feature pairs interact most?" | shapiq — `get_tabpfn_explainer` with `index="k-SII"`, `max_order=2` |
+| "How does feature X affect predictions globally?" | Partial Dependence — `partial_dependence_plots` |
+| "I need SHAP values compatible with other models' explanations" | SHAP — `get_shap_values` or shapiq — `get_tabpfn_imputation_explainer` |
+| "Which features can I drop without losing accuracy?" | Feature Selection — `feature_selection` |
+
+---
+
+## Use Cases
+
+### Explain a prediction with shapiq
+
+Use Shapley interaction indices to understand not just which features matter,
+but which feature *pairs* drive a prediction together.
+
+
+
```python
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
+from tabpfn_client import TabPFNClassifier
+from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer
+
+X, y = load_breast_cancer(return_X_y=True, as_frame=True)
+X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+
+clf = TabPFNClassifier()
+clf.fit(X_train, y_train)
+
+# k-SII captures pairwise feature interactions
+explainer = get_tabpfn_explainer(
+ model=clf,
+ data=X_train,
+ labels=y_train,
+ index="k-SII",
+ max_order=2,
+)
+
+sv = explainer.explain(X_test.iloc[0:1].values, budget=128)
+print(sv) # top interactions ranked by magnitude
+sv.plot_waterfall() # waterfall plot showing additive contributions
+```
+
+
+```python
+from sklearn.datasets import load_diabetes
+from sklearn.model_selection import train_test_split
+from tabpfn_client import TabPFNRegressor
+from tabpfn_extensions.interpretability.shapiq import get_tabpfn_explainer
-from tabpfn_extensions import TabPFNClassifier, interpretability
+X, y = load_diabetes(return_X_y=True, as_frame=True)
+X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
-# Load example dataset
-data = load_breast_cancer()
-X, y = data.data, data.target
-feature_names = data.feature_names
-n_samples = 50
+reg = TabPFNRegressor()
+reg.fit(X_train, y_train)
-# Split data
-X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)
+# SV with max_order=1 gives plain Shapley values (no interactions)
+explainer = get_tabpfn_explainer(
+ model=reg,
+ data=X_train,
+ labels=y_train,
+ index="SV",
+ max_order=1,
+)
+
+sv = explainer.explain(X_test.iloc[0:1].values, budget=128)
+sv.plot_waterfall()
+```
+
+
+
+### Visualize global feature effects with Partial Dependence Plots
+
+PDP and ICE curves show how a feature affects predictions across the whole
+dataset, not just one sample.
+
+
+
+```python
+from sklearn.datasets import load_breast_cancer
+from sklearn.model_selection import train_test_split
+from tabpfn_client import TabPFNClassifier
+from tabpfn_extensions.interpretability.pdp import partial_dependence_plots
+
+X, y = load_breast_cancer(return_X_y=True, as_frame=True)
+X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
-# Initialize and train model
clf = TabPFNClassifier()
clf.fit(X_train, y_train)
-# Calculate SHAP values
-shap_values = interpretability.shap.get_shap_values(
- estimator=clf,
- test_x=X_test[:n_samples],
- attribute_names=feature_names,
- algorithm="permutation",
+# PDP for two features; set kind="individual" for ICE, or "both" for overlay
+partial_dependence_plots(
+ clf, X_test.values,
+ features=[0, 1],
+ kind="average",
+ target_class=1,
+)
+```
+
+
+```python
+from sklearn.datasets import load_diabetes
+from sklearn.model_selection import train_test_split
+from tabpfn_client import TabPFNRegressor
+from tabpfn_extensions.interpretability.pdp import partial_dependence_plots
+
+X, y = load_diabetes(return_X_y=True, as_frame=True)
+X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+
+reg = TabPFNRegressor()
+reg.fit(X_train, y_train)
+
+# 1D partial dependence for features 0 and 2, with ICE overlay
+partial_dependence_plots(reg, X_test.values, features=[0, 2], kind="both")
+
+# 2D interaction plot
+partial_dependence_plots(reg, X_test.values, features=[(0, 2)])
+```
+
+
+
+### Compare TabPFN vs other model explanations side-by-side
+
+If you need to compare TabPFN explanations against SHAP explanations of another
+model using the exact same imputation strategy, use `get_tabpfn_imputation_explainer`. This wraps shapiq's generic
+`TabularExplainer` with marginal imputation — the same approach the SHAP library uses.
+
+```python
+from tabpfn_extensions.interpretability.shapiq import get_tabpfn_imputation_explainer
+
+# Imputation-based explanation (same strategy as SHAP)
+impute_explainer = get_tabpfn_imputation_explainer(
+ model=clf,
+ data=X_train,
+ index="SV",
+ max_order=1,
+ imputer="marginal",
)
+sv_impute = impute_explainer.explain(X_test.iloc[0:1].values, budget=128)
+```
-# Create visualization
-fig = interpretability.shap.plot_shap(shap_values)
+### Feature selection
+
+Sequential feature selection identifies the minimal subset of features that
+contributes most to model performance:
+
+```python
+from tabpfn_extensions.interpretability.feature_selection import feature_selection
+
+selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_select=5)
+X_selected = selector.transform(X_test.values)
+print("Selected feature indices:", selector.get_support(indices=True))
```
+### Controlling the budget parameter
+
+The `budget` parameter in `explainer.explain()` sets how many coalition samples
+shapiq evaluates to approximate Shapley values. Each coalition is a subset of
+features — evaluating more of them produces more accurate estimates but costs
+more model calls.
+
+In theory, exact Shapley values require evaluating all 2^n feature subsets
+(e.g. 1024 for 10 features, ~1 billion for 30). In practice, shapiq's
+approximation algorithms converge well before that:
+
+| Dataset size | Suggested budget | Notes |
+|-------------|-----------------|-------|
+| Few features (< 10) | `64`–`128` | Converges quickly; low budgets are fine |
+| Medium (10–20 features) | `128`–`512` | Good accuracy/speed tradeoff |
+| Many features (20+) | `512`–`2048` | Higher budgets help, but returns diminish |
+
+Start low (e.g. `budget=128`) and increase only if the resulting explanations
+look noisy or unstable across repeated runs.
+
+### Compute SHAP values using `shap` library
+
+The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to `shapiq`.
+
+
+SHAP's permutation explainer scales with the number of features. On datasets
+with many features, expect longer runtimes. For faster results,
+consider using `shapiq` or passing a smaller subset to `get_shap_values`.
+
+
+
+
+```python
+from sklearn.datasets import load_iris
+from sklearn.model_selection import train_test_split
+from tabpfn_client import TabPFNClassifier
+from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap
+
+X, y = load_iris(return_X_y=True, as_frame=True)
+X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+
+clf = TabPFNClassifier()
+clf.fit(X_train, y_train)
+
+shap_values = get_shap_values(clf, X_test)
+
+# Aggregate bar chart + per-sample beeswarm
+plot_shap(shap_values)
+```
+
+
+```python
+from sklearn.datasets import load_diabetes
+from sklearn.model_selection import train_test_split
+from tabpfn_client import TabPFNRegressor
+from tabpfn_extensions.interpretability.shap import get_shap_values, plot_shap
+
+X, y = load_diabetes(return_X_y=True, as_frame=True)
+X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
+
+reg = TabPFNRegressor()
+reg.fit(X_train, y_train)
+
+shap_values = get_shap_values(reg, X_test)
+plot_shap(shap_values)
+```
+
+
+
+---
+
+## Library Reference
+
+### `interpretability.shapiq.get_tabpfn_explainer`
+
+Creates a shapiq `TabPFNExplainer` that uses the remove-and-recontextualize
+paradigm for TabPFN models.
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `model` | `TabPFNClassifier \| TabPFNRegressor` | *required* | Fitted TabPFN model |
+| `data` | `DataFrame \| ndarray` | *required* | Background / training data |
+| `labels` | `DataFrame \| ndarray` | *required* | Labels for the background data |
+| `index` | `str` | `"k-SII"` | Shapley index type. Options: `"SV"` (Shapley values), `"k-SII"` (k-Shapley interaction index), `"SII"`, `"FSII"`, `"FBII"`, `"STII"`. With `max_order=1`, `"k-SII"` reduces to standard Shapley values. |
+| `max_order` | `int` | `2` | Maximum interaction order. Set to `1` for single-feature attributions only (no interactions). |
+| `class_index` | `int \| None` | `None` | Class to explain for classification models. Defaults to class `1` when `None`. Ignored for regression. |
+| `**kwargs` | | | Additional keyword arguments forwarded to `shapiq.TabPFNExplainer` |
-### Core Functions
+**Returns:** `shapiq.TabPFNExplainer`
-| Function | Description |
-| -------- | ----------- |
-| **`get_shap_values`** | Calculates SHAP values for the provided model and data subset. |
-| **`plot_shap`** | Generates an interactive visualization showing feature contributions for each prediction. |
+Call `.explain(x, budget=N)` where `x` is a 2D numpy array of shape `(1, n_features)` and `budget` is the number of coalition samples to evaluate (see [Controlling the budget parameter](#controlling-the-budget-parameter)). Returns a `shapiq.InteractionValues` object with `.plot_waterfall()`, `.plot_force()`, and other visualization methods.
+
+---
+
+### `interpretability.shapiq.get_tabpfn_imputation_explainer`
+
+Creates a shapiq `TabularExplainer` that uses imputation-based feature removal
+(same strategy as SHAP).
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `model` | `TabPFNClassifier \| TabPFNRegressor` | *required* | Fitted TabPFN model |
+| `data` | `DataFrame \| ndarray` | *required* | Background data for imputation sampling |
+| `index` | `str` | `"k-SII"` | Shapley index type (same options as above) |
+| `max_order` | `int` | `2` | Maximum interaction order |
+| `imputer` | `str` | `"marginal"` | Imputation method. See shapiq docs for available options. |
+| `class_index` | `int \| None` | `None` | Class to explain (classification only) |
+| `**kwargs` | | | Additional keyword arguments forwarded to `shapiq.TabularExplainer` |
+
+**Returns:** `shapiq.TabularExplainer`
+
+Same `.explain(x, budget=N)` interface as above.
+
+---
+
+### `interpretability.shap.get_shap_values`
+
+Computes SHAP values using a permutation-based explainer with automatic backend
+selection for TabPFN models.
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `estimator` | sklearn-compatible model | *required* | Fitted model (TabPFN or any sklearn estimator) |
+| `test_x` | `DataFrame \| ndarray \| Tensor` | *required* | Samples to explain |
+| `attribute_names` | `list[str] \| None` | `None` | Feature names when `test_x` is a numpy array |
+| `**kwargs` | | | Forwarded to the underlying `shap.Explainer` |
+
+**Returns:** `shap.Explanation` — access `.values` for a numpy array of shape
+`(n_samples, n_features)` (regression) or `(n_samples, n_features, n_classes)`
+(classification).
+
+---
+
+### `interpretability.shap.plot_shap`
+
+Visualizes SHAP values as an aggregate bar chart, a per-sample beeswarm plot,
+and (if more than one sample) an interaction scatter for the most important
+feature.
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `shap_values` | `shap.Explanation` | *required* | Output from `get_shap_values` |
+
+**Returns:** `None` (displays matplotlib figures).
+
+---
+
+### `interpretability.pdp.partial_dependence_plots`
+
+Convenience wrapper around sklearn's `PartialDependenceDisplay.from_estimator`.
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `estimator` | sklearn-compatible model | *required* | Fitted estimator |
+| `X` | `ndarray` | *required* | Input features |
+| `features` | `list[int \| tuple[int, int]]` | *required* | Feature indices for 1D plots, or `(i, j)` tuples for 2D interaction plots |
+| `grid_resolution` | `int` | `20` | Number of grid points per feature axis |
+| `kind` | `str` | `"average"` | `"average"` for PDP, `"individual"` for ICE curves, `"both"` for overlay |
+| `target_class` | `int \| None` | `None` | For classifiers: which class probability to plot |
+| `ax` | `matplotlib.axes.Axes \| None` | `None` | Optional axes to plot into |
+| `**kwargs` | | | Forwarded to `PartialDependenceDisplay.from_estimator` |
+
+**Returns:** `sklearn.inspection.PartialDependenceDisplay`
+
+---
+
+### `interpretability.feature_selection.feature_selection`
+
+Forward sequential feature selection using cross-validation.
+
+| Parameter | Type | Default | Description |
+|-----------|------|---------|-------------|
+| `estimator` | sklearn-compatible model | *required* | Fitted estimator |
+| `X` | `ndarray` | *required* | Input features |
+| `y` | `ndarray` | *required* | Target values |
+| `n_features_to_select` | `int` | `3` | Number of features to select |
+| `feature_names` | `list[str] \| None` | `None` | Feature names (optional) |
+| `**kwargs` | | | Forwarded to `sklearn.feature_selection.SequentialFeatureSelector` |
+
+**Returns:** `sklearn.feature_selection.SequentialFeatureSelector` — call
+`.transform(X)` to reduce features, or `.get_support(indices=True)` to get
+selected indices.
+
+---
-
- Check out our Google Colab for a demo.
-
+
+
+ GPU setup, batch inference, and performance tuning.
+
+
+ Binary and multi-class classification guide.
+
+
+ Point estimates, quantiles, and full distributions.
+
+
+ Adapt TabPFN to your domain-specific data.
+
+