Update self-supervised losses implementation and tests#1531
Update self-supervised losses implementation and tests#1531surajyadav-research wants to merge 7 commits intogoogle-deepmind:mainfrom
Conversation
|
Hi @rajasekharporeddy |
rdyro
left a comment
There was a problem hiding this comment.
Thanks for the PR! I left some comments, mostly about improving the tests
| dtype=jnp.float32, | ||
| ) | ||
|
|
||
| def testing_barlow_twins_loss( |
There was a problem hiding this comment.
instead of reimplementing the loss logic, can you compare e.g., that the loss is zero (and not nan) for identical inputs and non-zero otherwise?
| ) | ||
| np.testing.assert_allclose(result, handmade_result, atol=1e-6) | ||
|
|
||
| def test_two_view_matches_handmade(self): |
There was a problem hiding this comment.
rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)
|
|
||
| class DinoLossTest(absltest.TestCase): | ||
|
|
||
| def test_single_view_matches_handmade(self): |
There was a problem hiding this comment.
rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)
| ) | ||
| np.testing.assert_allclose(result, handmade_result, atol=1e-4) | ||
|
|
||
| def test_single_direction_matches_handmade(self): |
There was a problem hiding this comment.
rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)
| result = jax.jit(_self_supervised.simsiam_loss)( | ||
| p, | ||
| z, | ||
| symmetric=False, |
There was a problem hiding this comment.
nit: too much whitespace, this can be on a single line
|
|
||
| class SimSiamLossTest(absltest.TestCase): | ||
|
|
||
| def test_symmetric_batched_matches_handmade(self): |
There was a problem hiding this comment.
rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)
| ) | ||
| np.testing.assert_allclose(result, handmade_result, atol=1e-4) | ||
|
|
||
| def test_single_direction_matches_handmade(self): |
There was a problem hiding this comment.
rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)
|
|
||
| class ByolLossTest(absltest.TestCase): | ||
|
|
||
| def test_symmetric_batched_matches_handmade(self): |
There was a problem hiding this comment.
rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)
|
Thanks for the review, @rdyro. I’ll make the requested updates ASAP. |
|
Hi @rdyro |
rdyro
left a comment
There was a problem hiding this comment.
Thank you!
I added a couple more comments. Testing these losses is in general not trivial, I'm not sure the re-implementation approach is the best, but if you think there's no other way, we can go with that.
I'm not entirely sure why copybara is failing, can you try squashing all commits into a single one?
| z1 = jax.random.normal(k3, (2, 3), dtype=jnp.float32) | ||
| z2 = jax.random.normal(k4, (2, 3), dtype=jnp.float32) | ||
|
|
||
| def testing_byol_loss(q1_val, z2_val, q2_val, z1_val, eps=1e-6): |
There was a problem hiding this comment.
Let's ideally avoid a re-implementation tests, is there another way to test here?
| q = jax.random.normal(k1, (2, 3), dtype=jnp.float32) | ||
| z = jax.random.normal(k2, (2, 3), dtype=jnp.float32) | ||
|
|
||
| def testing_single_direction_byol(q_val, z_val, eps=1e-6): |
There was a problem hiding this comment.
Let's ideally avoid a re-implementation tests, is there another way to test here?
| z1 = jax.random.normal(k3, (2, 3), dtype=jnp.float32) | ||
| z2 = jax.random.normal(k4, (2, 3), dtype=jnp.float32) | ||
|
|
||
| def testing_simsiam_loss(p1_val, z2_val, p2_val, z1_val, eps=1e-6): |
There was a problem hiding this comment.
Let's ideally avoid a re-implementation tests, is there another way to test here?
| p = jax.random.normal(k1, (2, 3), dtype=jnp.float32) | ||
| z = jax.random.normal(k2, (2, 3), dtype=jnp.float32) | ||
|
|
||
| def testing_single_direction_simsiam(p_val, z_val, eps=1e-6): |
There was a problem hiding this comment.
Let's ideally avoid a re-implementation tests, is there another way to test here?
| teacher_temperature = 0.04 | ||
| teacher_center = jax.random.normal(k3, (3,), dtype=jnp.float32) | ||
|
|
||
| def testing_dino_loss( |
There was a problem hiding this comment.
Let's ideally avoid a re-implementation tests, is there another way to test here?
There was a problem hiding this comment.
What's the purpose of this file?
|
|
||
| @parameterized.parameters(0, 1, 42) | ||
| def test_single_direction_matches_handmade(self, seed): | ||
| key = jax.random.PRNGKey(seed) |
There was a problem hiding this comment.
Prefer the modern jax.random.key
b5a3c17 to
fafa81d
Compare
|
PR #1531 should be able to fix it once it's merged, can you rebase then? Were you able to give it some more thought if there was a better way to test these losses instead of their re-implementation in the test file? |
|
Hi @rdyro BYOL
SimSiam
DINO
Barlow Twins
|
|
Great, these sound like very good ideas for tests! To unblock your flow, you could try rebasing your commits to the commit from Decemeber 26th 6b364ac289da913454bc849756f827f675de63c2 |
fafa81d to
1eb4a25
Compare
The markdown link checker was failing with 503 errors on links to the repo's own files (e.g., getting_started.ipynb, CONTRIBUTING.md). These are transient failures due to GitHub rate-limiting automated link checkers. This config: - Ignores internal repo blob links that cause false positives - Adds retry logic for rate-limited (429) responses - Sets appropriate timeout and retry delays
1eb4a25 to
f5f5417
Compare
|
Hi @rdyro , I’ve pushed the updated test code. I tried to rebase it, but I’m getting stuck and couldn’t figure out the right steps. Sorry for the inconvenience. |
|
No worries, let's pick this up once the CI is unblocked. |
|
The CI fix is in, so you should be able to rebase to main now @surajyadav-research |
|
Hi @rdyro |
71c617a to
3f4a69e
Compare
|
Hi @rdyro , |
|
Hi @rdyro |
Summary of Improvements #1528
This PR corrects several issues in the existing self-supervised losses (BYOL, SimSiam, DINO, Barlow Twins) and brings the implementations in line with the original papers. The previous versions were minimal but incomplete, and in several cases unsafe (silent broadcasting, missing stop-gradients, incorrect view handling, etc.).
The new implementations add correct two-view behavior, strict validation, and safer numerical handling.
BYOL
Two-view formulation
(q, z)direction; users had to manually build the symmetric two-view BYOL loss outside the function.symmetricand four projections, matching the paper.Teacher gradients
stop_gradientsilently broke BYOL.lax.stop_gradienton target projections inside the loss.Shape validation
Cosine similarity & eps handling
finfo.eps._regression.cosine_similarityand a dtype-safe, configurableeps.SimSiam
Symmetric loss support
D(p1, z2); symmetric SimSiam required manual assembly.Stop-gradient enforcement
stop_gradient(z); forgetting it invalidates SimSiam.Shape checks
Shared cosine similarity
_regression.cosine_similaritywith dtype-alignedeps.DINO
Single-view and two-view support
(student, teacher)pair; cross-view teacher/student matching had to be implemented manually.two_viewand a second logit pair.Teacher stop-gradient
lax.stop_gradient.Temperature and shape validation
Centering & broadcasting
centersubtraction relied on uncontrolled broadcasting.Barlow Twins
Input shape & rank checks
[batch, feature_dim].Numerically stable normalization
jnp.stdwith eps added after; less explicit numeric pathway.Dtype-safe hyperparameters
epsandlambdawere raw Python floats (problematic in mixed precision).Clear loss decomposition
Tests
The test suite was updated to rigorously validate both the math and the new code paths.
BYOL
SimSiam
DINO
Barlow Twins