Skip to content

Accelerated ts.optimize by batching Frechet Cell Filter #439

Open
falletta wants to merge 19 commits intoTorchSim:mainfrom
falletta:speedup_relax
Open

Accelerated ts.optimize by batching Frechet Cell Filter #439
falletta wants to merge 19 commits intoTorchSim:mainfrom
falletta:speedup_relax

Conversation

@falletta
Copy link
Contributor

@falletta falletta commented Feb 3, 2026

Update Summary


1. torch_sim/math.py

Removed unbatched/legacy code:

  • expm_frechet_block_enlarge (helper function for block enlargement method)
  • _diff_pade3, _diff_pade5, _diff_pade7, _diff_pade9 (Padé approximation helpers)
  • expm_frechet_algo_64 (original algorithm implementation)
  • matrix_exp (custom matrix exponential function)
  • vec, expm_frechet_kronform (Kronecker form helpers)
  • expm_cond (condition number estimation)
  • class expm (autograd Function class)
  • _is_valid_matrix, _determine_eigenvalue_case (unbatched helpers)

Refactored expm_frechet:

  • Now optimized specifically for batched 3x3 matrices (common case for cell operations)
  • Handles both (B, 3, 3) batched input and (3, 3) single matrix input (auto-adds batch dim)
  • Removed method parameter (was SPS"or blockEnlarge)
  • Inlined the algorithm directly instead of calling helper functions

Refactored matrix_log_33:

  • Added _ensure_batched, _determine_matrix_log_cases, _process_matrix_log_case helpers
  • Made the matrix log computation work in batched mode

2. torch_sim/optimizers/cell_filters.py

Vectorized compute_cell_forces:

  • Before: used nested loops over systems and directions (9 iterations per system)
  • After: uses batched matrix operations across all systems and all 9 directions simultaneously
  • Key optimization: expm_frechet(A_batch, E_batch) is now called once with all n_systems * 9 matrices batched together

3. tests/test_math.py

Refactored tests:

  • TestExpmFrechet: test_expm_frechet, test_small_norm_expm_frechet, test_fuzz
  • TestExpmFrechetTorch: test_expm_frechet, test_fuzz

All updated to use 3x3 matrices and simplified by removing method parameter testing. Fuzz tests streamlined with fewer iterations.

Removed tests:

  • test_problematic_matrix, test_medium_matrix (both numpy and torch versions)
  • TestExpmFrechetTorchGrad class

Tests for comparing computation methods and large matrix performance no longer apply to the 3x3-specialized implementation.

Added tests:

  • TestExpmFrechet.test_large_norm_matrices - Tests scaling behavior for larger norm matrices
  • TestLogM33.test_batched_positive_definite - Tests batched matrix logarithm with round-trip verification
  • TestFrechetCellFilterIntegration - Integration tests for the cell filter pipeline
  • test_wrap_positions_* - Tests for the new wrap_positions property

Results

The figure below shows the speedup achieved for 10-step atomic relaxation. The test is performed for a 8-atom cubic supercell of MgO using the mace-mpa model. Prior results are shown in blue, while new results are shown in red. The speedup is calculated as speedup (%) = (baseline_time / current_time − 1) × 100. We observe a speedup up to 564% for large batches.
Screenshot 2026-02-03 at 2 57 00 PM

Co-authored-by: Cursor <cursoragent@cursor.com>
@orionarcher
Copy link
Collaborator

Could we get some tests verifying identical numerical behavior of the old and new versions? Can be deleted before merging when we get rid of the unbatched version.

@falletta
Copy link
Contributor Author

falletta commented Feb 4, 2026

@orionarcher I added test_math_frechet.py to compare the batched code against the SciPy implementations. Please have a look—happy to revise it as needed, and if it looks good, we can include it directly in test_math.py.

num_tol = 1e-16 if dtype == torch.float64 else 1e-8
batched = T.dim() == 3

if batched:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why support both batched and unbatched versions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I now removed all unbatched code

class TestExpmFrechet:
"""Tests for expm_frechet against scipy.linalg.expm_frechet."""

def test_small_matrix(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we testing the batched or unbatched versions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I incorporated the batched tests only in test_math.py

@thomasloux
Copy link
Collaborator

Haven't looked carefully on the implementation but I would potentially support to have 2 separate functions for batched (B, 3, 3) and unbatched (3,3) algorithms. This would also prevent graph breaks in the future, be easier to read, and in practice a state.cell is always (B, 3, 3), potentially with B=1. So we would always use the batched version anyway.

Co-authored-by: Cursor <cursoragent@cursor.com>
@falletta
Copy link
Contributor Author

falletta commented Feb 5, 2026

@orionarcher I removed all unbatched and unused code while preserving the new performance speedups. Please see the PR description for a detailed list of changes.

@thomasloux It’s indeed a good point, but for now it’s probably better to keep things clean and stick to the batched implementation only. By keeping only the batched implementation, we can remove quite a few lines of dead code.

Co-authored-by: Cursor <cursoragent@cursor.com>
@orionarcher
Copy link
Collaborator

Again not checking the integration math, but noting that TestExpmFrechetTorch.test_expm_frechet and test_fuzz both discard the Frechet derivative result (observed_expm, _ = ...). Since the Frechet derivative is the core output used by the cell filter, these should verify it too. Can you add the Frechet assertions back?

@falletta
Copy link
Contributor Author

Right, great point! I restored the Frechet derivative assertions by computing the expected derivative via the block-matrix identity and checking it against the second output of fm.expm_frechet.

@orionarcher
Copy link
Collaborator

@thomasloux or @abhijeetgangan would you take a quick pass on this? LGTM but i don't know the math

Copy link
Collaborator

@orionarcher orionarcher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving but let's get a more mathy person to review too

Comment on lines 312 to 319
A_batch = (
deform_grad_log.unsqueeze(1).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3)
)
E_batch = directions.unsqueeze(0).expand(n_systems, 9, 3, 3).reshape(-1, 3, 3)
_, expm_derivs_batch = fm.expm_frechet(A_batch, E_batch)
expm_derivs = expm_derivs_batch.reshape(n_systems, 9, 3, 3)
forces_flat = (expm_derivs * ucf_cell_grad.unsqueeze(1)).sum(dim=(2, 3))
cell_forces = forces_flat.reshape(n_systems, 3, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a bit mathy for the md logic? could we hide this inside a function that takes the directions and deform_grad_log and returns forces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I extracted the batched Frechet-derivative math into a private helper

device, dtype = A.device, A.dtype
ident = torch.eye(3, device=device, dtype=dtype).unsqueeze(0).expand(batch_size, 3, 3)

A_norm_1 = torch.norm(A, p=1, dim=(-2, -1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we've lost the provenience to the algorithm name being used given that only the 64 is being used (i believe) now with the batched code. I am not familiar with the numerical decision making about whether a blockEnlarge pathway is worth maintaining.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, following also Thomas’ comment below, I reintroduced the blockEnlarge method and added the frechet_method attribute to activate it (before, that part of the code was never accessed). More details below

Copy link
Member

@CompRhys CompRhys left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but just want to double check with @abhijeetgangan that we are okay with removing the other methods that seemingly are not used. In the expm_frechet several paths are removed. The fact that the ase optimizer tests still passes suggests that they were not necessary

@thomasloux
Copy link
Collaborator

The only part in the implementation that is not strictly equivalent to scipy is the scaling s that is defined as a scalar in your batched implementation for all matrices, and not one scaling for each matrix, which is in theory needed. But in a MD, you are not supposed to have a cell with matrix norms stupidly high, so I agree with your implementation.

Yet, I would just like to know why you got rid of expm_frechet_block_enlarge. In my test it is actually on par with your implementation, and even slightly faster on my mac for small batch.

Code to reproduce:

import time

import torch

import torch_sim.math as fm
try:
  from math_old import expm_frechet as unbatched_expm_frechet
except ImportError:
  print("Warning: math_old.expm_frechet not found, unbatched loop will fail.")
  def unbatched_expm_frechet(A, E):
      raise NotImplementedError("math_old.expm_frechet not found")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def frechet_matrix_exp(A, E):
    """Compute exp(A) and its Frechet derivative L_A(E) via block matrix_exp."""
    n = A.shape[-1]
    top = torch.cat([A, E], dim=-1)
    bottom = torch.cat([torch.zeros_like(A), A], dim=-1)
    block = torch.cat([top, bottom], dim=-2)
    exp_block = torch.linalg.matrix_exp(block)
    exp_A = exp_block[..., :n, :n]
    L_A_E = exp_block[..., :n, n:]
    return exp_A, L_A_E


def unbatched_loop(A_batch, E_batch):
    """Loop over batch calling the old unbatched Pade expm_frechet."""
    Rs, Ls = [], []
    for i in range(A_batch.shape[0]):
        R, L = unbatched_expm_frechet(A_batch[i], E_batch[i])
        Rs.append(R)
        Ls.append(L)
    return torch.stack(Rs), torch.stack(Ls)


def benchmark(func, A_batch, E_batch, n_warmup=3, n_runs=30, label=""):
    """Run warmup + timed runs, return average time in seconds."""
    for _ in range(n_warmup):
        func(A_batch, E_batch)
    if device.type == "cuda":
        torch.cuda.synchronize()

    t0 = time.perf_counter()
    for _ in range(n_runs):
        func(A_batch, E_batch)
    if device.type == "cuda":
        torch.cuda.synchronize()
    elapsed = (time.perf_counter() - t0) / n_runs
    return elapsed


def test_benchmark_large_batch(batch_size=150):
    """Time 3 Frechet derivative implementations."""
    torch.manual_seed(42)

    A_batch = torch.randn(batch_size, 3, 3, dtype=torch.float64, device=device) * 0.1
    E_batch = torch.randn(batch_size, 3, 3, dtype=torch.float64, device=device) * 0.1

    n_warmup = 5
    n_runs = 30

    t_batched_pade = benchmark(
        fm.expm_frechet, A_batch, E_batch, n_warmup, n_runs,
    )
    t_block = benchmark(
        frechet_matrix_exp, A_batch, E_batch, n_warmup, n_runs,
    )
    try:
      t_unbatched_loop = benchmark(
              unbatched_loop, A_batch, E_batch, n_warmup, n_runs,
          )
    except NotImplementedError:
        t_unbatched_loop = float('nan')

    print(f"\n{'='*65}")
    print(f"Frechet derivative benchmark  (batch_size={batch_size}, "
          f"device={device}, {n_runs} runs)")
    print(f"  1) fm.expm_frechet  (batched Pade):  {t_batched_pade*1000:8.3f} ms")
    print(f"  2) frechet_matrix_exp (block exp):    {t_block*1000:8.3f} ms")
    print(f"  3) unbatched loop   (old Pade):       {t_unbatched_loop*1000:8.3f} ms")
    print(f"  ---")
    print(f"  block_exp  / batched_pade  = {t_block/t_batched_pade:.2f}x")
    print(f"  old_loop   / batched_pade  = {t_unbatched_loop/t_batched_pade:.2f}x")
    print(f"  old_loop   / block_exp     = {t_unbatched_loop/t_block:.2f}x")
    print(f"{'='*65}")

    # Verify all 3 agree
    _, L_bp = fm.expm_frechet(A_batch, E_batch)
    _, L_blk = frechet_matrix_exp(A_batch, E_batch)
    try:
      _, L_old = unbatched_loop(A_batch, E_batch)
    except NotImplementedError:
      L_old = L_bp
    torch.testing.assert_close(L_bp, L_blk, atol=1e-10, rtol=1e-6)
    torch.testing.assert_close(L_bp, L_old, atol=1e-10, rtol=1e-6)
    print("Correctness check passed.")


if __name__ == "__main__":
    test_benchmark_large_batch(1)
    test_benchmark_large_batch(150)
    test_benchmark_large_batch(1500)

uv run python test_frechet.py                                                                                                                                                                                           ─╯
Warp DeprecationWarning: The symbol `warp.vec` will soon be removed from the public API. Use `warp.types.vector` instead.

=================================================================
Frechet derivative benchmark  (batch_size=1, device=cpu, 30 runs)
  1) fm.expm_frechet  (batched Pade):     0.114 ms
  2) frechet_matrix_exp (block exp):       0.063 ms
  3) unbatched loop   (old Pade):          0.084 ms
  ---
  block_exp  / batched_pade  = 0.55x
  old_loop   / batched_pade  = 0.73x
  old_loop   / block_exp     = 1.32x
=================================================================
Correctness check passed.

=================================================================
Frechet derivative benchmark  (batch_size=150, device=cpu, 30 runs)
  1) fm.expm_frechet  (batched Pade):     0.323 ms
  2) frechet_matrix_exp (block exp):       0.218 ms
  3) unbatched loop   (old Pade):         11.200 ms
  ---
  block_exp  / batched_pade  = 0.68x
  old_loop   / batched_pade  = 34.64x
  old_loop   / block_exp     = 51.32x
