diff --git a/conformix_boltz/pytest.ini b/conformix_boltz/pytest.ini new file mode 100644 index 0000000..7eda4ea --- /dev/null +++ b/conformix_boltz/pytest.ini @@ -0,0 +1,30 @@ +[pytest] +# Pytest configuration for ConforMix-Boltz + +# Test discovery +python_files = test_*.py +python_classes = Test* +python_functions = test_* + +# Markers +markers = + integration: Integration tests that require model download and GPU + slow: Tests that take a long time to run + unit: Unit tests that are fast + +# Test paths +testpaths = tests + +# Output options +addopts = + -v + --strict-markers + --tb=short + --disable-warnings + +# Coverage (if pytest-cov is installed) +# addopts = +# --cov=boltz +# --cov-report=html +# --cov-report=term-missing + diff --git a/conformix_boltz/tests/__init__.py b/conformix_boltz/tests/__init__.py new file mode 100644 index 0000000..18ee6b5 --- /dev/null +++ b/conformix_boltz/tests/__init__.py @@ -0,0 +1,2 @@ +"""Tests for ConforMix-Boltz.""" + diff --git a/conformix_boltz/tests/conftest.py b/conformix_boltz/tests/conftest.py new file mode 100644 index 0000000..10b8552 --- /dev/null +++ b/conformix_boltz/tests/conftest.py @@ -0,0 +1,86 @@ +"""Pytest configuration and fixtures for ConforMix-Boltz tests.""" +import pytest +import numpy as np +import tempfile +from pathlib import Path +import mdtraj as md + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for tests.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def example_fasta_file(temp_dir): + """Create an example FASTA file.""" + fasta_path = temp_dir / "test_protein.fasta" + fasta_content = """>test_protein +MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRQTLGQHDFSAGEGLYTHMKALRPDEDRLSPLHSVYVDQWDWERVMGDGERQFSTLKSTVEAIWAGIKATEAAVSEEFGLAPFLPDQIHFVHSQELLSRYPDLDAKGRERAIAKDLGAVFLVGIGGKLSDGHRHDVRAPDYDDWSTPSELGHAGLNGDILVWNPVLEDAFELSSMGIRVDADTLKHQLALTGDEDRLELEWHQALLRGEMPQTIGGGIGQSRLTMLLLQLPHIGQVQAGVWPAAVRESVPSLL +""" + fasta_path.write_text(fasta_content) + return fasta_path + + +@pytest.fixture +def short_fasta_file(temp_dir): + """Create a short FASTA file for quick tests.""" + fasta_path = temp_dir / "short_protein.fasta" + fasta_content = """>short_protein +MKTAYIAKQR +""" + fasta_path.write_text(fasta_content) + return fasta_path + + +@pytest.fixture +def example_trajectory(): + """Create an example MDTraj trajectory for testing.""" + # Create a simple 3-residue trajectory + n_frames = 10 + n_atoms = 15 # 5 atoms per residue (N, CA, C, O, CB) + + # Create topology + top = md.Topology() + chain = top.add_chain() + + for i in range(3): + residue = top.add_residue(f"RES{i}", chain) + for atom_name in ["N", "CA", "C", "O", "CB"]: + top.add_atom(atom_name, md.element.carbon, residue) + + # Create coordinates + xyz = np.random.randn(n_frames, n_atoms, 3) * 0.1 + np.array([0, 0, 0]) + + traj = md.Trajectory(xyz=xyz, topology=top) + return traj + + +@pytest.fixture +def example_cif_file(temp_dir): + """Create a minimal example CIF file.""" + cif_path = temp_dir / "test_structure.cif" + # Minimal CIF content + cif_content = """data_test +# +loop_ +_atom_site.group_PDB +_atom_site.id +_atom_site.type_symbol +_atom_site.label_atom_id +_atom_site.label_comp_id +_atom_site.label_asym_id +_atom_site.label_seq_id +_atom_site.Cartn_x +_atom_site.Cartn_y +_atom_site.Cartn_z +ATOM 1 N N MET A 1 10.0 20.0 30.0 +ATOM 2 C CA MET A 1 11.0 21.0 31.0 +ATOM 3 C C MET A 1 12.0 22.0 32.0 +ATOM 4 O O MET A 1 13.0 23.0 33.0 +""" + cif_path.write_text(cif_content) + return cif_path + diff --git a/conformix_boltz/tests/test_integration.py b/conformix_boltz/tests/test_integration.py new file mode 100644 index 0000000..adb5819 --- /dev/null +++ b/conformix_boltz/tests/test_integration.py @@ -0,0 +1,89 @@ +"""End-to-end integration tests for ConforMix-Boltz.""" +import pytest +from pathlib import Path +import tempfile +import shutil + + +@pytest.mark.integration +@pytest.mark.skip(reason="Requires model download and GPU - run manually") +class TestEndToEnd: + """End-to-end integration tests.""" + + def test_full_pipeline_short_sequence(self, temp_dir): + """Test full pipeline with a short sequence.""" + # This test requires: + # 1. Model download + # 2. GPU access + # 3. Significant time + + # Create a short test sequence + fasta_path = temp_dir / "test.fasta" + fasta_content = """>test_protein +MKTAYIAKQRQISFVKSHFSRQLEERLGLIEVQAPILSRVGDGTQDNLSGAEKAVQVKVKALPDAQFEVVHSLAKWKRQTLGQHDFSAGEGLYTHMKALRPDEDRLSPLHSVYVDQWDWERVMGDGERQFSTLKSTVEAIWAGIKATEAAVSEEFGLAPFLPDQIHFVHSQELLSRYPDLDAKGRERAIAKDLGAVFLVGIGGKLSDGHRHDVRAPDYDDWSTPSELGHAGLNGDILVWNPVLEDAFELSSMGIRVDADTLKHQLALTGDEDRLELEWHQALLRGEMPQTIGGGIGQSRLTMLLLQLPHIGQVQAGVWPAAVRESVPSLL +""" + fasta_path.write_text(fasta_content) + + output_dir = temp_dir / "output" + + # Run the pipeline + # This would be the actual call: + # python -m boltz.run_conformixrmsd_boltz \ + # --fasta_path {fasta_path} \ + # --out_dir {output_dir} \ + # --num_twist_targets 3 \ + # --samples_per_target 1 + + # Check outputs + # assert (output_dir / "default_reference").exists() + # assert (output_dir / "final_filtered" / "topology.pdb").exists() + # assert (output_dir / "final_filtered" / "samples.xtc").exists() + + # For now, just verify the test structure + assert fasta_path.exists() + + def test_pipeline_with_custom_reference(self, temp_dir): + """Test pipeline with a custom reference structure.""" + # This would test using a provided reference CIF + # instead of generating one + pass + + def test_pipeline_with_subset_residues(self, temp_dir): + """Test pipeline with subset residue selection.""" + # This would test the --subset_residues functionality + pass + + +@pytest.mark.integration +class TestDataProcessing: + """Tests for data processing components.""" + + def test_cif_to_xtc_conversion(self, temp_dir): + """Test CIF to XTC conversion.""" + # This would test the cif_to_xtc utility + # with actual CIF files + pass + + def test_trajectory_filtering(self, temp_dir): + """Test trajectory filtering pipeline.""" + # This would test the full filtering pipeline + pass + + +@pytest.mark.slow +@pytest.mark.skip(reason="Very slow - run only for full validation") +class TestFullValidation: + """Full validation tests that take a long time.""" + + def test_large_protein(self): + """Test with a large protein (>500 residues).""" + pass + + def test_multiple_targets(self): + """Test with many RMSD targets.""" + pass + + def test_many_samples(self): + """Test with many samples per target.""" + pass + diff --git a/conformix_boltz/tests/test_run_conformixrmsd.py b/conformix_boltz/tests/test_run_conformixrmsd.py new file mode 100644 index 0000000..377c1d1 --- /dev/null +++ b/conformix_boltz/tests/test_run_conformixrmsd.py @@ -0,0 +1,78 @@ +"""Tests for the main ConforMix RMSD script.""" +import pytest +from pathlib import Path +from unittest.mock import Mock, patch, MagicMock + +from boltz.run_conformixrmsd_boltz import ( + get_sequence_length_from_fasta, + main, +) + + +class TestGetSequenceLengthFromFasta: + """Tests for FASTA sequence length extraction.""" + + def test_get_sequence_length_simple(self, example_fasta_file): + """Test getting sequence length from simple FASTA.""" + length = get_sequence_length_from_fasta(example_fasta_file) + assert length > 0 + assert isinstance(length, int) + + def test_get_sequence_length_multiline(self, temp_dir): + """Test getting sequence length from multiline FASTA.""" + fasta_path = temp_dir / "multiline.fasta" + fasta_content = """>test +MKTAYIAKQR +QISFVKSHFS +""" + fasta_path.write_text(fasta_content) + + length = get_sequence_length_from_fasta(fasta_path) + assert length == 20 # Combined length + + def test_get_sequence_length_with_header(self, temp_dir): + """Test that header lines are ignored.""" + fasta_path = temp_dir / "with_header.fasta" + fasta_content = """>test_protein description +MKTAYIAKQR +""" + fasta_path.write_text(fasta_content) + + length = get_sequence_length_from_fasta(fasta_path) + assert length == 10 + + +class TestMainFunction: + """Tests for the main function.""" + + @pytest.mark.skip(reason="Requires model download and GPU") + def test_main_basic(self, example_fasta_file, temp_dir): + """Test basic main function execution.""" + # This test requires actual model and may be slow + # Skip in CI, run manually for integration testing + output_dir = temp_dir / "output" + + with patch('boltz.run_conformixrmsd_boltz.run_untwisted') as mock_untwisted, \ + patch('boltz.run_conformixrmsd_boltz.run_twisted') as mock_twisted: + + # Mock the model loading + mock_model = Mock() + mock_untwisted.load_model.return_value = mock_model + + # Mock prediction + mock_untwisted.predict.callback.return_value = None + mock_twisted.predict.callback.return_value = None + + # This would require actual implementation + # For now, just test argument parsing + pass + + def test_main_invalid_fasta(self, temp_dir): + """Test main function with invalid FASTA path.""" + invalid_fasta = temp_dir / "nonexistent.fasta" + output_dir = temp_dir / "output" + + # Should exit with error + # This would be tested with actual execution + assert not invalid_fasta.exists() + diff --git a/conformix_boltz/tests/test_utils.py b/conformix_boltz/tests/test_utils.py new file mode 100644 index 0000000..db77573 --- /dev/null +++ b/conformix_boltz/tests/test_utils.py @@ -0,0 +1,169 @@ +"""Tests for utility functions.""" +import pytest +import numpy as np +from pathlib import Path +import mdtraj as md + +from boltz.utils.cif_to_xtc import ( + find_cif_files, + filter_unphysical_traj, + combine_structures, + load_cif_structures, +) + + +class TestFindCifFiles: + """Tests for finding CIF files.""" + + def test_find_cif_files_single(self, temp_dir): + """Test finding a single CIF file.""" + cif_file = temp_dir / "test.cif" + cif_file.write_text("data_test\n") + + found = find_cif_files(temp_dir) + assert len(found) == 1 + assert found[0] == cif_file + + def test_find_cif_files_multiple(self, temp_dir): + """Test finding multiple CIF files.""" + (temp_dir / "file1.cif").write_text("data_test\n") + (temp_dir / "file2.cif").write_text("data_test\n") + subdir = temp_dir / "subdir" + subdir.mkdir() + (subdir / "file3.cif").write_text("data_test\n") + + found = find_cif_files(temp_dir) + assert len(found) == 3 + + def test_find_cif_files_none(self, temp_dir): + """Test finding no CIF files.""" + found = find_cif_files(temp_dir) + assert len(found) == 0 + + def test_find_cif_files_case_insensitive(self, temp_dir): + """Test case-insensitive CIF file finding.""" + (temp_dir / "test.CIF").write_text("data_test\n") + (temp_dir / "test2.Cif").write_text("data_test\n") + + found = find_cif_files(temp_dir) + assert len(found) == 2 + + +class TestFilterUnphysicalTraj: + """Tests for filtering unphysical trajectories.""" + + def test_filter_keeps_good_trajectory(self, example_trajectory): + """Test that good trajectories pass filtering.""" + filtered = filter_unphysical_traj(example_trajectory) + assert filtered.n_frames > 0 + assert filtered.n_atoms == example_trajectory.n_atoms + + def test_filter_removes_clashes(self, temp_dir): + """Test that structures with clashes are filtered.""" + # Create trajectory with overlapping atoms + top = md.Topology() + chain = top.add_chain() + residue = top.add_residue("RES", chain) + top.add_atom("CA", md.element.carbon, residue) + top.add_atom("CA2", md.element.carbon, residue) + + # Create coordinates with atoms too close + xyz = np.array([[[0.0, 0.0, 0.0], [0.1, 0.1, 0.1]]]) # Very close atoms + traj = md.Trajectory(xyz=xyz, topology=top) + + filtered = filter_unphysical_traj(traj, clash_distance=0.5) + # Should filter out frames with clashes + assert filtered.n_frames <= traj.n_frames + + def test_filter_strict_mode(self, example_trajectory): + """Test strict filtering mode.""" + # This should work if trajectory is good + try: + filtered = filter_unphysical_traj(example_trajectory, strict=True) + assert filtered.n_frames > 0 + except ValueError: + # If all frames filtered, that's also valid + pass + + +class TestCombineStructures: + """Tests for combining structures.""" + + def test_combine_single_structure(self, example_trajectory): + """Test combining a single structure.""" + structures = [example_trajectory] + combined = combine_structures(structures) + + assert combined is not None + assert combined.n_frames == 1 + assert combined.n_atoms == example_trajectory.n_atoms + + def test_combine_multiple_structures(self, example_trajectory): + """Test combining multiple structures.""" + structures = [example_trajectory, example_trajectory] + combined = combine_structures(structures) + + assert combined is not None + assert combined.n_frames == 2 + assert combined.n_atoms == example_trajectory.n_atoms + + def test_combine_empty_list(self): + """Test combining empty list.""" + combined = combine_structures([]) + assert combined is None + + def test_combine_mismatched_atoms(self, example_trajectory): + """Test combining structures with mismatched atom counts.""" + # Create trajectory with different atom count + top = md.Topology() + chain = top.add_chain() + residue = top.add_residue("RES", chain) + top.add_atom("CA", md.element.carbon, residue) + + xyz = np.array([[[0.0, 0.0, 0.0]]]) + traj2 = md.Trajectory(xyz=xyz, topology=top) + + structures = [example_trajectory, traj2] + combined = combine_structures(structures) + + # Should only include matching structures + assert combined is not None + assert combined.n_frames == 1 # Only first structure matches itself + + +class TestLoadCifStructures: + """Tests for loading CIF structures.""" + + def test_load_cif_file(self, example_cif_file): + """Test loading a CIF file.""" + structures = load_cif_structures([example_cif_file]) + assert len(structures) == 1 + assert isinstance(structures[0], md.Trajectory) + + def test_load_multiple_cif_files(self, temp_dir): + """Test loading multiple CIF files.""" + cif1 = temp_dir / "file1.cif" + cif2 = temp_dir / "file2.cif" + + # Create minimal CIF files + cif_content = """data_test +loop_ +_atom_site.group_PDB +_atom_site.id +_atom_site.type_symbol +_atom_site.label_atom_id +_atom_site.label_comp_id +_atom_site.label_asym_id +_atom_site.label_seq_id +_atom_site.Cartn_x +_atom_site.Cartn_y +_atom_site.Cartn_z +ATOM 1 N N MET A 1 10.0 20.0 30.0 +""" + cif1.write_text(cif_content) + cif2.write_text(cif_content) + + structures = load_cif_structures([cif1, cif2]) + assert len(structures) == 2 + assert all(isinstance(s, md.Trajectory) for s in structures) +