diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..128d5cd --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,165 @@ +# ConforMix Benchmarking Suite + +A comprehensive benchmarking framework for evaluating ConforMix on standard protein conformational datasets. + +## Overview + +This module provides automated tools to: +- Run ConforMix on benchmark datasets (domain motion, fold-switching, cryptic pockets) +- Compute evaluation metrics (RMSD, TM-score, conformational coverage) +- Generate professional Markdown and HTML reports + +## Installation + +The benchmarking suite is included in the ConforMix repository. Ensure you have the main package installed: + +```bash +pip install ./conformix_boltz +``` + +Then install benchmark dependencies: + +```bash +pip install pandas pyyaml numpy mdtraj pytest +``` + +## Quick Start + +### Run a Benchmark + +```bash +# Run on domain motion dataset +python -m benchmarks.run_benchmark --config benchmarks/configs/domainmotion.yaml + +# Run on specific proteins only +python -m benchmarks.run_benchmark \ + --config benchmarks/configs/domainmotion.yaml \ + --proteins P0205 P69441 + +# Dry run (see what would be executed) +python -m benchmarks.run_benchmark \ + --config benchmarks/configs/domainmotion.yaml \ + --dry-run +``` + +### Compute Metrics + +```bash +python -m benchmarks.evaluate_metrics \ + --results benchmark_results/domainmotion/all_results.json +``` + +### Generate Reports + +```bash +python -m benchmarks.generate_report \ + --metrics benchmark_results/domainmotion/metrics.json +``` + +## Available Datasets + +| Dataset | Config File | Proteins | Description | +|---------|-------------|----------|-------------| +| Domain Motion | `configs/domainmotion.yaml` | 38 | Large-scale domain movements | +| Fold-Switching | `configs/foldswitching.yaml` | 15 | Proteins that change secondary structure | +| Cryptic Pockets | `configs/crypticpockets.yaml` | 34 | Hidden binding sites | +| Membrane Transporters | `configs/membranetransporters.yaml` | - | Conformational changes in transport | + +## Configuration Options + +Create a YAML configuration file: + +```yaml +dataset_name: "my_dataset" +csv_path: "datasets/my_dataset.csv" +output_dir: "benchmark_results/my_dataset" + +# Sampling parameters +num_twist_targets: 5 # Number of RMSD targets +samples_per_target: 2 # Samples per target +twist_strength: 15.0 # Twist potential strength +structured_regions_only: true + +# Execution settings +timeout_seconds: 3600 # Timeout per protein +skip_existing: true # Skip already processed +``` + +## Metrics Computed + +| Metric | Description | +|--------|-------------| +| **Min RMSD to Alt** | Minimum RMSD from any sample to alternate structure | +| **Mean RMSD to Alt** | Average RMSD across all samples | +| **Conformational Coverage** | How close best sample is to known alternate | +| **RMSD Diversity** | Average pairwise RMSD between samples | + +## Output Structure + +``` +benchmark_results/ +└── domainmotion/ + ├── config.json # Configuration used + ├── all_results.json # Raw benchmark results + ├── metrics.json # Computed metrics + ├── report.md # Markdown report + ├── report.html # HTML report + ├── .cache/ # Downloaded structures + └── P0205/ # Per-protein outputs + ├── result.json + └── samples.cif +``` + +## Running Tests + +```bash +# Run all tests +pytest benchmarks/tests/ -v + +# Run specific test class +pytest benchmarks/tests/test_benchmark.py::TestRMSDComputation -v +``` + +## API Usage + +```python +from benchmarks import run_benchmark, compute_metrics, generate_report +from benchmarks.run_benchmark import BenchmarkConfig + +# Load configuration +config = BenchmarkConfig.from_yaml("benchmarks/configs/domainmotion.yaml") + +# Run benchmark +results = run_benchmark(config) + +# Compute metrics +from benchmarks.evaluate_metrics import compute_all_metrics +metrics = compute_all_metrics(config.output_dir / "all_results.json") + +# Generate reports +generate_report(config.output_dir / "metrics.json") +``` + +## Adding Custom Datasets + +1. Create a CSV file with columns: + - `system_id`: Unique identifier + - `pdb1`: First PDB ID with chain (e.g., `1AKE_A`) + - `pdb2`: Second PDB ID with chain (alternate state) + - `RMSD`: Ground truth RMSD between states + - `TM-score`: (optional) Structural similarity + +2. Create a YAML config pointing to your CSV + +3. Run the benchmark + +## Contributing + +When adding new metrics or features: +1. Add implementation to appropriate module +2. Add tests to `tests/test_benchmark.py` +3. Update this README + +## License + +MIT License - same as ConforMix main repository. diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 0000000..b0be95f --- /dev/null +++ b/benchmarks/__init__.py @@ -0,0 +1,19 @@ +""" +Benchmarks module for ConforMix. + +This module provides an automated benchmarking framework for evaluating +ConforMix on standard protein conformational datasets. +""" + +from .run_benchmark import run_benchmark, BenchmarkConfig +from .evaluate_metrics import compute_metrics, MetricsResult +from .generate_report import generate_report + +__version__ = "0.1.0" +__all__ = [ + "run_benchmark", + "BenchmarkConfig", + "compute_metrics", + "MetricsResult", + "generate_report", +] diff --git a/benchmarks/configs/crypticpockets.yaml b/benchmarks/configs/crypticpockets.yaml new file mode 100644 index 0000000..24a12c8 --- /dev/null +++ b/benchmarks/configs/crypticpockets.yaml @@ -0,0 +1,16 @@ +# Cryptic Pockets Benchmark Configuration +# Tests ConforMix on 34 proteins with hidden binding sites + +dataset_name: "crypticpockets" +csv_path: "datasets/crypticpockets.csv" +output_dir: "benchmark_results/crypticpockets" + +# Sampling parameters +num_twist_targets: 6 +samples_per_target: 2 +twist_strength: 15.0 +structured_regions_only: true + +# Execution settings +timeout_seconds: 3600 +skip_existing: true diff --git a/benchmarks/configs/domainmotion.yaml b/benchmarks/configs/domainmotion.yaml new file mode 100644 index 0000000..005f059 --- /dev/null +++ b/benchmarks/configs/domainmotion.yaml @@ -0,0 +1,16 @@ +# Domain Motion Benchmark Configuration +# Tests ConforMix on 38 proteins with large-scale domain movements + +dataset_name: "domainmotion" +csv_path: "datasets/domainmotion.csv" +output_dir: "benchmark_results/domainmotion" + +# Sampling parameters +num_twist_targets: 5 +samples_per_target: 2 +twist_strength: 15.0 +structured_regions_only: true + +# Execution settings +timeout_seconds: 3600 +skip_existing: true diff --git a/benchmarks/configs/foldswitching.yaml b/benchmarks/configs/foldswitching.yaml new file mode 100644 index 0000000..6b4d95d --- /dev/null +++ b/benchmarks/configs/foldswitching.yaml @@ -0,0 +1,16 @@ +# Fold-Switching Benchmark Configuration +# Tests ConforMix on 15 proteins that switch between different folds + +dataset_name: "foldswitching" +csv_path: "datasets/foldswitching.csv" +output_dir: "benchmark_results/foldswitching" + +# Sampling parameters - more samples for challenging transitions +num_twist_targets: 8 +samples_per_target: 3 +twist_strength: 20.0 +structured_regions_only: true + +# Execution settings +timeout_seconds: 5400 # 90 minutes - fold switching is harder +skip_existing: true diff --git a/benchmarks/configs/membranetransporters.yaml b/benchmarks/configs/membranetransporters.yaml new file mode 100644 index 0000000..3f8371d --- /dev/null +++ b/benchmarks/configs/membranetransporters.yaml @@ -0,0 +1,16 @@ +# Membrane Transporters Benchmark Configuration +# Tests ConforMix on membrane transporter proteins + +dataset_name: "membranetransporters" +csv_path: "datasets/membranetransporters.csv" +output_dir: "benchmark_results/membranetransporters" + +# Sampling parameters +num_twist_targets: 5 +samples_per_target: 2 +twist_strength: 15.0 +structured_regions_only: true + +# Execution settings +timeout_seconds: 4800 # 80 minutes - larger proteins +skip_existing: true diff --git a/benchmarks/evaluate_metrics.py b/benchmarks/evaluate_metrics.py new file mode 100644 index 0000000..a8fd921 --- /dev/null +++ b/benchmarks/evaluate_metrics.py @@ -0,0 +1,491 @@ +""" +evaluate_metrics.py + +Compute evaluation metrics for ConforMix benchmark results. +Calculates RMSD, TM-score, and conformational coverage. + +Usage: + python -m benchmarks.evaluate_metrics --results benchmark_results/domainmotion/all_results.json +""" + +import argparse +import json +import logging +import subprocess +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import numpy as np + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +@dataclass +class MetricsResult: + """Metrics for a single protein benchmark.""" + + system_id: str + pdb1: str + pdb2: str + ground_truth_rmsd: float + + min_rmsd_to_alt: Optional[float] = None + max_rmsd_to_alt: Optional[float] = None + mean_rmsd_to_alt: Optional[float] = None + + best_tm_score: Optional[float] = None + mean_tm_score: Optional[float] = None + + conformational_coverage: Optional[float] = None + rmsd_diversity: Optional[float] = None + + num_samples: int = 0 + error_message: Optional[str] = None + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "system_id": self.system_id, + "pdb1": self.pdb1, + "pdb2": self.pdb2, + "ground_truth_rmsd": self.ground_truth_rmsd, + "min_rmsd_to_alt": self.min_rmsd_to_alt, + "max_rmsd_to_alt": self.max_rmsd_to_alt, + "mean_rmsd_to_alt": self.mean_rmsd_to_alt, + "best_tm_score": self.best_tm_score, + "mean_tm_score": self.mean_tm_score, + "conformational_coverage": self.conformational_coverage, + "rmsd_diversity": self.rmsd_diversity, + "num_samples": self.num_samples, + "error_message": self.error_message, + } + + +def fetch_pdb_structure(pdb_id: str, chain_id: str, cache_dir: Path) -> Optional[Path]: + """ + Fetch PDB structure from RCSB. + + Args: + pdb_id: PDB ID (e.g., '1AKE') + chain_id: Chain ID (e.g., 'A') + cache_dir: Directory to cache downloaded files + + Returns: + Path to PDB file, or None if fetch failed + """ + import urllib.request + import urllib.error + + cache_dir.mkdir(parents=True, exist_ok=True) + pdb_path = cache_dir / f"{pdb_id}.pdb" + + if pdb_path.exists(): + return pdb_path + + url = f"https://files.rcsb.org/download/{pdb_id}.pdb" + + try: + logger.info(f"Fetching structure {pdb_id} from RCSB...") + with urllib.request.urlopen(url, timeout=30) as response: + content = response.read() + + with open(pdb_path, "wb") as f: + f.write(content) + + return pdb_path + + except urllib.error.URLError as e: + logger.error(f"Failed to fetch {pdb_id}: {e}") + return None + + +def compute_rmsd_mdtraj( + sample_path: Path, + reference_path: Path, + atom_selection: str = "backbone", +) -> Optional[float]: + """ + Compute RMSD between sample and reference using MDTraj. + + Args: + sample_path: Path to sample structure + reference_path: Path to reference structure + atom_selection: Atom selection for RMSD + + Returns: + RMSD in Angstroms, or None if computation failed + """ + try: + import mdtraj as md + + sample = md.load(str(sample_path)) + reference = md.load(str(reference_path)) + + if atom_selection == "backbone": + sample_atoms = sample.topology.select("backbone") + ref_atoms = reference.topology.select("backbone") + elif atom_selection == "ca": + sample_atoms = sample.topology.select("name CA") + ref_atoms = reference.topology.select("name CA") + else: + sample_atoms = sample.topology.select("all") + ref_atoms = reference.topology.select("all") + + min_atoms = min(len(sample_atoms), len(ref_atoms)) + if min_atoms == 0: + return None + + sample_coords = sample.xyz[0, sample_atoms[:min_atoms]] + ref_coords = reference.xyz[0, ref_atoms[:min_atoms]] + + sample_centered = sample_coords - sample_coords.mean(axis=0) + ref_centered = ref_coords - ref_coords.mean(axis=0) + + correlation_matrix = np.dot(sample_centered.T, ref_centered) + U, S, Vt = np.linalg.svd(correlation_matrix) + rotation = np.dot(Vt.T, U.T) + + if np.linalg.det(rotation) < 0: + Vt[-1, :] *= -1 + rotation = np.dot(Vt.T, U.T) + + sample_aligned = np.dot(sample_centered, rotation) + rmsd = np.sqrt(np.mean(np.sum((sample_aligned - ref_centered) ** 2, axis=1))) + + return float(rmsd * 10) + + except Exception as e: + logger.warning(f"RMSD computation failed: {e}") + return None + + +def compute_rmsd_simple( + coords1: np.ndarray, + coords2: np.ndarray, +) -> float: + """ + Compute RMSD between two coordinate arrays after optimal alignment. + + Args: + coords1: First set of coordinates (N, 3) + coords2: Second set of coordinates (N, 3) + + Returns: + RMSD in same units as input + """ + centered1 = coords1 - coords1.mean(axis=0) + centered2 = coords2 - coords2.mean(axis=0) + + correlation = np.dot(centered1.T, centered2) + U, S, Vt = np.linalg.svd(correlation) + rotation = np.dot(Vt.T, U.T) + + if np.linalg.det(rotation) < 0: + Vt[-1, :] *= -1 + rotation = np.dot(Vt.T, U.T) + + aligned1 = np.dot(centered1, rotation) + rmsd = np.sqrt(np.mean(np.sum((aligned1 - centered2) ** 2, axis=1))) + + return float(rmsd) + + +def extract_ca_coords_from_pdb(pdb_path: Path) -> Optional[np.ndarray]: + """ + Extract CA coordinates from a PDB file. + + Args: + pdb_path: Path to PDB file + + Returns: + Array of CA coordinates (N, 3), or None if failed + """ + coords = [] + + try: + with open(pdb_path) as f: + for line in f: + if line.startswith("ATOM") and line[12:16].strip() == "CA": + x = float(line[30:38]) + y = float(line[38:46]) + z = float(line[46:54]) + coords.append([x, y, z]) + + if coords: + return np.array(coords) + return None + + except Exception as e: + logger.warning(f"Failed to extract CA coords from {pdb_path}: {e}") + return None + + +def extract_ca_coords_from_cif(cif_path: Path) -> Optional[np.ndarray]: + """ + Extract CA coordinates from a CIF file. + + Args: + cif_path: Path to CIF file + + Returns: + Array of CA coordinates (N, 3), or None if failed + """ + coords = [] + + try: + with open(cif_path) as f: + in_atom_site = False + columns = {} + + for line in f: + if line.startswith("_atom_site."): + column_name = line.split(".")[1].strip() + columns[column_name] = len(columns) + in_atom_site = True + elif in_atom_site and line.startswith("ATOM"): + parts = line.split() + if len(parts) > max(columns.values()): + atom_name_idx = columns.get("label_atom_id", columns.get("auth_atom_id", 3)) + if parts[atom_name_idx] == "CA": + x_idx = columns.get("Cartn_x", 10) + y_idx = columns.get("Cartn_y", 11) + z_idx = columns.get("Cartn_z", 12) + x = float(parts[x_idx]) + y = float(parts[y_idx]) + z = float(parts[z_idx]) + coords.append([x, y, z]) + elif in_atom_site and line.strip() == "#": + in_atom_site = False + + if coords: + return np.array(coords) + return None + + except Exception as e: + logger.warning(f"Failed to extract CA coords from {cif_path}: {e}") + return None + + +def find_sample_files(output_dir: Path) -> list[Path]: + """Find all sample structure files in an output directory.""" + samples = [] + + for pattern in ["*.cif", "*.pdb"]: + samples.extend(output_dir.glob(f"**/{pattern}")) + + samples = [s for s in samples if "reference" not in s.name.lower()] + + return sorted(samples) + + +def compute_metrics( + benchmark_result: dict, + cache_dir: Path, +) -> MetricsResult: + """ + Compute metrics for a single protein benchmark result. + + Args: + benchmark_result: Dictionary with benchmark result data + cache_dir: Directory for caching downloaded structures + + Returns: + MetricsResult with computed metrics + """ + system_id = benchmark_result["system_id"] + pdb1 = benchmark_result["pdb1"] + pdb2 = benchmark_result["pdb2"] + ground_truth_rmsd = benchmark_result["ground_truth_rmsd"] + output_dir = Path(benchmark_result["output_dir"]) + + if not benchmark_result["success"]: + return MetricsResult( + system_id=system_id, + pdb1=pdb1, + pdb2=pdb2, + ground_truth_rmsd=ground_truth_rmsd, + error_message=benchmark_result.get("error_message", "Benchmark failed"), + ) + + pdb2_id, chain2_id = pdb2.split("_") if "_" in pdb2 else (pdb2, "A") + alt_pdb_path = fetch_pdb_structure(pdb2_id, chain2_id, cache_dir) + + if alt_pdb_path is None: + return MetricsResult( + system_id=system_id, + pdb1=pdb1, + pdb2=pdb2, + ground_truth_rmsd=ground_truth_rmsd, + error_message=f"Failed to fetch alternate structure {pdb2}", + ) + + alt_coords = extract_ca_coords_from_pdb(alt_pdb_path) + if alt_coords is None: + return MetricsResult( + system_id=system_id, + pdb1=pdb1, + pdb2=pdb2, + ground_truth_rmsd=ground_truth_rmsd, + error_message=f"Failed to extract coords from {pdb2}", + ) + + sample_files = find_sample_files(output_dir) + if not sample_files: + return MetricsResult( + system_id=system_id, + pdb1=pdb1, + pdb2=pdb2, + ground_truth_rmsd=ground_truth_rmsd, + num_samples=0, + error_message="No sample files found", + ) + + rmsds_to_alt = [] + sample_coords_list = [] + + for sample_path in sample_files: + if sample_path.suffix == ".cif": + sample_coords = extract_ca_coords_from_cif(sample_path) + else: + sample_coords = extract_ca_coords_from_pdb(sample_path) + + if sample_coords is None: + continue + + sample_coords_list.append(sample_coords) + + min_len = min(len(sample_coords), len(alt_coords)) + if min_len > 0: + rmsd = compute_rmsd_simple( + sample_coords[:min_len], + alt_coords[:min_len] + ) + rmsds_to_alt.append(rmsd) + + pairwise_rmsds = [] + for i in range(len(sample_coords_list)): + for j in range(i + 1, len(sample_coords_list)): + coords_i = sample_coords_list[i] + coords_j = sample_coords_list[j] + min_len = min(len(coords_i), len(coords_j)) + if min_len > 0: + rmsd = compute_rmsd_simple(coords_i[:min_len], coords_j[:min_len]) + pairwise_rmsds.append(rmsd) + + rmsd_diversity = np.mean(pairwise_rmsds) if pairwise_rmsds else None + + if rmsds_to_alt: + conformational_coverage = min(rmsds_to_alt) / ground_truth_rmsd if ground_truth_rmsd > 0 else None + conformational_coverage = min(1.0, 1.0 - abs(1.0 - conformational_coverage)) if conformational_coverage else None + else: + conformational_coverage = None + + return MetricsResult( + system_id=system_id, + pdb1=pdb1, + pdb2=pdb2, + ground_truth_rmsd=ground_truth_rmsd, + min_rmsd_to_alt=min(rmsds_to_alt) if rmsds_to_alt else None, + max_rmsd_to_alt=max(rmsds_to_alt) if rmsds_to_alt else None, + mean_rmsd_to_alt=np.mean(rmsds_to_alt) if rmsds_to_alt else None, + conformational_coverage=conformational_coverage, + rmsd_diversity=rmsd_diversity, + num_samples=len(sample_files), + ) + + +def compute_all_metrics( + results_path: Path, + output_path: Optional[Path] = None, +) -> list[MetricsResult]: + """ + Compute metrics for all benchmark results. + + Args: + results_path: Path to all_results.json + output_path: Optional path to save metrics JSON + + Returns: + List of MetricsResult objects + """ + logger.info(f"Loading results from {results_path}") + + with open(results_path) as f: + results = json.load(f) + + cache_dir = results_path.parent / ".cache" + cache_dir.mkdir(exist_ok=True) + + metrics = [] + + for i, result in enumerate(results): + logger.info(f"[{i+1}/{len(results)}] Computing metrics for {result['system_id']}") + m = compute_metrics(result, cache_dir) + metrics.append(m) + + if m.error_message: + logger.warning(f" Error: {m.error_message}") + else: + logger.info(f" RMSD to alt: {m.min_rmsd_to_alt:.2f}Å (min), " + f"{m.mean_rmsd_to_alt:.2f}Å (mean)") + + if output_path: + with open(output_path, "w") as f: + json.dump([m.to_dict() for m in metrics], f, indent=2) + logger.info(f"Saved metrics to {output_path}") + + return metrics + + +def main(): + """Main entry point for CLI.""" + parser = argparse.ArgumentParser( + description="Compute metrics for ConforMix benchmark results" + ) + parser.add_argument( + "--results", + type=Path, + required=True, + help="Path to all_results.json from benchmark run" + ) + parser.add_argument( + "--output", + type=Path, + help="Path to save metrics JSON (default: metrics.json in same directory)" + ) + + args = parser.parse_args() + + output_path = args.output or args.results.parent / "metrics.json" + + metrics = compute_all_metrics(args.results, output_path) + + successful = [m for m in metrics if m.error_message is None] + + print("\n" + "=" * 60) + print("METRICS SUMMARY") + print("=" * 60) + print(f"Total proteins: {len(metrics)}") + print(f"Successful: {len(successful)}") + + if successful: + rmsds = [m.min_rmsd_to_alt for m in successful if m.min_rmsd_to_alt is not None] + if rmsds: + print(f"Min RMSD to alt: {np.mean(rmsds):.2f}Å ± {np.std(rmsds):.2f}Å") + + coverages = [m.conformational_coverage for m in successful if m.conformational_coverage is not None] + if coverages: + print(f"Coverage: {np.mean(coverages)*100:.1f}% ± {np.std(coverages)*100:.1f}%") + + print(f"Output: {output_path}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/generate_report.py b/benchmarks/generate_report.py new file mode 100644 index 0000000..1a0bd3f --- /dev/null +++ b/benchmarks/generate_report.py @@ -0,0 +1,419 @@ +""" +generate_report.py + +Generate HTML and Markdown reports from benchmark metrics. + +Usage: + python -m benchmarks.generate_report --metrics benchmark_results/domainmotion/metrics.json +""" + +import argparse +import json +import logging +from datetime import datetime +from pathlib import Path +from typing import Optional + +import numpy as np + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def generate_markdown_report( + metrics_path: Path, + output_path: Optional[Path] = None, + title: Optional[str] = None, +) -> str: + """ + Generate a Markdown report from metrics. + + Args: + metrics_path: Path to metrics.json + output_path: Optional path to save report + title: Optional report title + + Returns: + Markdown report string + """ + with open(metrics_path) as f: + metrics = json.load(f) + + if title is None: + title = f"ConforMix Benchmark Report" + + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + successful = [m for m in metrics if m.get("error_message") is None] + failed = [m for m in metrics if m.get("error_message") is not None] + + rmsds_to_alt = [m["min_rmsd_to_alt"] for m in successful if m.get("min_rmsd_to_alt") is not None] + gt_rmsds = [m["ground_truth_rmsd"] for m in successful if m.get("min_rmsd_to_alt") is not None] + coverages = [m["conformational_coverage"] for m in successful if m.get("conformational_coverage") is not None] + diversities = [m["rmsd_diversity"] for m in successful if m.get("rmsd_diversity") is not None] + + lines = [ + f"# {title}", + "", + f"**Generated:** {timestamp}", + "", + "---", + "", + "## Summary", + "", + f"| Metric | Value |", + f"|--------|-------|", + f"| Total Proteins | {len(metrics)} |", + f"| Successful | {len(successful)} ({100*len(successful)/len(metrics):.1f}%) |", + f"| Failed | {len(failed)} |", + ] + + if rmsds_to_alt: + lines.append(f"| Avg Min RMSD to Alt | {np.mean(rmsds_to_alt):.2f}Å ± {np.std(rmsds_to_alt):.2f}Å |") + if gt_rmsds: + lines.append(f"| Avg Ground Truth RMSD | {np.mean(gt_rmsds):.2f}Å ± {np.std(gt_rmsds):.2f}Å |") + if coverages: + lines.append(f"| Avg Conformational Coverage | {np.mean(coverages)*100:.1f}% ± {np.std(coverages)*100:.1f}% |") + if diversities: + lines.append(f"| Avg Sample Diversity | {np.mean(diversities):.2f}Å ± {np.std(diversities):.2f}Å |") + + lines.extend([ + "", + "---", + "", + "## Per-Protein Results", + "", + "| System ID | PDB1 | PDB2 | GT RMSD | Min RMSD to Alt | Coverage | Samples |", + "|-----------|------|------|---------|-----------------|----------|---------|", + ]) + + for m in sorted(metrics, key=lambda x: x.get("min_rmsd_to_alt") or 999): + system_id = m["system_id"] + pdb1 = m["pdb1"] + pdb2 = m["pdb2"] + gt_rmsd = m.get("ground_truth_rmsd", 0) + min_rmsd = m.get("min_rmsd_to_alt") + coverage = m.get("conformational_coverage") + num_samples = m.get("num_samples", 0) + error = m.get("error_message") + + if error: + lines.append(f"| {system_id} | {pdb1} | {pdb2} | {gt_rmsd:.1f}Å | [X] Error | - | 0 |") + else: + min_rmsd_str = f"{min_rmsd:.2f}Å" if min_rmsd is not None else "-" + coverage_str = f"{coverage*100:.0f}%" if coverage is not None else "-" + lines.append(f"| {system_id} | {pdb1} | {pdb2} | {gt_rmsd:.1f}Å | {min_rmsd_str} | {coverage_str} | {num_samples} |") + + if failed: + lines.extend([ + "", + "---", + "", + "## Failed Proteins", + "", + "| System ID | Error |", + "|-----------|-------|", + ]) + + for m in failed: + error = m.get("error_message", "Unknown error")[:80] + lines.append(f"| {m['system_id']} | {error} |") + + lines.extend([ + "", + "---", + "", + "## Interpretation Guide", + "", + "- **Min RMSD to Alt**: Minimum RMSD from any generated sample to the alternate (target) conformation", + "- **Coverage**: How close the best sample is to the known alternate state (higher is better)", + "- **Samples**: Number of conformational samples generated", + "", + "> Lower Min RMSD to Alt indicates better conformational sampling.", + "> High coverage (>50%) suggests ConforMix successfully found the alternate state.", + "", + ]) + + report = "\n".join(lines) + + if output_path: + with open(output_path, "w") as f: + f.write(report) + logger.info(f"Saved report to {output_path}") + + return report + + +def generate_html_report( + metrics_path: Path, + output_path: Optional[Path] = None, + title: Optional[str] = None, +) -> str: + """ + Generate an HTML report from metrics. + + Args: + metrics_path: Path to metrics.json + output_path: Optional path to save report + title: Optional report title + + Returns: + HTML report string + """ + with open(metrics_path) as f: + metrics = json.load(f) + + if title is None: + title = "ConforMix Benchmark Report" + + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + successful = [m for m in metrics if m.get("error_message") is None] + failed = [m for m in metrics if m.get("error_message") is not None] + + rmsds_to_alt = [m["min_rmsd_to_alt"] for m in successful if m.get("min_rmsd_to_alt") is not None] + coverages = [m["conformational_coverage"] for m in successful if m.get("conformational_coverage") is not None] + + html = f""" + + + + + {title} + + + +

{title}

+

Generated: {timestamp}

+ +
+
+
{len(metrics)}
+
Total Proteins
+
+
+
{len(successful)}
+
Successful
+
+
+
{len(failed)}
+
Failed
+
+
+
{np.mean(rmsds_to_alt):.2f}Å
+
Avg Min RMSD to Alt
+
+
+
{np.mean(coverages)*100:.0f}%
+
Avg Coverage
+
+
+ +

Per-Protein Results

+ + + + + + + + + + + + + +""" + + for m in sorted(metrics, key=lambda x: x.get("min_rmsd_to_alt") or 999): + system_id = m["system_id"] + pdb1 = m["pdb1"] + pdb2 = m["pdb2"] + gt_rmsd = m.get("ground_truth_rmsd", 0) + min_rmsd = m.get("min_rmsd_to_alt") + coverage = m.get("conformational_coverage") + num_samples = m.get("num_samples", 0) + error = m.get("error_message") + + if error: + html += f""" + + + + + + + + +""" + else: + min_rmsd_str = f"{min_rmsd:.2f}Å" if min_rmsd is not None else "-" + coverage_str = f"{coverage*100:.0f}%" if coverage is not None else "-" + html += f""" + + + + + + + + +""" + + html += """ +
System IDPDB1PDB2GT RMSDMin RMSD to AltCoverageSamples
{system_id}{pdb1}{pdb2}{gt_rmsd:.1f}Å[X] Error-0
{system_id}{pdb1}{pdb2}{gt_rmsd:.1f}Å{min_rmsd_str}{coverage_str}{num_samples}
+ + +""" + + if output_path: + with open(output_path, "w") as f: + f.write(html) + logger.info(f"Saved HTML report to {output_path}") + + return html + + +def generate_report( + metrics_path: Path, + output_dir: Optional[Path] = None, + title: Optional[str] = None, +) -> tuple[Path, Path]: + """ + Generate both Markdown and HTML reports. + + Args: + metrics_path: Path to metrics.json + output_dir: Directory to save reports + title: Optional report title + + Returns: + Tuple of (markdown_path, html_path) + """ + if output_dir is None: + output_dir = metrics_path.parent + + output_dir.mkdir(parents=True, exist_ok=True) + + md_path = output_dir / "report.md" + html_path = output_dir / "report.html" + + generate_markdown_report(metrics_path, md_path, title) + generate_html_report(metrics_path, html_path, title) + + return md_path, html_path + + +def main(): + """Main entry point for CLI.""" + parser = argparse.ArgumentParser( + description="Generate reports from ConforMix benchmark metrics" + ) + parser.add_argument( + "--metrics", + type=Path, + required=True, + help="Path to metrics.json" + ) + parser.add_argument( + "--output-dir", + type=Path, + help="Directory to save reports (default: same as metrics)" + ) + parser.add_argument( + "--title", + type=str, + help="Report title" + ) + parser.add_argument( + "--format", + choices=["markdown", "html", "both"], + default="both", + help="Output format" + ) + + args = parser.parse_args() + + output_dir = args.output_dir or args.metrics.parent + + if args.format == "markdown": + md_path = output_dir / "report.md" + generate_markdown_report(args.metrics, md_path, args.title) + print(f"Generated: {md_path}") + elif args.format == "html": + html_path = output_dir / "report.html" + generate_html_report(args.metrics, html_path, args.title) + print(f"Generated: {html_path}") + else: + md_path, html_path = generate_report(args.metrics, output_dir, args.title) + print(f"Generated: {md_path}") + print(f"Generated: {html_path}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/run_benchmark.py b/benchmarks/run_benchmark.py new file mode 100644 index 0000000..413fcdc --- /dev/null +++ b/benchmarks/run_benchmark.py @@ -0,0 +1,418 @@ +""" +run_benchmark.py + +Main entry point for the ConforMix benchmarking suite. +Runs ConforMix on benchmark datasets and saves results for evaluation. + +Usage: + python -m benchmarks.run_benchmark --config benchmarks/configs/domainmotion.yaml + python -m benchmarks.run_benchmark --config benchmarks/configs/domainmotion.yaml --proteins P0205 P69441 +""" + +import argparse +import json +import logging +import shutil +import subprocess +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +import pandas as pd +import yaml + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +@dataclass +class BenchmarkConfig: + """Configuration for a benchmark run.""" + + dataset_name: str + csv_path: Path + output_dir: Path + num_twist_targets: int = 5 + samples_per_target: int = 2 + structured_regions_only: bool = True + subset_residues: Optional[str] = None + twist_strength: float = 15.0 + timeout_seconds: int = 3600 + skip_existing: bool = True + + @classmethod + def from_yaml(cls, yaml_path: Path) -> "BenchmarkConfig": + """Load configuration from a YAML file.""" + with open(yaml_path) as f: + data = yaml.safe_load(f) + + return cls( + dataset_name=data["dataset_name"], + csv_path=Path(data["csv_path"]), + output_dir=Path(data["output_dir"]), + num_twist_targets=data.get("num_twist_targets", 5), + samples_per_target=data.get("samples_per_target", 2), + structured_regions_only=data.get("structured_regions_only", True), + subset_residues=data.get("subset_residues"), + twist_strength=data.get("twist_strength", 15.0), + timeout_seconds=data.get("timeout_seconds", 3600), + skip_existing=data.get("skip_existing", True), + ) + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "dataset_name": self.dataset_name, + "csv_path": str(self.csv_path), + "output_dir": str(self.output_dir), + "num_twist_targets": self.num_twist_targets, + "samples_per_target": self.samples_per_target, + "structured_regions_only": self.structured_regions_only, + "subset_residues": self.subset_residues, + "twist_strength": self.twist_strength, + "timeout_seconds": self.timeout_seconds, + "skip_existing": self.skip_existing, + } + + +@dataclass +class BenchmarkResult: + """Result from benchmarking a single protein.""" + + system_id: str + pdb1: str + pdb2: str + ground_truth_rmsd: float + success: bool + runtime_seconds: float + output_dir: Path + error_message: Optional[str] = None + num_samples_generated: int = 0 + + def to_dict(self) -> dict: + """Convert to dictionary for serialization.""" + return { + "system_id": self.system_id, + "pdb1": self.pdb1, + "pdb2": self.pdb2, + "ground_truth_rmsd": self.ground_truth_rmsd, + "success": self.success, + "runtime_seconds": self.runtime_seconds, + "output_dir": str(self.output_dir), + "error_message": self.error_message, + "num_samples_generated": self.num_samples_generated, + } + + +def fetch_sequence_from_pdb(pdb_id: str, chain_id: str, cache_dir: Path) -> Optional[Path]: + """ + Fetch sequence from RCSB PDB and save as FASTA file. + + Args: + pdb_id: PDB ID (e.g., '1AKE') + chain_id: Chain ID (e.g., 'A') + cache_dir: Directory to cache downloaded files + + Returns: + Path to FASTA file, or None if fetch failed + """ + import urllib.request + import urllib.error + + cache_dir.mkdir(parents=True, exist_ok=True) + fasta_path = cache_dir / f"{pdb_id}_{chain_id}.fasta" + + if fasta_path.exists(): + logger.debug(f"Using cached FASTA: {fasta_path}") + return fasta_path + + url = f"https://www.rcsb.org/fasta/entry/{pdb_id}/display" + + try: + logger.info(f"Fetching sequence for {pdb_id}_{chain_id} from RCSB...") + with urllib.request.urlopen(url, timeout=30) as response: + content = response.read().decode("utf-8") + + lines = content.strip().split("\n") + target_header = None + target_sequence = [] + current_header = None + current_sequence = [] + + for line in lines: + if line.startswith(">"): + if current_header and chain_id in current_header: + target_header = current_header + target_sequence = current_sequence + current_header = line + current_sequence = [] + else: + current_sequence.append(line) + + if current_header and chain_id in current_header: + target_header = current_header + target_sequence = current_sequence + + if target_header and target_sequence: + with open(fasta_path, "w") as f: + f.write(f">{pdb_id}_{chain_id}\n") + f.write("".join(target_sequence) + "\n") + logger.info(f"Saved FASTA to: {fasta_path}") + return fasta_path + else: + logger.warning(f"Chain {chain_id} not found in {pdb_id}") + return None + + except urllib.error.URLError as e: + logger.error(f"Failed to fetch {pdb_id}: {e}") + return None + + +def run_conformix_boltz( + fasta_path: Path, + output_dir: Path, + config: BenchmarkConfig, +) -> tuple[bool, float, str]: + """ + Run ConforMix-Boltz on a single protein. + + Args: + fasta_path: Path to input FASTA file + output_dir: Directory to save outputs + config: Benchmark configuration + + Returns: + Tuple of (success, runtime_seconds, error_message) + """ + output_dir.mkdir(parents=True, exist_ok=True) + + cmd = [ + sys.executable, "-m", "boltz.run_conformixrmsd_boltz", + "--fasta_path", str(fasta_path), + "--out_dir", str(output_dir), + "--num_twist_targets", str(config.num_twist_targets), + "--samples_per_target", str(config.samples_per_target), + "--twist_strength", str(config.twist_strength), + ] + + if config.structured_regions_only: + cmd.append("--structured_regions_only") + + if config.subset_residues: + cmd.extend(["--subset_residues", config.subset_residues]) + + logger.info(f"Running: {' '.join(cmd)}") + + start_time = time.time() + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=config.timeout_seconds, + cwd=Path(__file__).parent.parent, + ) + runtime = time.time() - start_time + + if result.returncode == 0: + return True, runtime, "" + else: + error_msg = result.stderr[:500] if result.stderr else "Unknown error" + return False, runtime, error_msg + + except subprocess.TimeoutExpired: + runtime = time.time() - start_time + return False, runtime, f"Timeout after {config.timeout_seconds}s" + except Exception as e: + runtime = time.time() - start_time + return False, runtime, str(e) + + +def count_samples(output_dir: Path) -> int: + """Count the number of sample files generated.""" + patterns = ["*.cif", "*.pdb", "*.xtc"] + count = 0 + for pattern in patterns: + count += len(list(output_dir.glob(f"**/{pattern}"))) + return count + + +def run_benchmark( + config: BenchmarkConfig, + proteins: Optional[list[str]] = None, + dry_run: bool = False, +) -> list[BenchmarkResult]: + """ + Run the benchmark on a dataset. + + Args: + config: Benchmark configuration + proteins: Optional list of specific protein IDs to run + dry_run: If True, only print what would be done + + Returns: + List of BenchmarkResult objects + """ + logger.info(f"Loading dataset from {config.csv_path}") + df = pd.read_csv(config.csv_path) + + if proteins: + df = df[df["system_id"].isin(proteins) | + df["pdb1"].str.split("_").str[0].isin(proteins)] + logger.info(f"Filtered to {len(df)} proteins") + + logger.info(f"Running benchmark on {len(df)} proteins from {config.dataset_name}") + + config.output_dir.mkdir(parents=True, exist_ok=True) + + with open(config.output_dir / "config.json", "w") as f: + json.dump(config.to_dict(), f, indent=2) + + cache_dir = config.output_dir / ".cache" + results = [] + + for idx, row in df.iterrows(): + system_id = row["system_id"] + pdb1 = row["pdb1"] + pdb2 = row["pdb2"] + ground_truth_rmsd = row.get("RMSD", 0.0) + + pdb_id, chain_id = pdb1.split("_") if "_" in pdb1 else (pdb1, "A") + + protein_output_dir = config.output_dir / system_id + + if config.skip_existing and (protein_output_dir / "result.json").exists(): + logger.info(f"Skipping {system_id} (already exists)") + with open(protein_output_dir / "result.json") as f: + cached = json.load(f) + results.append(BenchmarkResult( + system_id=cached["system_id"], + pdb1=cached["pdb1"], + pdb2=cached["pdb2"], + ground_truth_rmsd=cached["ground_truth_rmsd"], + success=cached["success"], + runtime_seconds=cached["runtime_seconds"], + output_dir=Path(cached["output_dir"]), + error_message=cached.get("error_message"), + num_samples_generated=cached.get("num_samples_generated", 0), + )) + continue + + logger.info(f"[{idx+1}/{len(df)}] Processing {system_id} ({pdb1} -> {pdb2})") + + if dry_run: + logger.info(f" [DRY RUN] Would process {system_id}") + continue + + fasta_path = fetch_sequence_from_pdb(pdb_id, chain_id, cache_dir) + + if fasta_path is None: + result = BenchmarkResult( + system_id=system_id, + pdb1=pdb1, + pdb2=pdb2, + ground_truth_rmsd=ground_truth_rmsd, + success=False, + runtime_seconds=0.0, + output_dir=protein_output_dir, + error_message="Failed to fetch sequence from PDB", + ) + else: + success, runtime, error = run_conformix_boltz( + fasta_path, protein_output_dir, config + ) + + num_samples = count_samples(protein_output_dir) if success else 0 + + result = BenchmarkResult( + system_id=system_id, + pdb1=pdb1, + pdb2=pdb2, + ground_truth_rmsd=ground_truth_rmsd, + success=success, + runtime_seconds=runtime, + output_dir=protein_output_dir, + error_message=error if not success else None, + num_samples_generated=num_samples, + ) + + with open(protein_output_dir / "result.json", "w") as f: + json.dump(result.to_dict(), f, indent=2) + + results.append(result) + + success_count = sum(1 for r in results if r.success) + logger.info(f" Status: {'SUCCESS' if result.success else 'FAILED'} " + f"(Runtime: {result.runtime_seconds:.1f}s, " + f"Total: {success_count}/{len(results)} successful)") + + with open(config.output_dir / "all_results.json", "w") as f: + json.dump([r.to_dict() for r in results], f, indent=2) + + logger.info(f"Benchmark complete. Results saved to {config.output_dir}") + return results + + +def main(): + """Main entry point for CLI.""" + parser = argparse.ArgumentParser( + description="Run ConforMix benchmarks on standard datasets" + ) + parser.add_argument( + "--config", + type=Path, + required=True, + help="Path to benchmark configuration YAML file" + ) + parser.add_argument( + "--proteins", + nargs="+", + help="Optional: specific protein IDs to run" + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print what would be done without running" + ) + parser.add_argument( + "--output-dir", + type=Path, + help="Override output directory from config" + ) + + args = parser.parse_args() + + config = BenchmarkConfig.from_yaml(args.config) + + if args.output_dir: + config.output_dir = args.output_dir + + results = run_benchmark( + config=config, + proteins=args.proteins, + dry_run=args.dry_run, + ) + + if not args.dry_run: + success_count = sum(1 for r in results if r.success) + total_runtime = sum(r.runtime_seconds for r in results) + + print("\n" + "=" * 60) + print("BENCHMARK SUMMARY") + print("=" * 60) + print(f"Dataset: {config.dataset_name}") + print(f"Proteins: {len(results)}") + print(f"Successful: {success_count}/{len(results)} ({100*success_count/len(results):.1f}%)") + print(f"Runtime: {total_runtime:.1f}s ({total_runtime/60:.1f}min)") + print(f"Output: {config.output_dir}") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/tests/__init__.py b/benchmarks/tests/__init__.py new file mode 100644 index 0000000..09aa370 --- /dev/null +++ b/benchmarks/tests/__init__.py @@ -0,0 +1 @@ +"""Tests package for benchmarks module.""" diff --git a/benchmarks/tests/test_benchmark.py b/benchmarks/tests/test_benchmark.py new file mode 100644 index 0000000..9afdf87 --- /dev/null +++ b/benchmarks/tests/test_benchmark.py @@ -0,0 +1,307 @@ +""" +test_benchmark.py + +Unit tests for the ConforMix benchmarking suite. + +Run with: + pytest benchmarks/tests/test_benchmark.py -v +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import patch, MagicMock + +import numpy as np +import pytest + +from benchmarks.run_benchmark import ( + BenchmarkConfig, + BenchmarkResult, + run_benchmark, + fetch_sequence_from_pdb, + count_samples, +) +from benchmarks.evaluate_metrics import ( + MetricsResult, + compute_rmsd_simple, + extract_ca_coords_from_pdb, + compute_metrics, +) +from benchmarks.generate_report import ( + generate_markdown_report, + generate_html_report, +) + + +class TestBenchmarkConfig: + """Tests for BenchmarkConfig class.""" + + def test_from_yaml(self, tmp_path): + """Test loading config from YAML.""" + yaml_content = """ +dataset_name: test_dataset +csv_path: datasets/test.csv +output_dir: outputs/test +num_twist_targets: 3 +samples_per_target: 2 +structured_regions_only: true +""" + yaml_path = tmp_path / "test_config.yaml" + yaml_path.write_text(yaml_content) + + config = BenchmarkConfig.from_yaml(yaml_path) + + assert config.dataset_name == "test_dataset" + assert config.num_twist_targets == 3 + assert config.samples_per_target == 2 + assert config.structured_regions_only is True + + def test_to_dict(self): + """Test config serialization.""" + config = BenchmarkConfig( + dataset_name="test", + csv_path=Path("test.csv"), + output_dir=Path("output"), + ) + + d = config.to_dict() + + assert d["dataset_name"] == "test" + assert d["csv_path"] == "test.csv" + assert "num_twist_targets" in d + + +class TestBenchmarkResult: + """Tests for BenchmarkResult class.""" + + def test_to_dict(self): + """Test result serialization.""" + result = BenchmarkResult( + system_id="test_system", + pdb1="1ABC_A", + pdb2="2XYZ_A", + ground_truth_rmsd=5.0, + success=True, + runtime_seconds=120.5, + output_dir=Path("/output/test"), + num_samples_generated=10, + ) + + d = result.to_dict() + + assert d["system_id"] == "test_system" + assert d["success"] is True + assert d["runtime_seconds"] == 120.5 + assert d["num_samples_generated"] == 10 + + +class TestRMSDComputation: + """Tests for RMSD computation functions.""" + + def test_compute_rmsd_simple_identical(self): + """Test RMSD of identical coordinates is zero.""" + coords = np.array([ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + ], dtype=float) + + rmsd = compute_rmsd_simple(coords, coords) + + assert rmsd < 1e-6 + + def test_compute_rmsd_simple_translation(self): + """Test RMSD is zero for translated coordinates (after alignment).""" + coords1 = np.array([ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + ], dtype=float) + + coords2 = coords1 + np.array([10, 20, 30]) + + rmsd = compute_rmsd_simple(coords1, coords2) + + assert rmsd < 1e-6 + + def test_compute_rmsd_simple_known_value(self): + """Test RMSD computation with known result.""" + # Square in XY plane + coords1 = np.array([ + [0, 0, 0], + [1, 0, 0], + [1, 1, 0], + [0, 1, 0], + ], dtype=float) + + # Distorted square - can't be aligned perfectly + coords2 = np.array([ + [0, 0, 0], + [1.5, 0, 0], # Stretched + [1.5, 1, 0], + [0, 1, 0], + ], dtype=float) + + rmsd = compute_rmsd_simple(coords1, coords2) + + # RMSD should be non-zero for distorted coordinates + assert rmsd > 0.1 + + +class TestPDBParsing: + """Tests for PDB file parsing.""" + + def test_extract_ca_coords(self, tmp_path): + """Test CA coordinate extraction from PDB.""" + pdb_content = """ATOM 1 N ALA A 1 0.000 0.000 0.000 1.00 0.00 N +ATOM 2 CA ALA A 1 1.458 0.000 0.000 1.00 0.00 C +ATOM 3 C ALA A 1 2.009 1.420 0.000 1.00 0.00 C +ATOM 4 O ALA A 1 1.251 2.390 0.000 1.00 0.00 O +ATOM 5 N GLY A 2 3.326 1.548 0.000 1.00 0.00 N +ATOM 6 CA GLY A 2 3.992 2.826 0.000 1.00 0.00 C +END +""" + pdb_path = tmp_path / "test.pdb" + pdb_path.write_text(pdb_content) + + coords = extract_ca_coords_from_pdb(pdb_path) + + assert coords is not None + assert coords.shape == (2, 3) + np.testing.assert_array_almost_equal(coords[0], [1.458, 0.0, 0.0]) + np.testing.assert_array_almost_equal(coords[1], [3.992, 2.826, 0.0]) + + +class TestMetricsResult: + """Tests for MetricsResult class.""" + + def test_to_dict(self): + """Test metrics serialization.""" + metrics = MetricsResult( + system_id="test", + pdb1="1ABC_A", + pdb2="2XYZ_A", + ground_truth_rmsd=5.0, + min_rmsd_to_alt=3.5, + conformational_coverage=0.7, + ) + + d = metrics.to_dict() + + assert d["system_id"] == "test" + assert d["min_rmsd_to_alt"] == 3.5 + assert d["conformational_coverage"] == 0.7 + + +class TestReportGeneration: + """Tests for report generation.""" + + def test_generate_markdown_report(self, tmp_path): + """Test Markdown report generation.""" + metrics = [ + { + "system_id": "test1", + "pdb1": "1ABC_A", + "pdb2": "2XYZ_A", + "ground_truth_rmsd": 5.0, + "min_rmsd_to_alt": 3.5, + "mean_rmsd_to_alt": 4.0, + "conformational_coverage": 0.7, + "num_samples": 10, + "error_message": None, + }, + { + "system_id": "test2", + "pdb1": "3DEF_A", + "pdb2": "4GHI_A", + "ground_truth_rmsd": 8.0, + "min_rmsd_to_alt": None, + "conformational_coverage": None, + "num_samples": 0, + "error_message": "Test error", + }, + ] + + metrics_path = tmp_path / "metrics.json" + with open(metrics_path, "w") as f: + json.dump(metrics, f) + + output_path = tmp_path / "report.md" + report = generate_markdown_report(metrics_path, output_path) + + assert output_path.exists() + assert "ConforMix Benchmark Report" in report + assert "test1" in report + assert "test2" in report + assert "Error" in report + + def test_generate_html_report(self, tmp_path): + """Test HTML report generation.""" + metrics = [ + { + "system_id": "test1", + "pdb1": "1ABC_A", + "pdb2": "2XYZ_A", + "ground_truth_rmsd": 5.0, + "min_rmsd_to_alt": 3.5, + "conformational_coverage": 0.7, + "num_samples": 10, + "error_message": None, + }, + ] + + metrics_path = tmp_path / "metrics.json" + with open(metrics_path, "w") as f: + json.dump(metrics, f) + + output_path = tmp_path / "report.html" + report = generate_html_report(metrics_path, output_path) + + assert output_path.exists() + assert "" in report + assert "test1" in report + + +class TestCountSamples: + """Tests for sample counting.""" + + def test_count_samples(self, tmp_path): + """Test counting sample files.""" + (tmp_path / "sample1.cif").write_text("test") + (tmp_path / "sample2.cif").write_text("test") + (tmp_path / "sample3.pdb").write_text("test") + (tmp_path / "other.txt").write_text("test") + + count = count_samples(tmp_path) + + assert count == 3 + + +class TestIntegration: + """Integration tests.""" + + def test_dry_run_benchmark(self, tmp_path): + """Test dry run of benchmark (no actual execution).""" + csv_content = """system_id,pdb1,pdb2,RMSD,TM-score +test1,1ABC_A,2XYZ_A,5.0,0.8 +test2,3DEF_A,4GHI_A,8.0,0.7 +""" + csv_path = tmp_path / "test_dataset.csv" + csv_path.write_text(csv_content) + + config = BenchmarkConfig( + dataset_name="test", + csv_path=csv_path, + output_dir=tmp_path / "output", + ) + + results = run_benchmark(config, dry_run=True) + + assert len(results) == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])