Conversation
|
Preview deployment for your docs. Learn more about Mintlify Previews.
|
There was a problem hiding this comment.
Code Review
This pull request significantly expands the interpretability documentation, adding comprehensive guides for the shapiq extension, SHAP values, and Partial Dependence Plots. The review feedback identifies several opportunities for improvement, including maintaining consistent American English spelling, refining wordy phrasing for better readability, clarifying ambiguous table headers, and ensuring that all code examples are self-contained and runnable for the user.
capabilities/interpretability.mdx
Outdated
|
|
||
| <Tip> | ||
| **shapiq vs SHAP** — shapiq's `TabPFNExplainer` removes features and | ||
| re-contextualises the model, which matches how TabPFN natively handles missing |
capabilities/interpretability.mdx
Outdated
|
|
||
| ### Compute SHAP values for a batch of predictions | ||
|
|
||
| The classic SHAP library uses permutation-based imputation. It is less computationally efficient as compared to `shapiq`. |
capabilities/interpretability.mdx
Outdated
| consider using `shapiq` or passing a smaller subset to `get_shap_values`. | ||
| </Note> | ||
|
|
||
| ### Visualise global feature effects with Partial Dependence Plots |
| ```python | ||
| from tabpfn_extensions.interpretability.shapiq import ( | ||
| get_tabpfn_explainer, | ||
| get_tabpfn_imputation_explainer, | ||
| ) | ||
|
|
||
| # Native TabPFN explanation (remove-and-recontextualize) | ||
| native_explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train) | ||
| sv_native = native_explainer.explain(X_test.iloc[0:1].values, budget=128) | ||
|
|
||
| # Imputation-based explanation (same strategy as SHAP) | ||
| impute_explainer = get_tabpfn_imputation_explainer( | ||
| model=clf, | ||
| data=X_train, | ||
| index="SV", | ||
| max_order=1, | ||
| imputer="marginal", | ||
| ) | ||
| sv_impute = impute_explainer.explain(X_test.iloc[0:1].values, budget=128) | ||
| ``` |
There was a problem hiding this comment.
This code example is not self-contained because it uses variables (clf, X_train, X_test) that are not defined within the snippet. This prevents users from being able to copy and paste the code to run it. Please add the necessary setup code to make the example runnable on its own.
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from tabpfn_client import TabPFNClassifier
from tabpfn_extensions.interpretability.shapiq import (
get_tabpfn_explainer,
get_tabpfn_imputation_explainer,
)
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = TabPFNClassifier()
clf.fit(X_train, y_train)
# Native TabPFN explanation (remove-and-recontextualize)
native_explainer = get_tabpfn_explainer(model=clf, data=X_train, labels=y_train)
sv_native = native_explainer.explain(X_test.iloc[0:1].values, budget=128)
# Imputation-based explanation (same strategy as SHAP)
impute_explainer = get_tabpfn_imputation_explainer(
model=clf,
data=X_train,
index="SV",
max_order=1,
imputer="marginal",
)
sv_impute = impute_explainer.explain(X_test.iloc[0:1].values, budget=128)
| ```python | ||
| from tabpfn_extensions.interpretability.feature_selection import feature_selection | ||
|
|
||
| selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_select=5) | ||
| X_selected = selector.transform(X_test.values) | ||
| print("Selected feature indices:", selector.get_support(indices=True)) | ||
| ``` |
There was a problem hiding this comment.
This code example for feature selection is not self-contained, as it depends on variables (clf, X_train, y_train, X_test) defined in other sections. To improve the user experience, please make this snippet runnable by including the necessary data loading and model training steps.
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from tabpfn_client import TabPFNClassifier
from tabpfn_extensions.interpretability.feature_selection import feature_selection
X, y = load_breast_cancer(return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
clf = TabPFNClassifier()
clf.fit(X_train, y_train)
selector = feature_selection(clf, X_train.values, y_train.values, n_features_to_select=5)
X_selected = selector.transform(X_test.values)
print("Selected feature indices:", selector.get_support(indices=True))
capabilities/interpretability.mdx
Outdated
| (e.g. 1024 for 10 features, ~1 billion for 30). In practice, shapiq's | ||
| approximation algorithms converge well before that: | ||
|
|
||
| | Dataset size | Suggested budget | Notes | |
There was a problem hiding this comment.
capabilities/interpretability.mdx
Outdated
|
|
||
| ### `interpretability.shap.plot_shap` | ||
|
|
||
| Visualises SHAP values as an aggregate bar chart, a per-sample beeswarm plot, |
|
Codex usage limits have been reached for code reviews. Please check with the admins of this repo to increase the limits by adding credits. |
psinger-prior
left a comment
There was a problem hiding this comment.
Thanks! Overall mostly LGTM, please consider the comments (also from gemini)
| --- | ||
|
|
||
| The [Interpretability Extension](https://github.com/PriorLabs/tabpfn-extensions/tree/main/src/tabpfn_extensions/interpretability) adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. This can be used to: | ||
| The Interpretability Extension adds SHAP (SHapley Additive exPlanations) support to quantify the contribution of each input feature to an individual prediction. |
There was a problem hiding this comment.
would suggest to add a reference to what SHAP is
| | `attribute_names` | `list[str] \| None` | `None` | Feature names when `test_x` is a numpy array | | ||
| | `**kwargs` | | | Forwarded to the underlying `shap.Explainer` | | ||
|
|
||
| **Returns:** `shap.Explanation` — access `.values` for a numpy array of shape |
There was a problem hiding this comment.
This function actually seems to return a numpy array: https://github.com/PriorLabs/tabpfn-extensions/blob/main/src/tabpfn_extensions/interpretability/shap.py#L191
Not sure at which side this should be different
|
|
||
| --- | ||
|
|
||
| ## Library Reference |
There was a problem hiding this comment.
I would personally not have all the library references here, easier to just reference the github repo directly (or we generate these automatically from there) - that also makes the updating path easier and avoids issues with return values like below.
| @@ -1,75 +1,440 @@ | |||
| --- | |||
| title: "Interpretability" | |||
There was a problem hiding this comment.
We could add a top section of why TabPFN works well for interpretability in general (smooth predictions, not overfit, etc) and that it works out of the box with most interpretability functionality in sklearn as we follow the sdk. And then we have a few popular examples.
| <Tip> | ||
| **shapiq vs SHAP** — shapiq's `TabPFNExplainer` removes features and | ||
| re-contextualizes the model, which matches how TabPFN natively handles missing | ||
| data. SHAP replaces absent features with random background samples. shapiq is |
There was a problem hiding this comment.
Even if a method is smaller case, would suggest not to start a sentence with lower case "shapiq"
|
|
||
| | Method | What it tells you | When to reach for it | | ||
| |--------|-------------------|----------------------| | ||
| | **shapiq** (recommended) | This is a modern version of classic SHAP library. Tells you which features drove a specific prediction. Uses a remove-and-recontextualize strategy that is native to how TabPFN handles missing data. | You want per-sample explanations and care about feature interactions, or you want the fastest Shapley-based method for TabPFN. | |
There was a problem hiding this comment.
I would not call it modern, rather an extension of original SHAP library
| | **Partial Dependence / ICE** | The global, marginal effect of one or two features across the entire dataset. | You want to understand how a feature affects the model *on average* rather than for a single sample, or you want to visually compare TabPFN against another sklearn estimator. | | ||
| | **Feature Selection** | Which minimal subset of features preserves model performance. | You want to simplify your model or identify redundant features before deployment. | | ||
|
|
||
| If you are still unsure which method to use, follow the table below to see the best tools for most common questions. |
There was a problem hiding this comment.
Maybe also have decision matrix based on size?
Draft PR for now - but generally improved the interpretability doc with examples, explanations and library reference.