=================================================================
Correctness check passed.

=================================================================
Frechet derivative benchmark  (batch_size=1500, device=cpu, 30 runs)
  1) fm.expm_frechet  (batched Pade):     1.557 ms
  2) frechet_matrix_exp (block exp):       1.343 ms
  3) unbatched loop   (old Pade):        112.160 ms
  ---
  block_exp  / batched_pade  = 0.86x
  old_loop   / batched_pade  = 72.03x
  old_loop   / block_exp     = 83.49x
=================================================================
Correctness check passed.

@thomasloux
Copy link
Collaborator

thomasloux commented Feb 13, 2026

On a h200:


=================================================================
Frechet derivative benchmark  (batch_size=1, device=cuda, 30 runs)
  1) fm.expm_frechet  (batched Pade):     0.916 ms
  2) frechet_matrix_exp (block exp):       0.463 ms
  3) unbatched loop   (old Pade):            nan ms
  ---
  block_exp  / batched_pade  = 0.51x
  old_loop   / batched_pade  = nanx
  old_loop   / block_exp     = nanx
=================================================================
Correctness check passed.

=================================================================
Frechet derivative benchmark  (batch_size=150, device=cuda, 30 runs)
  1) fm.expm_frechet  (batched Pade):     0.890 ms
  2) frechet_matrix_exp (block exp):       0.420 ms
  3) unbatched loop   (old Pade):            nan ms
  ---
  block_exp  / batched_pade  = 0.47x
  old_loop   / batched_pade  = nanx
  old_loop   / block_exp     = nanx
=================================================================
Correctness check passed.

=================================================================
Frechet derivative benchmark  (batch_size=1500, device=cuda, 30 runs)
  1) fm.expm_frechet  (batched Pade):     0.902 ms
  2) frechet_matrix_exp (block exp):       0.432 ms
  3) unbatched loop   (old Pade):            nan ms
  ---
  block_exp  / batched_pade  = 0.48x
  old_loop   / batched_pade  = nanx
  old_loop   / block_exp     = nanx
=================================================================
Correctness check passed.

@falletta
Copy link
Contributor Author

Thanks for looking into the math and for running scaling tests, @thomasloux . The reason I had removed expm_frechet_block_enlarge was that it was never accessed. I have now reinserted the method and added the frechet_method variable, which users can specify in init_kwargs:

ts.optimize(
  ...
  init_kwargs={
      "cell_filter": ts.optimizers.cell_filters.CellFilter.frechet,
      "frechet_method": "blockEnlarge",
        ....
  },

In addition, I reinserted the unit tests that compare blockEnlarge with SPS. I verified that the tests pass and that scaling performance is essentially equivalent when switching between the two.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implement batched versions of math operations used in Frechet and all treatment of complex eigs in logM

5 participants