Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,23 +862,23 @@ def _native_attention_backward_op(
key.requires_grad_(True)
value.requires_grad_(True)

query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)
with torch.enable_grad():
query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
query=query_t,
key=key_t,
value=value_t,
attn_mask=ctx.attn_mask,
dropout_p=ctx.dropout_p,
is_causal=ctx.is_causal,
scale=ctx.scale,
enable_gqa=ctx.enable_gqa,
)
out = out.permute(0, 2, 1, 3)

grad_out_t = grad_out.permute(0, 2, 1, 3)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False
)
grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad(
outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out, retain_graph=False
)

grad_query = grad_query_t.permute(0, 2, 1, 3)
grad_key = grad_key_t.permute(0, 2, 1, 3)
Expand Down
103 changes: 103 additions & 0 deletions tests/models/testing_utils/parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,64 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
dist.destroy_process_group()


def _context_parallel_backward_worker(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict
):
"""Worker function for context parallel backward pass testing."""
try:
# Set up distributed environment
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(master_port)
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

# Get device configuration
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
backend = device_config["backend"]
device_module = device_config["module"]

# Initialize process group
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)

# Set device for this process
device_module.set_device(rank)
device = torch.device(f"{torch_device}:{rank}")

# Create model in training mode
model = model_class(**init_dict)
model.to(device)
model.train()

# Move inputs to device
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}

# Enable context parallelism
cp_config = ContextParallelConfig(**cp_dict)
model.enable_parallelism(config=cp_config)

# Run forward and backward pass
output = model(**inputs_on_device, return_dict=False)[0]
loss = output.sum()
loss.backward()

# Check that backward actually produced at least one valid gradient
grads = [p.grad for p in model.parameters() if p.requires_grad and p.grad is not None]
has_valid_grads = len(grads) > 0 and all(torch.isfinite(g).all() for g in grads)

# Only rank 0 reports results
if rank == 0:
return_dict["status"] = "success"
return_dict["has_valid_grads"] = bool(has_valid_grads)

except Exception as e:
if rank == 0:
return_dict["status"] = "error"
return_dict["error"] = str(e)
finally:
if dist.is_initialized():
dist.destroy_process_group()


def _custom_mesh_worker(
rank,
world_size,
Expand Down Expand Up @@ -204,6 +262,51 @@ def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)

@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_backward(self, cp_type, batch_size: int = 1):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")

if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")

if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")

world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)

# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
cp_dict = {cp_type: world_size}

# Find a free port for distributed communication
master_port = _find_free_port()

# Use multiprocessing manager for cross-process communication
manager = mp.Manager()
return_dict = manager.dict()

# Spawn worker processes
mp.spawn(
_context_parallel_backward_worker,
args=(world_size, master_port, self.model_class, init_dict, cp_dict, inputs_dict, return_dict),
nprocs=world_size,
join=True,
)

assert return_dict.get("status") == "success", (
f"Context parallel backward pass failed: {return_dict.get('error', 'Unknown error')}"
)
assert return_dict.get("has_valid_grads"), "Context parallel backward pass did not produce valid gradients."

@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_backward_batch_inputs(self, cp_type):
self.test_context_parallel_backward(cp_type, batch_size=2)

@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
Expand Down