Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,636 changes: 2,636 additions & 0 deletions examples/ChestXray-Classification-ResNet-with-Saliency.ipynb

Large diffs are not rendered by default.

1,044 changes: 1,044 additions & 0 deletions examples/interpretability/lrp_stagenet_mimic4.ipynb

Large diffs are not rendered by default.

199 changes: 199 additions & 0 deletions examples/interpretability/lrp_stagenet_mimic4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""LRP with StageNet on MIMIC-IV for mortality prediction.

Demonstrates Layer-wise Relevance Propagation (LRP) interpretability
using epsilon-rule and alphabeta-rule on MIMIC-IV data.

Usage:
python lrp_stagenet_mimic4.py
"""

from pathlib import Path

import torch

from pyhealth.datasets import (
MIMIC4Dataset,
get_dataloader,
load_processors,
save_processors,
split_by_patient,
)
from pyhealth.interpret.methods import LayerwiseRelevancePropagation
from pyhealth.models import StageNet
from pyhealth.tasks import MortalityPredictionStageNetMIMIC4
from pyhealth.trainer import Trainer


def decode_indices_to_tokens(indices_tensor, processor, feature_key):
"""Decode token indices back to original codes using processor vocabulary."""
if not hasattr(processor, "code_vocab"):
return None
reverse_vocab = {idx: token for token, idx in processor.code_vocab.items()}

def decode(idx):
return reverse_vocab.get(idx, f"<unknown_{idx}>")

items = indices_tensor.tolist()
if indices_tensor.dim() == 1:
return [decode(i) for i in items]
elif indices_tensor.dim() == 2:
return [[decode(i) for i in row] for row in items]
elif indices_tensor.dim() == 3:
return [[[decode(i) for i in inner] for inner in row] for row in items]
return items


def print_lrp_results(attributions, sample_batch, sample_dataset, top_k=10):
"""Print top-k LRP attribution results per feature."""
processors = sample_dataset.input_processors

for feature_key, attr in attributions.items():
if attr.numel() == 0:
continue

input_data = sample_batch[feature_key]
if isinstance(input_data, tuple):
input_tensor = input_data[1]
else:
input_tensor = input_data

total = attr[0].sum().item()
flat = attr[0].flatten()
k = min(top_k, flat.numel())
top_idx = torch.topk(flat.abs(), k=k).indices

print(f"\n {feature_key} (shape={attr.shape}, total_relevance={total:+.6f}):")

is_continuous = torch.is_floating_point(input_tensor)
processor = processors.get(feature_key)

for rank, fidx in enumerate(top_idx.tolist(), 1):
val = flat[fidx].item()
if is_continuous and attr[0].dim() == 3:
dim2 = attr[0].shape[2]
t, f = fidx // dim2, fidx % dim2
if input_tensor.dim() == 3 and t < input_tensor.shape[1] and f < input_tensor.shape[2]:
actual = input_tensor[0, t, f].item()
print(f" {rank:2d}. T{t} F{f} val={actual:7.2f} -> {val:+.6f}")
else:
print(f" {rank:2d}. idx={fidx} -> {val:+.6f}")
elif not is_continuous and processor:
tokens = decode_indices_to_tokens(input_tensor[0], processor, feature_key)
if tokens and attr[0].dim() == 3:
dim2 = attr[0].shape[2]
t, f = fidx // dim2, fidx % dim2
if t < len(tokens) and f < len(tokens[t]):
print(f" {rank:2d}. Visit {t} '{tokens[t][f]}' -> {val:+.6f}")
continue
print(f" {rank:2d}. idx={fidx} -> {val:+.6f}")
else:
print(f" {rank:2d}. idx={fidx} -> {val:+.6f}")


def main():
# Load MIMIC-IV
print("Loading MIMIC-IV dataset...")
base_dataset = MIMIC4Dataset(
ehr_root="/srv/local/data/physionet.org/files/mimiciv/2.2/",
ehr_tables=[
"patients", "admissions", "diagnoses_icd",
"procedures_icd", "labevents",
],
dev=True,
)
base_dataset.stats()

# Processors
processor_dir = Path("../../output/processors/stagenet_mortality_mimic4_lrp")
cache_dir = Path("../../mimic4_stagenet_lrp_cache")

if processor_dir.exists() and any(processor_dir.iterdir()):
print(f"Loading processors from {processor_dir}")
input_processors = load_processors(str(processor_dir))
sample_dataset = base_dataset.set_task(
MortalityPredictionStageNetMIMIC4(padding=20),
processors=input_processors,
cache_dir=str(cache_dir),
)
else:
print("Creating new processors...")
processor_dir.mkdir(parents=True, exist_ok=True)
sample_dataset = base_dataset.set_task(
MortalityPredictionStageNetMIMIC4(padding=20),
cache_dir=str(cache_dir),
)
save_processors(sample_dataset.input_processors, str(processor_dir))

print(f"Samples: {len(sample_dataset)}")

# Split
train_ds, val_ds, test_ds = split_by_patient(sample_dataset, [0.8, 0.1, 0.1])
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")

train_loader = get_dataloader(train_ds, batch_size=64, shuffle=True)
val_loader = get_dataloader(val_ds, batch_size=64, shuffle=False)
test_loader = get_dataloader(test_ds, batch_size=1, shuffle=False)

# Model
model = StageNet(
dataset=sample_dataset, embedding_dim=128, chunk_size=128,
levels=3, dropout=0.3,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train
trainer = Trainer(
model=model, device="cpu",
metrics=["pr_auc", "roc_auc", "accuracy", "f1"],
)
trainer.train(
train_dataloader=train_loader, val_dataloader=val_loader,
epochs=5, monitor="roc_auc", optimizer_params={"lr": 1e-4},
)

# Evaluate
results = trainer.evaluate(test_loader)
print("\nTest Results:")
for metric, value in results.items():
print(f" {metric}: {value:.4f}")

# LRP interpretation
sample_batch = next(iter(test_loader))

with torch.no_grad():
output = model(**sample_batch)
probs = output["y_prob"]
pred = torch.argmax(probs, dim=-1)
true_label = sample_batch[model.label_key]
print(f"\nPrediction: true={int(true_label[0].item())}, "
f"pred={int(pred[0].item())}, "
f"P(survived)={probs[0, 0].item():.4f}, P(died)={probs[0, 1].item():.4f}")

# Epsilon rule
print("\nLRP Epsilon-Rule (eps=0.01):")
lrp_eps = LayerwiseRelevancePropagation(
model, rule="epsilon", epsilon=0.01, use_embeddings=True
)
attr_eps = lrp_eps.attribute(**sample_batch)
print_lrp_results(attr_eps, sample_batch, sample_dataset)

# AlphaBeta rule
print("\nLRP AlphaBeta-Rule (alpha=1.0, beta=0.0):")
lrp_ab = LayerwiseRelevancePropagation(
model, rule="alphabeta", alpha=1.0, beta=0.0, use_embeddings=True
)
attr_ab = lrp_ab.attribute(**sample_batch)
print_lrp_results(attr_ab, sample_batch, sample_dataset)

# Conservation comparison
print("\nRelevance comparison:")
for key in attr_eps:
eps_t = attr_eps[key][0].sum().item()
ab_t = attr_ab[key][0].sum().item()
print(f" {key}: epsilon={eps_t:+.6f}, alphabeta={ab_t:+.6f}")

print(f"\nProcessors saved at: {processor_dir}")


if __name__ == "__main__":
main()
195 changes: 195 additions & 0 deletions examples/interpretability/lrp_stagenet_synthetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""LRP with StageNet on synthetic data.

Demonstrates Layer-wise Relevance Propagation (LRP) interpretability
on a StageNet model using synthetic patient data. No external datasets required.

Usage:
python lrp_stagenet_synthetic.py
"""

import random
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn

from pyhealth.datasets import SampleDataset, get_dataloader
from pyhealth.interpret.methods import LayerwiseRelevancePropagation
from pyhealth.models import StageNet
from pyhealth.processors import StageNetProcessor, StageNetTensorProcessor


def generate_synthetic_data(
num_samples: int = 500,
num_visits_range: Tuple[int, int] = (3, 10),
num_codes_range: Tuple[int, int] = (5, 20),
num_lab_tests: int = 5,
vocab_size: int = 100,
seed: int = 42,
) -> list:
"""Generate synthetic patient samples for StageNet."""
random.seed(seed)
np.random.seed(seed)
samples = []

for i in range(num_samples):
num_visits = random.randint(*num_visits_range)
diagnoses_list, diagnosis_times = [], []
for v in range(num_visits):
num_codes = random.randint(*num_codes_range)
diagnoses_list.append(
[f"D{random.randint(0, vocab_size - 1)}" for _ in range(num_codes)]
)
diagnosis_times.append(0.0 if v == 0 else random.uniform(24, 720))

lab_values_list, lab_times = [], []
meas_idx = 0
for v in range(num_visits):
for _ in range(random.randint(3, 10)):
vec = []
for lab_idx in range(num_lab_tests):
if (i == 0 and meas_idx == 0) or random.random() < 0.8:
vec.append(100.0 + random.gauss(0, 20))
else:
vec.append(None)
lab_values_list.append(vec)
lab_times.append(random.uniform(0, 24))
meas_idx += 1

risky = sum(
1 for codes in diagnoses_list for c in codes if int(c[1:]) < 20
)
risk = num_visits * 0.1 + risky * 0.05 + random.gauss(0, 0.1)

samples.append({
"patient_id": f"P{i:04d}",
"diagnoses": (diagnosis_times, diagnoses_list),
"labs": (lab_times, lab_values_list),
"label": 1 if risk > 0.5 else 0,
})
return samples


def print_top_features(attributions, sample_batch, top_k=10):
"""Print top-k features by absolute LRP relevance."""
for key, attr_tensor in attributions.items():
if attr_tensor is None or attr_tensor.numel() == 0:
continue
attr = attr_tensor[0].detach().cpu().flatten()
k = min(top_k, attr.numel())
_, top_idx = torch.topk(attr.abs(), k=k)

print(f"\n {key} (shape={attr_tensor.shape}):")
for rank, idx in enumerate(top_idx.tolist(), 1):
print(f" {rank:2d}. index={idx}, relevance={attr[idx].item():+.6f}")


def main():
print("Generating synthetic patient data...")
samples = generate_synthetic_data(num_samples=500, seed=42)
print(f" {len(samples)} samples generated")

# Create dataset
from pyhealth.datasets.sample_dataset import InMemorySampleDataset
from pyhealth.processors.base_processor import FeatureProcessor

class LabelProcessor(FeatureProcessor):
def fit(self, samples, key):
pass
def process(self, value):
return torch.tensor([value], dtype=torch.float)
def size(self):
return 1

dataset = InMemorySampleDataset(
samples=samples,
input_schema={"diagnoses": "stagenet", "labs": "stagenet_tensor"},
output_schema={"label": "binary"},
output_processors={"label": LabelProcessor()},
)

# Split
n_train = int(0.7 * len(dataset))
n_val = int(0.15 * len(dataset))
train_ds = dataset.subset(list(range(n_train)))
test_ds = dataset.subset(list(range(n_train + n_val, len(dataset))))
train_loader = get_dataloader(train_ds, batch_size=32, shuffle=True)
test_loader = get_dataloader(test_ds, batch_size=1, shuffle=False)

# Model
model = StageNet(
dataset=dataset, embedding_dim=128, chunk_size=128, levels=3, dropout=0.3,
)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Train
device = torch.device("cpu")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(3):
model.train()
total_loss, correct, total = 0.0, 0, 0
for batch in train_loader:
batch = {
k: v.to(device) if isinstance(v, torch.Tensor)
else tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in v)
if isinstance(v, tuple) else v
for k, v in batch.items()
}
optimizer.zero_grad()
out = model(**batch)
out["loss"].backward()
optimizer.step()
total_loss += out["loss"].item()
preds = out["y_prob"]
labels = batch["label"].squeeze()
correct += ((preds > 0.5).long() == labels).sum().item()
total += labels.size(0)
print(f" Epoch {epoch+1}/3: loss={total_loss/len(train_loader):.4f}, "
f"acc={100*correct/total:.1f}%")

# LRP
model.eval()
sample_batch = next(iter(test_loader))
sample_batch = {
k: v.to(device) if isinstance(v, torch.Tensor)
else tuple(t.to(device) if isinstance(t, torch.Tensor) else t for t in v)
if isinstance(v, tuple) else v
for k, v in sample_batch.items()
}

with torch.no_grad():
output = model(**sample_batch)
pred_prob = torch.sigmoid(output["logit"]).item()
print(f"\nPrediction: class={int(pred_prob > 0.5)}, prob={pred_prob:.4f}, "
f"true={int(sample_batch['label'].item())}")

# Epsilon rule
print("\nLRP Epsilon-Rule (eps=0.01):")
lrp_eps = LayerwiseRelevancePropagation(
model, rule="epsilon", epsilon=0.01, use_embeddings=True
)
attr_eps = lrp_eps.attribute(**sample_batch)
print_top_features(attr_eps, sample_batch)

# AlphaBeta rule
print("\nLRP AlphaBeta-Rule (alpha=1, beta=0):")
lrp_ab = LayerwiseRelevancePropagation(
model, rule="alphabeta", alpha=1.0, beta=0.0, use_embeddings=True
)
attr_ab = lrp_ab.attribute(**sample_batch)
print_top_features(attr_ab, sample_batch)

# Conservation check
with torch.no_grad():
f_x = model(**sample_batch)["logit"].squeeze().item()
eps_sum = sum(attr_eps[k][0].sum().item() for k in attr_eps)
ab_sum = sum(attr_ab[k][0].sum().item() for k in attr_ab)
print(f"\nConservation: f(x)={f_x:.6f}, "
f"eps_sum={eps_sum:.6f}, ab_sum={ab_sum:.6f}")


if __name__ == "__main__":
main()
Loading
Loading