Skip to content

Add regression test for scale_by_rms zero-gradient stability#1553

Open
TanmayThakur2209 wants to merge 6 commits intogoogle-deepmind:mainfrom
TanmayThakur2209:add-rms-zero-grad-regression-test
Open

Add regression test for scale_by_rms zero-gradient stability#1553
TanmayThakur2209 wants to merge 6 commits intogoogle-deepmind:mainfrom
TanmayThakur2209:add-rms-zero-grad-regression-test

Conversation

@TanmayThakur2209
Copy link
Copy Markdown
Contributor

This PR adds a regression test to ensure scale_by_rms does not produce NaN or infinite updates when gradients are zero. This guards against potential future numerical stability regressions without changing optimizer behavior.

@rdyro
Copy link
Copy Markdown
Collaborator

rdyro commented Jan 7, 2026

This test only passes if the eps is non-zero. It's a bit of an interesting question what we should return if the eps == 0?

One possibility is to always do a safe division / safe sqrt, but it adds an additional max operation, perhaps we should have an extra argument that enables a "safe" version of this transform when setting eps == 0. What do you think?

@TanmayThakur2209
Copy link
Copy Markdown
Contributor Author

TanmayThakur2209 commented Jan 8, 2026

Thank you for highlighting this, this test assumes eps>0 by default and would fail if eps==0. I propose updating the test to explicitly use a non-zero epsilon:, tx = transform.scale_by_rms(eps=1e-8) where this guarantees no division by zero and match the current optax semantics. This test then asserts numerical stability in the recommended configuration (eps > 0) and does not define the behavior for eps == 0.

Along with that I really liked the idea of having an optional safe version for users who want eps == 0 but still prefer numerical stability. It would work by clamping the RMS denominator to a small positive value so division by zero never occurs, at the cost of changing the math and adding extra operations. This can be done by putting an explicit flag to it scale_by_rms(eps=0.0, safe=True), or by separate compositional transform, for example:

optax.chain(
    optax.scale_by_rms(eps=0.0),
    optax.zero_nans(),
)

They both do same math but differ in API philosophy. In this way it will have explicit user intent. It won't change default behaviour and allows users to opt into safely.

For this PR, I’m happy to scope the change to the test update only and leave the safe variant as a potential future design discussion. I’d be interested in exploring such an option further if you think it would be useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants