num_neighbors pruning for CoilSetMinDistance#2140
Conversation
Add optional n_neighbors parameter to CoilSetMinDistance that limits pairwise distance computation to K nearest neighbors per coil, selected by centroid distance. Neighbor indices use jax.lax.stop_gradient so AD only traces through the fine-grained distances to nearby coils. Results at full resolution (106 coils, 4833 DOFs, A100 80GB): - 16.5x speedup on CoilSetMinDistance Jacobian alone - 5.8x full auglag optimizer speedup (468s -> 80s) - 3.5x full lsq-exact optimizer speedup (464s -> 134s) - Enables jac_chunk_size=32 (baseline OOMs at chunk=6) - Machine-precision correctness (K=5 matches full at alpha=100) Usage: CoilSetMinDistance(..., n_neighbors=20) Default n_neighbors=None preserves existing behavior. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
ddudt
left a comment
There was a problem hiding this comment.
Remember to also update CHANGELOG.md
Memory benchmark result| Test Name | %Δ | Master (MB) | PR (MB) | Δ (MB) | Time PR (s) | Time Master (s) |
| -------------------------------------- | ------------ | ------------------ | ------------------ | ------------ | ------------------ | ------------------ |
test_objective_jac_w7x | 6.13 % | 4.040e+03 | 4.288e+03 | 247.67 | 43.28 | 41.39 |
test_proximal_jac_w7x_with_eq_update | -0.81 % | 6.579e+03 | 6.526e+03 | -53.34 | 172.42 | 171.81 |
test_proximal_freeb_jac | 0.20 % | 1.337e+04 | 1.340e+04 | 26.24 | 92.78 | 91.27 |
test_proximal_freeb_jac_blocked | 0.56 % | 7.710e+03 | 7.753e+03 | 43.03 | 82.23 | 82.34 |
test_proximal_freeb_jac_batched | 0.51 % | 7.671e+03 | 7.710e+03 | 39.00 | 81.44 | 82.45 |
test_proximal_jac_ripple | 1.01 % | 3.573e+03 | 3.609e+03 | 36.03 | 67.22 | 66.60 |
test_proximal_jac_ripple_bounce1d | 0.26 % | 3.821e+03 | 3.831e+03 | 9.98 | 81.37 | 81.89 |
test_eq_solve | -0.84 % | 2.232e+03 | 2.214e+03 | -18.64 | 101.07 | 101.48 |For the memory plots, go to the summary of |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #2140 +/- ##
==========================================
- Coverage 94.45% 94.45% -0.01%
==========================================
Files 101 101
Lines 28593 28604 +11
==========================================
+ Hits 27008 27018 +10
- Misses 1585 1586 +1
🚀 New features to boost your workflow:
|
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Note for reviewers: Some of the original code was written by Claude, but I re-wrote almost everything as a human. |
num_neighbors pruning for CoilSetMinDistance
YigitElma
left a comment
There was a problem hiding this comment.
I am not sure about the behavior when the neighboring coils change, but I guess it is similar to how jnp.min or softmin behaves under similar conditions.
It looks good to me.
| obj3.build() | ||
| g1 = obj1.jac_unscaled(obj1.x(coils)) | ||
| g3 = obj3.jac_unscaled(obj3.x(coils)) | ||
| np.testing.assert_allclose(g1, g3) |
There was a problem hiding this comment.
Should this be true if one of the neighboring coils is about to change?
There was a problem hiding this comment.
Yes I believe so. Both the original method and this new pruning should have the same behavior for these derivatives, which doesn't fully capture what would happen if the closest coil changes.


Add optional
num_neighborsparameter to theCoilSetMinDistanceobjective. This limits pairwise distance computation to only thenum_neighborsnearest neighbors per coil, selected by centroid distance. Neighbor indices use jax.lax.stop_gradient so AD only traces through the fine-grained distances to nearby coils.Depending on coilset used, we found:
For more generic modular coil optimizations, pruning may decrease performance due to overhead.
Usage: CoilSetMinDistance(..., n_neighbors=20)
Default n_neighbors=None preserves existing behavior.