Skip to content

Update self-supervised losses implementation and tests#1531

Open
surajyadav-research wants to merge 7 commits intogoogle-deepmind:mainfrom
surajyadav-research:feat/self-supervised-losses-v2
Open

Update self-supervised losses implementation and tests#1531
surajyadav-research wants to merge 7 commits intogoogle-deepmind:mainfrom
surajyadav-research:feat/self-supervised-losses-v2

Conversation

@surajyadav-research
Copy link
Copy Markdown

@surajyadav-research surajyadav-research commented Dec 11, 2025

Summary of Improvements #1528

This PR corrects several issues in the existing self-supervised losses (BYOL, SimSiam, DINO, Barlow Twins) and brings the implementations in line with the original papers. The previous versions were minimal but incomplete, and in several cases unsafe (silent broadcasting, missing stop-gradients, incorrect view handling, etc.).
The new implementations add correct two-view behavior, strict validation, and safer numerical handling.


BYOL

  • Two-view formulation

    • Before: Only supported a single (q, z) direction; users had to manually build the symmetric two-view BYOL loss outside the function.
    • Now: Adds single-direction and symmetric two-view BYOL via symmetric and four projections, matching the paper.
  • Teacher gradients

    • Before: Target projections were differentiable; forgetting to apply stop_gradient silently broke BYOL.
    • Now: Always applies lax.stop_gradient on target projections inside the loss.
  • Shape validation

    • Before: No shape checks → silent broadcasting or mismatched tensors.
    • Now: Validates shapes for all projection pairs, failing fast on incorrect usage.
  • Cosine similarity & eps handling

    • Before: Reimplemented cosine similarity manually and used hard-coded finfo.eps.
    • Now: Uses shared _regression.cosine_similarity and a dtype-safe, configurable eps.

SimSiam

  • Symmetric loss support

    • Before: Only computed D(p1, z2); symmetric SimSiam required manual assembly.
    • Now: Supports single-direction and symmetric two-view SimSiam in one API.
  • Stop-gradient enforcement

    • Before: Relied on the caller to apply stop_gradient(z); forgetting it invalidates SimSiam.
    • Now: The loss internally stops gradients through the target projections.
  • Shape checks

    • Before: No shape checks between predictor and target tensors.
    • Now: Enforces matching shapes for all predictor/target pairs.
  • Shared cosine similarity

    • Before: Manual inline cosine implementation.
    • Now: Uses _regression.cosine_similarity with dtype-aligned eps.

DINO

  • Single-view and two-view support

    • Before: Only supported a single (student, teacher) pair; cross-view teacher/student matching had to be implemented manually.
    • Now: Adds single-view and symmetric two-view DINO via two_view and a second logit pair.
  • Teacher stop-gradient

    • Before: Teacher probabilities were differentiable → gradients flowed into the EMA teacher.
    • Now: Wraps teacher softmax outputs in lax.stop_gradient.
  • Temperature and shape validation

    • Before: Did not validate shapes or require temperatures to be positive.
    • Now: Checks logits shapes and asserts positive temperatures before use.
  • Centering & broadcasting

    • Before: center subtraction relied on uncontrolled broadcasting.
    • Now: Casts and broadcasts center explicitly, ensuring consistent teacher normalization across views.

Barlow Twins

  • Input shape & rank checks

    • Before: Accepted arbitrary shapes; incorrect ranks led to meaningless correlation matrices.
    • Now: Enforces matching shapes and rank-2 inputs [batch, feature_dim].
  • Numerically stable normalization

    • Before: Used jnp.std with eps added after; less explicit numeric pathway.
    • Now: Computes variance explicitly, adds eps before sqrt, and normalizes safely.
  • Dtype-safe hyperparameters

    • Before: eps and lambda were raw Python floats (problematic in mixed precision).
    • Now: Casts both to the projection dtype for consistent computation.
  • Clear loss decomposition

    • Before: Diagonal and off-diagonal terms were computed but not clearly separated.
    • Now: Explicit on-diag + off-diag formulation directly matching the Barlow Twins paper.

Tests

The test suite was updated to rigorously validate both the math and the new code paths.

BYOL

  • Did not test symmetric path → Now tests exact symmetric formula with a handmade cosine reference.
  • Only rough cosine check on random data → Now uses deterministic handmade comparison against the analytical expression.
  • Weak JIT check (shape only) → Now verifies the JIT’ed loss matches the mathematical ground truth.

SimSiam

  • Never validated symmetric mode → Added an exact symmetric test that matches the cosine-based definition.
  • Only tested trivial edge cases (identical/orthogonal) → Now tests the full objective with a precise reference implementation.
  • No strict math-level comparison → Now enforces exact equivalence between implementation and the SimSiam loss definition.

DINO

  • No verification of the full KL / cross-entropy formula → Now compares against a handmade softmax/log-softmax reference.
  • Two-view coupling never tested → Added explicit L12/L21 tests that compute both directions by hand.
  • No argument validation tests → Added tests that temperatures must be positive and logit shapes must match.

Barlow Twins

  • Did not test the full normalization + cross-correlation pipeline → Now checks against a full handmade correlation computation.
  • Previously only asserted non-negativity → Now validates the complete objective numerically, including on- and off-diagonal terms.
  • No invalid-shape testing → Added tests for rank mismatches and shape mismatches to confirm error paths.

@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rajasekharporeddy
This PR is ready for review. I’ve updated the docs, implementation, and added/updated tests. All CI checks are green and CLA is done. Would you mind taking a look when you get a chance?

Copy link
Copy Markdown
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I left some comments, mostly about improving the tests

Comment thread optax/losses/_self_supervised_test.py Outdated
dtype=jnp.float32,
)

def testing_barlow_twins_loss(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of reimplementing the loss logic, can you compare e.g., that the loss is zero (and not nan) for identical inputs and non-zero otherwise?

Comment thread optax/losses/_self_supervised_test.py Outdated
)
np.testing.assert_allclose(result, handmade_result, atol=1e-6)

def test_two_view_matches_handmade(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)

Comment thread optax/losses/_self_supervised_test.py Outdated

class DinoLossTest(absltest.TestCase):

def test_single_view_matches_handmade(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)

Comment thread optax/losses/_self_supervised_test.py Outdated
)
np.testing.assert_allclose(result, handmade_result, atol=1e-4)

def test_single_direction_matches_handmade(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)

Comment thread optax/losses/_self_supervised_test.py Outdated
result = jax.jit(_self_supervised.simsiam_loss)(
p,
z,
symmetric=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: too much whitespace, this can be on a single line

Comment thread optax/losses/_self_supervised_test.py Outdated

class SimSiamLossTest(absltest.TestCase):

def test_symmetric_batched_matches_handmade(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)

Comment thread optax/losses/_self_supervised_test.py Outdated
)
np.testing.assert_allclose(result, handmade_result, atol=1e-4)

def test_single_direction_matches_handmade(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)

Comment thread optax/losses/_self_supervised_test.py Outdated

class ByolLossTest(absltest.TestCase):

def test_symmetric_batched_matches_handmade(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrite as a random inputs test parameterized by a random seed please (instead of hard-coding the numbers)

@surajyadav-research
Copy link
Copy Markdown
Author

Thanks for the review, @rdyro. I’ll make the requested updates ASAP.

@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro
I’ve updated the test code based on all your suggestions. However, I’m still seeing an import/copybara error and I’m not sure what’s causing it.

Copy link
Copy Markdown
Collaborator

@rdyro rdyro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

I added a couple more comments. Testing these losses is in general not trivial, I'm not sure the re-implementation approach is the best, but if you think there's no other way, we can go with that.

I'm not entirely sure why copybara is failing, can you try squashing all commits into a single one?

Comment thread optax/losses/_self_supervised_test.py Outdated
z1 = jax.random.normal(k3, (2, 3), dtype=jnp.float32)
z2 = jax.random.normal(k4, (2, 3), dtype=jnp.float32)

def testing_byol_loss(q1_val, z2_val, q2_val, z1_val, eps=1e-6):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ideally avoid a re-implementation tests, is there another way to test here?

Comment thread optax/losses/_self_supervised_test.py Outdated
q = jax.random.normal(k1, (2, 3), dtype=jnp.float32)
z = jax.random.normal(k2, (2, 3), dtype=jnp.float32)

def testing_single_direction_byol(q_val, z_val, eps=1e-6):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ideally avoid a re-implementation tests, is there another way to test here?

Comment thread optax/losses/_self_supervised_test.py Outdated
z1 = jax.random.normal(k3, (2, 3), dtype=jnp.float32)
z2 = jax.random.normal(k4, (2, 3), dtype=jnp.float32)

def testing_simsiam_loss(p1_val, z2_val, p2_val, z1_val, eps=1e-6):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ideally avoid a re-implementation tests, is there another way to test here?

Comment thread optax/losses/_self_supervised_test.py Outdated
p = jax.random.normal(k1, (2, 3), dtype=jnp.float32)
z = jax.random.normal(k2, (2, 3), dtype=jnp.float32)

def testing_single_direction_simsiam(p_val, z_val, eps=1e-6):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ideally avoid a re-implementation tests, is there another way to test here?

Comment thread optax/losses/_self_supervised_test.py Outdated
teacher_temperature = 0.04
teacher_center = jax.random.normal(k3, (3,), dtype=jnp.float32)

def testing_dino_loss(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's ideally avoid a re-implementation tests, is there another way to test here?

Comment thread mlc_config.json
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this file?

Comment thread optax/losses/_self_supervised_test.py Outdated

@parameterized.parameters(0, 1, 42)
def test_single_direction_matches_handmade(self, seed):
key = jax.random.PRNGKey(seed)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer the modern jax.random.key

@surajyadav-research surajyadav-research force-pushed the feat/self-supervised-losses-v2 branch from b5a3c17 to fafa81d Compare December 29, 2025 15:34
@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro,
Ruff is flagging changes I didn’t make, coming from commit #1497 (pytype suppression comments on the lines above).
How should I handle this?

@rdyro
Copy link
Copy Markdown
Collaborator

rdyro commented Dec 29, 2025

PR #1531 should be able to fix it once it's merged, can you rebase then?

Were you able to give it some more thought if there was a better way to test these losses instead of their re-implementation in the test file?

@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro
Yes, I’ve removed the redundant code in the tests and also implemented a few additional functions for better test coverage. I am confused due to the error I am encountering. I’ll push the new changes right away. These are changes i made.

BYOL

  • Added a single test that validates symmetric (two-view) behavior by comparing against the average of two single-direction calls.
  • Added zero-loss sanity check when q == z.
  • Added scale invariance check: loss(a*q, b*z) == loss(q, z) for any positive scalars a, b.
  • Added batch permutation invariance check: loss(q[perm], z[perm]) == loss(q, z) for the same permutation perm.

SimSiam

  • Added a single test that validates symmetric (two-view) behavior by comparing against the average of two single-direction calls.
  • Added identity sanity check: when p == z, loss is approximately -1 (cosine similarity = 1).
  • Added scale invariance check: loss(a*p, b*z) == loss(p, z) for any positive scalars a, b.
  • Added batch permutation invariance check: loss(p[perm], z[perm]) == loss(p, z) for the same permutation perm.

DINO

  • Added two-view consistency test: two_view=True equals the average of two two_view=False calls with swapped pairs.
  • Added sanity check that student matches teacher yields a lower/equal loss than a deterministic mismatch.
  • Added logit translation invariance check: adding a constant offset to logits does not change the loss.

Barlow Twins

  • Added batch permutation invariance check: shuffling examples does not change the loss.

@rdyro
Copy link
Copy Markdown
Collaborator

rdyro commented Dec 29, 2025

Great, these sound like very good ideas for tests!

To unblock your flow, you could try rebasing your commits to the commit from Decemeber 26th 6b364ac289da913454bc849756f827f675de63c2

@surajyadav-research surajyadav-research force-pushed the feat/self-supervised-losses-v2 branch from fafa81d to 1eb4a25 Compare December 29, 2025 19:03
surajyadav-research and others added 6 commits December 30, 2025 00:59
The markdown link checker was failing with 503 errors on links to
the repo's own files (e.g., getting_started.ipynb, CONTRIBUTING.md).
These are transient failures due to GitHub rate-limiting automated
link checkers.

This config:
- Ignores internal repo blob links that cause false positives
- Adds retry logic for rate-limited (429) responses
- Sets appropriate timeout and retry delays
@surajyadav-research surajyadav-research force-pushed the feat/self-supervised-losses-v2 branch from 1eb4a25 to f5f5417 Compare December 29, 2025 19:35
@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro , I’ve pushed the updated test code. I tried to rebase it, but I’m getting stuck and couldn’t figure out the right steps. Sorry for the inconvenience.

@rdyro
Copy link
Copy Markdown
Collaborator

rdyro commented Dec 29, 2025

No worries, let's pick this up once the CI is unblocked.

@rdyro
Copy link
Copy Markdown
Collaborator

rdyro commented Dec 29, 2025

The CI fix is in, so you should be able to rebase to main now @surajyadav-research

@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro
I was checking the line that was causing the error, and it’s still present on main. When I rebase, that same line shows up in my local code as well.

@surajyadav-research surajyadav-research force-pushed the feat/self-supervised-losses-v2 branch from 71c617a to 3f4a69e Compare December 30, 2025 12:23
@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro ,
The code is now passing tests, and I’ve added the new tests I explained earlier.
The mlc_config.json file was causing timeouts during the tests, which is why it has been modified. I will squash all the commits into one shortly.
Whenever you have some time, could you please review the changes.
Thank you very much for your patience with me.

@surajyadav-research
Copy link
Copy Markdown
Author

Hi @rdyro
In this implementation, I’ve updated the tests I mentioned earlier.
When you get a chance, could you please review the changes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants