From 488f7596eaa936b24894bcff5b3f41ce51f59c4c Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 25 Mar 2026 10:28:26 +0800 Subject: [PATCH 1/2] add UT for backward --- tests/models/testing_utils/parallelism.py | 104 ++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index bea832904041..766316742379 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -98,6 +98,65 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di dist.destroy_process_group() +def _context_parallel_backward_worker( + rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict +): + """Worker function for context parallel backward pass testing.""" + try: + # Set up distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(master_port) + os.environ["RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + + # Get device configuration + device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"]) + backend = device_config["backend"] + device_module = device_config["module"] + + # Initialize process group + dist.init_process_group(backend=backend, rank=rank, world_size=world_size) + + # Set device for this process + device_module.set_device(rank) + device = torch.device(f"{torch_device}:{rank}") + + # Create model in training mode + model = model_class(**init_dict) + model.to(device) + model.train() + + # Move inputs to device + inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + + # Enable context parallelism + cp_config = ContextParallelConfig(**cp_dict) + model.enable_parallelism(config=cp_config) + + # Run forward and backward pass + output = model(**inputs_on_device, return_dict=False)[0] + loss = output.sum() + loss.backward() + + # Check that all trainable parameters have finite gradients + has_valid_grads = all( + p.grad is not None and torch.isfinite(p.grad).all() for p in model.parameters() if p.requires_grad + ) + + # Only rank 0 reports results + if rank == 0: + return_dict["status"] = "success" + return_dict["has_valid_grads"] = bool(has_valid_grads) + + except Exception as e: + if rank == 0: + return_dict["status"] = "error" + return_dict["error"] = str(e) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + def _custom_mesh_worker( rank, world_size, @@ -204,6 +263,51 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1): def test_context_parallel_batch_inputs(self, cp_type): self.test_context_parallel_inference(cp_type, batch_size=2) + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_backward(self, cp_type, batch_size: int = 1): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + if cp_type == "ring_degree": + active_backend, _ = _AttentionBackendRegistry.get_active_backend() + if active_backend == AttentionBackendName.NATIVE: + pytest.skip("Ring attention is not supported with the native attention backend.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs(batch_size=batch_size) + + # Move all tensors to CPU for multiprocessing + inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} + cp_dict = {cp_type: world_size} + + # Find a free port for distributed communication + master_port = _find_free_port() + + # Use multiprocessing manager for cross-process communication + manager = mp.Manager() + return_dict = manager.dict() + + # Spawn worker processes + mp.spawn( + _context_parallel_backward_worker, + args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict), + nprocs=world_size, + join=True, + ) + + assert return_dict.get("status") == "success", ( + f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}" + ) + assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients." + + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_backward_batch_inputs(self, cp_type): + self.test_context_parallel_backward(cp_type, batch_size=2) + @pytest.mark.parametrize( "cp_type,mesh_shape,mesh_dim_names", [ From 9ed7c25b1366b4c526b5ef152c475851ed825358 Mon Sep 17 00:00:00 2001 From: Mike Cheung Date: Wed, 25 Mar 2026 11:12:32 +0800 Subject: [PATCH 2/2] fix SDPA attention backward --- src/diffusers/models/attention_dispatch.py | 32 +++++++++++----------- tests/models/testing_utils/parallelism.py | 7 ++--- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 42dc63273740..375abb24d131 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -862,23 +862,23 @@ def _native_attention_backward_op( key.requires_grad_(True) value.requires_grad_(True) - query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query_t, - key=key_t, - value=value_t, - attn_mask=ctx.attn_mask, - dropout_p=ctx.dropout_p, - is_causal=ctx.is_causal, - scale=ctx.scale, - enable_gqa=ctx.enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + with torch.enable_grad(): + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query_t, + key=key_t, + value=value_t, + attn_mask=ctx.attn_mask, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + enable_gqa=ctx.enable_gqa, + ) + out = out.permute(0, 2, 1, 3) - grad_out_t = grad_out.permute(0, 2, 1, 3) - grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( - outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False - ) + grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( + outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False + ) grad_query = grad_query_t.permute(0, 2, 1, 3) grad_key = grad_key_t.permute(0, 2, 1, 3) diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py index 766316742379..9bf4bcb62019 100644 --- a/tests/models/testing_utils/parallelism.py +++ b/tests/models/testing_utils/parallelism.py @@ -138,10 +138,9 @@ def _context_parallel_backward_worker( loss = output.sum() loss.backward() - # Check that all trainable parameters have finite gradients - has_valid_grads = all( - p.grad is not None and torch.isfinite(p.grad).all() for p in model.parameters() if p.requires_grad - ) + # Check that backward actually produced at least one valid gradient + grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None] + has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads) # Only rank 0 reports results if rank == 0: