Skip to content

num_neighbors pruning for CoilSetMinDistance#2140

Merged
ddudt merged 11 commits intomasterfrom
tme/coilset-min-distance-pruning
Apr 13, 2026
Merged

num_neighbors pruning for CoilSetMinDistance#2140
ddudt merged 11 commits intomasterfrom
tme/coilset-min-distance-pruning

Conversation

@TElder-thea
Copy link
Copy Markdown
Collaborator

@TElder-thea TElder-thea commented Mar 27, 2026

Add optional num_neighbors parameter to the CoilSetMinDistance objective. This limits pairwise distance computation to only the num_neighbors nearest neighbors per coil, selected by centroid distance. Neighbor indices use jax.lax.stop_gradient so AD only traces through the fine-grained distances to nearby coils.

Depending on coilset used, we found:

  • 16.5x speedup on CoilSetMinDistance Jacobian alone
  • 5.8x full auglag optimizer speedup (468s -> 80s)
  • 3.5x full lsq-exact optimizer speedup (464s -> 134s)
  • Enables jac_chunk_size=32 (baseline OOMs at chunk=6)
  • Machine-precision correctness (K=5 matches full at alpha=100)
    For more generic modular coil optimizations, pruning may decrease performance due to overhead.

Usage: CoilSetMinDistance(..., n_neighbors=20)
Default n_neighbors=None preserves existing behavior.

Add optional n_neighbors parameter to CoilSetMinDistance that limits
pairwise distance computation to K nearest neighbors per coil, selected
by centroid distance. Neighbor indices use jax.lax.stop_gradient so AD
only traces through the fine-grained distances to nearby coils.

Results at full resolution (106 coils, 4833 DOFs, A100 80GB):
- 16.5x speedup on CoilSetMinDistance Jacobian alone
- 5.8x full auglag optimizer speedup (468s -> 80s)
- 3.5x full lsq-exact optimizer speedup (464s -> 134s)
- Enables jac_chunk_size=32 (baseline OOMs at chunk=6)
- Machine-precision correctness (K=5 matches full at alpha=100)

Usage: CoilSetMinDistance(..., n_neighbors=20)
Default n_neighbors=None preserves existing behavior.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@ddudt ddudt marked this pull request as draft March 27, 2026 17:15
@ddudt ddudt added performance New feature or request to make the code faster coil stuff relating to coils and coil optimization easy Short and simple to code or review labels Mar 27, 2026
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
ddudt
ddudt previously requested changes Mar 27, 2026
Copy link
Copy Markdown
Collaborator

@ddudt ddudt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to also update CHANGELOG.md

Comment thread desc/objectives/_coils.py Outdated
Comment thread desc/objectives/_coils.py Outdated
Comment thread desc/objectives/_coils.py Outdated
@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Mar 27, 2026

Memory benchmark result

|               Test Name                |      %Δ      |    Master (MB)     |      PR (MB)       |    Δ (MB)    |    Time PR (s)     |  Time Master (s)   |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
  test_objective_jac_w7x                 |    6.13 %    |     4.040e+03      |     4.288e+03      |    247.67    |       43.28        |       41.39        |
  test_proximal_jac_w7x_with_eq_update   |   -0.81 %    |     6.579e+03      |     6.526e+03      |    -53.34    |       172.42       |       171.81       |
  test_proximal_freeb_jac                |    0.20 %    |     1.337e+04      |     1.340e+04      |    26.24     |       92.78        |       91.27        |
  test_proximal_freeb_jac_blocked        |    0.56 %    |     7.710e+03      |     7.753e+03      |    43.03     |       82.23        |       82.34        |
  test_proximal_freeb_jac_batched        |    0.51 %    |     7.671e+03      |     7.710e+03      |    39.00     |       81.44        |       82.45        |
  test_proximal_jac_ripple               |    1.01 %    |     3.573e+03      |     3.609e+03      |    36.03     |       67.22        |       66.60        |
  test_proximal_jac_ripple_bounce1d      |    0.26 %    |     3.821e+03      |     3.831e+03      |     9.98     |       81.37        |       81.89        |
  test_eq_solve                          |   -0.84 %    |     2.232e+03      |     2.214e+03      |    -18.64    |       101.07       |       101.48       |

For the memory plots, go to the summary of Memory Benchmarks workflow and download the artifact.

@codecov
Copy link
Copy Markdown

codecov bot commented Mar 27, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 94.45%. Comparing base (2471d55) to head (39aaca5).
⚠️ Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2140      +/-   ##
==========================================
- Coverage   94.45%   94.45%   -0.01%     
==========================================
  Files         101      101              
  Lines       28593    28604      +11     
==========================================
+ Hits        27008    27018      +10     
- Misses       1585     1586       +1     
Files with missing lines Coverage Δ
desc/objectives/_coils.py 99.34% <100.00%> (+<0.01%) ⬆️

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ddudt ddudt marked this pull request as ready for review April 8, 2026 21:38
@ddudt ddudt requested review from a team, YigitElma, ddudt, dpanici, f0uriest, rahulgaur104 and unalmis and removed request for a team April 8, 2026 21:39
@ddudt
Copy link
Copy Markdown
Collaborator

ddudt commented Apr 8, 2026

Note for reviewers: Some of the original code was written by Claude, but I re-wrote almost everything as a human.

@ddudt ddudt changed the title feat: add n_neighbors pruning to CoilSetMinDistance num_neighbors pruning for CoilSetMinDistance Apr 8, 2026
@ddudt ddudt dismissed their stale review April 9, 2026 15:36

Changes made

Copy link
Copy Markdown
Collaborator

@YigitElma YigitElma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about the behavior when the neighboring coils change, but I guess it is similar to how jnp.min or softmin behaves under similar conditions.

It looks good to me.

obj3.build()
g1 = obj1.jac_unscaled(obj1.x(coils))
g3 = obj3.jac_unscaled(obj3.x(coils))
np.testing.assert_allclose(g1, g3)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be true if one of the neighboring coils is about to change?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I believe so. Both the original method and this new pruning should have the same behavior for these derivatives, which doesn't fully capture what would happen if the closest coil changes.

@ddudt ddudt merged commit 03637dd into master Apr 13, 2026
27 checks passed
@ddudt ddudt deleted the tme/coilset-min-distance-pruning branch April 13, 2026 18:47
@YigitElma
Copy link
Copy Markdown
Collaborator

If I get confused in the future again, here is a small example,
image
image

Details
from desc.objectives.utils import softmin
def fun(x):
    f = x**2
    g = x
    return softmin(jnp.array([f, g]), alpha=8)

x = jnp.linspace(-0.5, 1.5, 300)
y = jax.vmap(fun)(x)
dy = jax.vmap(jax.grad(fun))(x)
plt.title("softmin with alpha=8")
plt.plot(x, x, "--", label="g", lw=2)
plt.plot(x, x**2, "--", label="f", lw=2)
plt.plot(x, y, label="softmin(f, g)", lw=3)
plt.plot(x, dy, label="dy/dx")
plt.legend()

def fun(x):
    f = x**2
    g = x
    return jnp.min(jnp.array([f, g]))

x = jnp.linspace(-0.5, 1.5, 300)
y = jax.vmap(fun)(x)
dy = jax.vmap(jax.grad(fun))(x)
plt.title("min(f, g)")
plt.plot(x, x, "--", label="g", lw=2)
plt.plot(x, x**2, "--", label="f", lw=2)
plt.plot(x, y, label="min(f, g)", lw=3)
plt.plot(x, dy, label="dy/dx")
plt.legend()

<\details

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

coil stuff relating to coils and coil optimization easy Short and simple to code or review performance New feature or request to make the code faster

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants